Initial commit
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Token counting utilities for document processing.
|
||||
"""
|
||||
import tiktoken
|
||||
from typing import Dict, List, Optional, Union
|
||||
from loguru import logger
|
||||
|
||||
# Default models to use for token counting
|
||||
DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
def count_tokens(text: str, model: str = DEFAULT_MODEL) -> int:
|
||||
"""
|
||||
Count the number of tokens in a text string using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for
|
||||
model: The model to use for token counting (default: gpt-4o)
|
||||
|
||||
Returns:
|
||||
Number of tokens in the text
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(text))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error counting tokens with model {model}: {str(e)}")
|
||||
# Fallback to cl100k_base encoding if model-specific encoding fails
|
||||
try:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(text))
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens with fallback encoding: {str(e)}")
|
||||
# If all else fails, use a rough approximation (4 chars per token)
|
||||
return len(text) // 4
|
||||
|
||||
def truncate_by_tokens(text: str, max_tokens: int, model: str = DEFAULT_MODEL) -> str:
|
||||
"""
|
||||
Truncate text to fit within a maximum token count.
|
||||
|
||||
Args:
|
||||
text: The text to truncate
|
||||
max_tokens: Maximum number of tokens to allow
|
||||
model: The model to use for token counting (default: gpt-4o)
|
||||
|
||||
Returns:
|
||||
Truncated text that fits within max_tokens
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
tokens = encoding.encode(text)
|
||||
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
|
||||
# Truncate tokens and decode
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_text = encoding.decode(truncated_tokens)
|
||||
|
||||
# Add truncation indicator
|
||||
return truncated_text + "...(truncated)"
|
||||
except Exception as e:
|
||||
logger.warning(f"Error truncating by tokens with model {model}: {str(e)}")
|
||||
# Fallback to character-based truncation if token-based fails
|
||||
approx_chars = max_tokens * 4 # Rough approximation
|
||||
if len(text) <= approx_chars:
|
||||
return text
|
||||
return text[:approx_chars] + "...(truncated)"
|
||||
|
||||
def estimate_tokens_from_chars(char_count: int) -> int:
|
||||
"""
|
||||
Estimate the number of tokens from character count.
|
||||
This is a rough approximation (4 chars per token on average).
|
||||
|
||||
Args:
|
||||
char_count: Number of characters
|
||||
|
||||
Returns:
|
||||
Estimated number of tokens
|
||||
"""
|
||||
return char_count // 4
|
||||
Reference in New Issue
Block a user