106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
"""
|
|
OpenAI Client Implementation
|
|
"""
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
try:
|
|
from openai import AsyncOpenAI
|
|
OPENAI_AVAILABLE = True
|
|
except ImportError:
|
|
OPENAI_AVAILABLE = False
|
|
|
|
from .base_client import BaseAIClient
|
|
from config import Config
|
|
|
|
class OpenAIClient(BaseAIClient):
|
|
"""OpenAI client with MCP integration"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "gpt-4o",
|
|
api_key: Optional[str] = None,
|
|
**kwargs
|
|
):
|
|
if not OPENAI_AVAILABLE:
|
|
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
|
|
|
super().__init__(model_name, "openai", api_key, **kwargs)
|
|
|
|
# OpenAI specific configuration
|
|
self._temperature = kwargs.get("temperature", 0.7)
|
|
self._max_tokens = kwargs.get("max_tokens", 1000)
|
|
self._api_key = api_key or Config.OPENAI_API_KEY
|
|
|
|
async def _initialize_client(self) -> None:
|
|
"""Initialize the OpenAI client"""
|
|
self._client = AsyncOpenAI(api_key=self._api_key)
|
|
|
|
async def chat_completion(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Perform OpenAI chat completion"""
|
|
if not self._initialized:
|
|
await self.initialize()
|
|
|
|
# Prepare request parameters
|
|
request_params = {
|
|
"model": self._model_name,
|
|
"messages": messages,
|
|
"temperature": self._temperature,
|
|
"max_tokens": self._max_tokens,
|
|
}
|
|
|
|
# Add tools if provided
|
|
if tools:
|
|
request_params["tools"] = tools
|
|
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
|
|
|
# Make the API call
|
|
response = await self._client.chat.completions.create(**request_params)
|
|
|
|
# Convert to standard format
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": choice.message.role,
|
|
"content": choice.message.content,
|
|
"tool_calls": [
|
|
{
|
|
"id": tool_call.id,
|
|
"function": {
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments,
|
|
}
|
|
}
|
|
for tool_call in (choice.message.tool_calls or [])
|
|
] if choice.message.tool_calls else None,
|
|
}
|
|
}
|
|
for choice in response.choices
|
|
]
|
|
}
|
|
|
|
def _format_tools_for_provider(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Format tools for OpenAI's expected format"""
|
|
formatted_tools = []
|
|
for tool in tools:
|
|
formatted_tool = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool["name"],
|
|
"description": tool["description"],
|
|
"parameters": tool["inputSchema"],
|
|
}
|
|
}
|
|
formatted_tools.append(formatted_tool)
|
|
return formatted_tools
|
|
|
|
async def _cleanup_client(self) -> None:
|
|
"""Clean up OpenAI client"""
|
|
if self._client:
|
|
await self._client.close()
|