Initial commit for deployment
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Model parameters for AI models.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ModelParameters(BaseModel):
|
||||
"""Parameters for AI model generation."""
|
||||
|
||||
# Basic parameters
|
||||
temperature: Optional[float] = Field(
|
||||
0.7,
|
||||
description="Controls randomness: 0 is deterministic, higher values are more random",
|
||||
ge=0.0,
|
||||
le=2.0
|
||||
)
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
1000,
|
||||
description="Maximum number of tokens to generate",
|
||||
gt=0
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
top_p: Optional[float] = Field(
|
||||
1.0,
|
||||
description="Nucleus sampling: consider tokens with top_p probability mass",
|
||||
ge=0.0,
|
||||
le=1.0
|
||||
)
|
||||
|
||||
top_k: Optional[int] = Field(
|
||||
None,
|
||||
description="Only sample from the top k tokens",
|
||||
gt=0
|
||||
)
|
||||
|
||||
# Repetition control
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
0.0,
|
||||
description="Penalizes repeated tokens",
|
||||
ge=-2.0,
|
||||
le=2.0
|
||||
)
|
||||
|
||||
presence_penalty: Optional[float] = Field(
|
||||
0.0,
|
||||
description="Penalizes repeated topics",
|
||||
ge=-2.0,
|
||||
le=2.0
|
||||
)
|
||||
|
||||
# Advanced parameters
|
||||
stop_sequences: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Sequences where the API will stop generating"
|
||||
)
|
||||
|
||||
min_p: Optional[float] = Field(
|
||||
None,
|
||||
description="Minimum probability threshold for token selection",
|
||||
ge=0.0,
|
||||
le=1.0
|
||||
)
|
||||
|
||||
repeat_penalty: Optional[float] = Field(
|
||||
None,
|
||||
description="Penalty for repeating tokens",
|
||||
ge=0.0
|
||||
)
|
||||
|
||||
presence_penalty_tokens: Optional[int] = Field(
|
||||
None,
|
||||
description="Number of tokens to consider for presence penalty",
|
||||
gt=0
|
||||
)
|
||||
|
||||
# System prompt
|
||||
system_prompt: Optional[str] = Field(
|
||||
None,
|
||||
description="System prompt to guide the model's behavior"
|
||||
)
|
||||
|
||||
# Function calling
|
||||
function_calling: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether to enable function calling"
|
||||
)
|
||||
|
||||
# Additional parameters that might be model-specific
|
||||
extra_params: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Additional model-specific parameters"
|
||||
)
|
||||
|
||||
@validator('temperature', 'top_p', 'frequency_penalty', 'presence_penalty', pre=True)
|
||||
def validate_float_params(cls, v):
|
||||
"""Validate float parameters."""
|
||||
if v is not None and not isinstance(v, bool): # Avoid converting bool to float
|
||||
return float(v)
|
||||
return v
|
||||
|
||||
@validator('max_tokens', 'top_k', pre=True)
|
||||
def validate_int_params(cls, v):
|
||||
"""Validate integer parameters."""
|
||||
if v is not None and not isinstance(v, bool): # Avoid converting bool to int
|
||||
return int(v)
|
||||
return v
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert parameters to a dictionary, excluding None values.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters.
|
||||
"""
|
||||
result = {}
|
||||
for key, value in self.dict().items():
|
||||
if value is not None and key != 'extra_params':
|
||||
result[key] = value
|
||||
|
||||
# Add any extra parameters
|
||||
if self.extra_params:
|
||||
result.update(self.extra_params)
|
||||
|
||||
return result
|
||||
|
||||
def for_provider(self, provider: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get parameters formatted for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., 'openai', 'ollama', 'anthropic').
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters formatted for the provider.
|
||||
"""
|
||||
params = self.to_dict()
|
||||
|
||||
# Handle provider-specific parameter naming
|
||||
if provider == 'openai':
|
||||
# OpenAI uses 'stop' instead of 'stop_sequences'
|
||||
if 'stop_sequences' in params:
|
||||
params['stop'] = params.pop('stop_sequences')
|
||||
|
||||
elif provider == 'ollama':
|
||||
# Ollama has specific parameter handling
|
||||
# Remove parameters not supported by Ollama
|
||||
params_to_keep = ['temperature', 'top_p', 'top_k', 'max_tokens', 'stop_sequences']
|
||||
params = {k: v for k, v in params.items() if k in params_to_keep}
|
||||
|
||||
# Rename stop_sequences to stop if present
|
||||
if 'stop_sequences' in params:
|
||||
params['stop'] = params.pop('stop_sequences')
|
||||
|
||||
elif provider == 'anthropic':
|
||||
# Anthropic uses 'stop_sequences' and different temperature scaling
|
||||
if 'temperature' in params:
|
||||
# Anthropic's temperature is typically 0-1
|
||||
params['temperature'] = min(params['temperature'], 1.0)
|
||||
|
||||
elif provider == 'cohere':
|
||||
# Cohere uses 'stop_sequences' and has some unique parameters
|
||||
pass
|
||||
|
||||
# Add more provider-specific conversions as needed
|
||||
|
||||
return params
|
||||
Reference in New Issue
Block a user