81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
|
|
"""
|
||
|
|
Token counting utilities for document processing.
|
||
|
|
"""
|
||
|
|
import tiktoken
|
||
|
|
from typing import Dict, List, Optional, Union
|
||
|
|
from loguru import logger
|
||
|
|
|
||
|
|
# Default models to use for token counting
|
||
|
|
DEFAULT_MODEL = "gpt-4o"
|
||
|
|
|
||
|
|
def count_tokens(text: str, model: str = DEFAULT_MODEL) -> int:
|
||
|
|
"""
|
||
|
|
Count the number of tokens in a text string using tiktoken.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: The text to count tokens for
|
||
|
|
model: The model to use for token counting (default: gpt-4o)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of tokens in the text
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
encoding = tiktoken.encoding_for_model(model)
|
||
|
|
return len(encoding.encode(text))
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"Error counting tokens with model {model}: {str(e)}")
|
||
|
|
# Fallback to cl100k_base encoding if model-specific encoding fails
|
||
|
|
try:
|
||
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||
|
|
return len(encoding.encode(text))
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error counting tokens with fallback encoding: {str(e)}")
|
||
|
|
# If all else fails, use a rough approximation (4 chars per token)
|
||
|
|
return len(text) // 4
|
||
|
|
|
||
|
|
def truncate_by_tokens(text: str, max_tokens: int, model: str = DEFAULT_MODEL) -> str:
|
||
|
|
"""
|
||
|
|
Truncate text to fit within a maximum token count.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: The text to truncate
|
||
|
|
max_tokens: Maximum number of tokens to allow
|
||
|
|
model: The model to use for token counting (default: gpt-4o)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Truncated text that fits within max_tokens
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
encoding = tiktoken.encoding_for_model(model)
|
||
|
|
tokens = encoding.encode(text)
|
||
|
|
|
||
|
|
if len(tokens) <= max_tokens:
|
||
|
|
return text
|
||
|
|
|
||
|
|
# Truncate tokens and decode
|
||
|
|
truncated_tokens = tokens[:max_tokens]
|
||
|
|
truncated_text = encoding.decode(truncated_tokens)
|
||
|
|
|
||
|
|
# Add truncation indicator
|
||
|
|
return truncated_text + "...(truncated)"
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"Error truncating by tokens with model {model}: {str(e)}")
|
||
|
|
# Fallback to character-based truncation if token-based fails
|
||
|
|
approx_chars = max_tokens * 4 # Rough approximation
|
||
|
|
if len(text) <= approx_chars:
|
||
|
|
return text
|
||
|
|
return text[:approx_chars] + "...(truncated)"
|
||
|
|
|
||
|
|
def estimate_tokens_from_chars(char_count: int) -> int:
|
||
|
|
"""
|
||
|
|
Estimate the number of tokens from character count.
|
||
|
|
This is a rough approximation (4 chars per token on average).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
char_count: Number of characters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Estimated number of tokens
|
||
|
|
"""
|
||
|
|
return char_count // 4
|