Files
2025-05-09 15:41:16 +01:00

171 lines
4.8 KiB
Python

"""
Model parameters for AI models.
"""
from typing import Dict, Any, Optional, List
from pydantic import BaseModel, Field, validator
class ModelParameters(BaseModel):
"""Parameters for AI model generation."""
# Basic parameters
temperature: Optional[float] = Field(
0.7,
description="Controls randomness: 0 is deterministic, higher values are more random",
ge=0.0,
le=2.0
)
max_tokens: Optional[int] = Field(
1000,
description="Maximum number of tokens to generate",
gt=0
)
# Sampling parameters
top_p: Optional[float] = Field(
1.0,
description="Nucleus sampling: consider tokens with top_p probability mass",
ge=0.0,
le=1.0
)
top_k: Optional[int] = Field(
None,
description="Only sample from the top k tokens",
gt=0
)
# Repetition control
frequency_penalty: Optional[float] = Field(
0.0,
description="Penalizes repeated tokens",
ge=-2.0,
le=2.0
)
presence_penalty: Optional[float] = Field(
0.0,
description="Penalizes repeated topics",
ge=-2.0,
le=2.0
)
# Advanced parameters
stop_sequences: Optional[List[str]] = Field(
None,
description="Sequences where the API will stop generating"
)
min_p: Optional[float] = Field(
None,
description="Minimum probability threshold for token selection",
ge=0.0,
le=1.0
)
repeat_penalty: Optional[float] = Field(
None,
description="Penalty for repeating tokens",
ge=0.0
)
presence_penalty_tokens: Optional[int] = Field(
None,
description="Number of tokens to consider for presence penalty",
gt=0
)
# System prompt
system_prompt: Optional[str] = Field(
None,
description="System prompt to guide the model's behavior"
)
# Function calling
function_calling: Optional[bool] = Field(
None,
description="Whether to enable function calling"
)
# Additional parameters that might be model-specific
extra_params: Optional[Dict[str, Any]] = Field(
None,
description="Additional model-specific parameters"
)
@validator('temperature', 'top_p', 'frequency_penalty', 'presence_penalty', pre=True)
def validate_float_params(cls, v):
"""Validate float parameters."""
if v is not None and not isinstance(v, bool): # Avoid converting bool to float
return float(v)
return v
@validator('max_tokens', 'top_k', pre=True)
def validate_int_params(cls, v):
"""Validate integer parameters."""
if v is not None and not isinstance(v, bool): # Avoid converting bool to int
return int(v)
return v
def to_dict(self) -> Dict[str, Any]:
"""
Convert parameters to a dictionary, excluding None values.
Returns:
Dictionary of parameters.
"""
result = {}
for key, value in self.dict().items():
if value is not None and key != 'extra_params':
result[key] = value
# Add any extra parameters
if self.extra_params:
result.update(self.extra_params)
return result
def for_provider(self, provider: str) -> Dict[str, Any]:
"""
Get parameters formatted for a specific provider.
Args:
provider: Provider name (e.g., 'openai', 'ollama', 'anthropic').
Returns:
Dictionary of parameters formatted for the provider.
"""
params = self.to_dict()
# Handle provider-specific parameter naming
if provider == 'openai':
# OpenAI uses 'stop' instead of 'stop_sequences'
if 'stop_sequences' in params:
params['stop'] = params.pop('stop_sequences')
elif provider == 'ollama':
# Ollama has specific parameter handling
# Remove parameters not supported by Ollama
params_to_keep = ['temperature', 'top_p', 'top_k', 'max_tokens', 'stop_sequences']
params = {k: v for k, v in params.items() if k in params_to_keep}
# Rename stop_sequences to stop if present
if 'stop_sequences' in params:
params['stop'] = params.pop('stop_sequences')
elif provider == 'anthropic':
# Anthropic uses 'stop_sequences' and different temperature scaling
if 'temperature' in params:
# Anthropic's temperature is typically 0-1
params['temperature'] = min(params['temperature'], 1.0)
elif provider == 'cohere':
# Cohere uses 'stop_sequences' and has some unique parameters
pass
# Add more provider-specific conversions as needed
return params