171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
|
|
"""
|
||
|
|
Model parameters for AI models.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, Any, Optional, List
|
||
|
|
from pydantic import BaseModel, Field, validator
|
||
|
|
|
||
|
|
|
||
|
|
class ModelParameters(BaseModel):
|
||
|
|
"""Parameters for AI model generation."""
|
||
|
|
|
||
|
|
# Basic parameters
|
||
|
|
temperature: Optional[float] = Field(
|
||
|
|
0.7,
|
||
|
|
description="Controls randomness: 0 is deterministic, higher values are more random",
|
||
|
|
ge=0.0,
|
||
|
|
le=2.0
|
||
|
|
)
|
||
|
|
|
||
|
|
max_tokens: Optional[int] = Field(
|
||
|
|
1000,
|
||
|
|
description="Maximum number of tokens to generate",
|
||
|
|
gt=0
|
||
|
|
)
|
||
|
|
|
||
|
|
# Sampling parameters
|
||
|
|
top_p: Optional[float] = Field(
|
||
|
|
1.0,
|
||
|
|
description="Nucleus sampling: consider tokens with top_p probability mass",
|
||
|
|
ge=0.0,
|
||
|
|
le=1.0
|
||
|
|
)
|
||
|
|
|
||
|
|
top_k: Optional[int] = Field(
|
||
|
|
None,
|
||
|
|
description="Only sample from the top k tokens",
|
||
|
|
gt=0
|
||
|
|
)
|
||
|
|
|
||
|
|
# Repetition control
|
||
|
|
frequency_penalty: Optional[float] = Field(
|
||
|
|
0.0,
|
||
|
|
description="Penalizes repeated tokens",
|
||
|
|
ge=-2.0,
|
||
|
|
le=2.0
|
||
|
|
)
|
||
|
|
|
||
|
|
presence_penalty: Optional[float] = Field(
|
||
|
|
0.0,
|
||
|
|
description="Penalizes repeated topics",
|
||
|
|
ge=-2.0,
|
||
|
|
le=2.0
|
||
|
|
)
|
||
|
|
|
||
|
|
# Advanced parameters
|
||
|
|
stop_sequences: Optional[List[str]] = Field(
|
||
|
|
None,
|
||
|
|
description="Sequences where the API will stop generating"
|
||
|
|
)
|
||
|
|
|
||
|
|
min_p: Optional[float] = Field(
|
||
|
|
None,
|
||
|
|
description="Minimum probability threshold for token selection",
|
||
|
|
ge=0.0,
|
||
|
|
le=1.0
|
||
|
|
)
|
||
|
|
|
||
|
|
repeat_penalty: Optional[float] = Field(
|
||
|
|
None,
|
||
|
|
description="Penalty for repeating tokens",
|
||
|
|
ge=0.0
|
||
|
|
)
|
||
|
|
|
||
|
|
presence_penalty_tokens: Optional[int] = Field(
|
||
|
|
None,
|
||
|
|
description="Number of tokens to consider for presence penalty",
|
||
|
|
gt=0
|
||
|
|
)
|
||
|
|
|
||
|
|
# System prompt
|
||
|
|
system_prompt: Optional[str] = Field(
|
||
|
|
None,
|
||
|
|
description="System prompt to guide the model's behavior"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Function calling
|
||
|
|
function_calling: Optional[bool] = Field(
|
||
|
|
None,
|
||
|
|
description="Whether to enable function calling"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Additional parameters that might be model-specific
|
||
|
|
extra_params: Optional[Dict[str, Any]] = Field(
|
||
|
|
None,
|
||
|
|
description="Additional model-specific parameters"
|
||
|
|
)
|
||
|
|
|
||
|
|
@validator('temperature', 'top_p', 'frequency_penalty', 'presence_penalty', pre=True)
|
||
|
|
def validate_float_params(cls, v):
|
||
|
|
"""Validate float parameters."""
|
||
|
|
if v is not None and not isinstance(v, bool): # Avoid converting bool to float
|
||
|
|
return float(v)
|
||
|
|
return v
|
||
|
|
|
||
|
|
@validator('max_tokens', 'top_k', pre=True)
|
||
|
|
def validate_int_params(cls, v):
|
||
|
|
"""Validate integer parameters."""
|
||
|
|
if v is not None and not isinstance(v, bool): # Avoid converting bool to int
|
||
|
|
return int(v)
|
||
|
|
return v
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Convert parameters to a dictionary, excluding None values.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary of parameters.
|
||
|
|
"""
|
||
|
|
result = {}
|
||
|
|
for key, value in self.dict().items():
|
||
|
|
if value is not None and key != 'extra_params':
|
||
|
|
result[key] = value
|
||
|
|
|
||
|
|
# Add any extra parameters
|
||
|
|
if self.extra_params:
|
||
|
|
result.update(self.extra_params)
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def for_provider(self, provider: str) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get parameters formatted for a specific provider.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
provider: Provider name (e.g., 'openai', 'ollama', 'anthropic').
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary of parameters formatted for the provider.
|
||
|
|
"""
|
||
|
|
params = self.to_dict()
|
||
|
|
|
||
|
|
# Handle provider-specific parameter naming
|
||
|
|
if provider == 'openai':
|
||
|
|
# OpenAI uses 'stop' instead of 'stop_sequences'
|
||
|
|
if 'stop_sequences' in params:
|
||
|
|
params['stop'] = params.pop('stop_sequences')
|
||
|
|
|
||
|
|
elif provider == 'ollama':
|
||
|
|
# Ollama has specific parameter handling
|
||
|
|
# Remove parameters not supported by Ollama
|
||
|
|
params_to_keep = ['temperature', 'top_p', 'top_k', 'max_tokens', 'stop_sequences']
|
||
|
|
params = {k: v for k, v in params.items() if k in params_to_keep}
|
||
|
|
|
||
|
|
# Rename stop_sequences to stop if present
|
||
|
|
if 'stop_sequences' in params:
|
||
|
|
params['stop'] = params.pop('stop_sequences')
|
||
|
|
|
||
|
|
elif provider == 'anthropic':
|
||
|
|
# Anthropic uses 'stop_sequences' and different temperature scaling
|
||
|
|
if 'temperature' in params:
|
||
|
|
# Anthropic's temperature is typically 0-1
|
||
|
|
params['temperature'] = min(params['temperature'], 1.0)
|
||
|
|
|
||
|
|
elif provider == 'cohere':
|
||
|
|
# Cohere uses 'stop_sequences' and has some unique parameters
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Add more provider-specific conversions as needed
|
||
|
|
|
||
|
|
return params
|