Files
ds_scp_task_solution/app/utils/token_counter.py
T
Aherobo Ovie Victor 0e3e22e8cb Initial commit
2025-07-17 22:20:25 +01:00

81 lines
2.6 KiB
Python

"""
Token counting utilities for document processing.
"""
import tiktoken
from typing import Dict, List, Optional, Union
from loguru import logger
# Default models to use for token counting
DEFAULT_MODEL = "gpt-4o"
def count_tokens(text: str, model: str = DEFAULT_MODEL) -> int:
"""
Count the number of tokens in a text string using tiktoken.
Args:
text: The text to count tokens for
model: The model to use for token counting (default: gpt-4o)
Returns:
Number of tokens in the text
"""
try:
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
except Exception as e:
logger.warning(f"Error counting tokens with model {model}: {str(e)}")
# Fallback to cl100k_base encoding if model-specific encoding fails
try:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
except Exception as e:
logger.error(f"Error counting tokens with fallback encoding: {str(e)}")
# If all else fails, use a rough approximation (4 chars per token)
return len(text) // 4
def truncate_by_tokens(text: str, max_tokens: int, model: str = DEFAULT_MODEL) -> str:
"""
Truncate text to fit within a maximum token count.
Args:
text: The text to truncate
max_tokens: Maximum number of tokens to allow
model: The model to use for token counting (default: gpt-4o)
Returns:
Truncated text that fits within max_tokens
"""
try:
encoding = tiktoken.encoding_for_model(model)
tokens = encoding.encode(text)
if len(tokens) <= max_tokens:
return text
# Truncate tokens and decode
truncated_tokens = tokens[:max_tokens]
truncated_text = encoding.decode(truncated_tokens)
# Add truncation indicator
return truncated_text + "...(truncated)"
except Exception as e:
logger.warning(f"Error truncating by tokens with model {model}: {str(e)}")
# Fallback to character-based truncation if token-based fails
approx_chars = max_tokens * 4 # Rough approximation
if len(text) <= approx_chars:
return text
return text[:approx_chars] + "...(truncated)"
def estimate_tokens_from_chars(char_count: int) -> int:
"""
Estimate the number of tokens from character count.
This is a rough approximation (4 chars per token on average).
Args:
char_count: Number of characters
Returns:
Estimated number of tokens
"""
return char_count // 4