Files
2025-09-11 23:13:58 +01:00

151 lines
4.9 KiB
Python

"""
Base AI Client with common MCP integration functionality
"""
import json
import os
import sys
from typing import Any, Dict, List, Optional
from abc import ABC
# Add the project root to the path to import config
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)
try:
from config import Config
CONFIG_AVAILABLE = True
except ImportError:
CONFIG_AVAILABLE = False
from ..core.interfaces import IAIClient, IMCPClient
class BaseAIClient(IAIClient, ABC):
"""Base class for AI clients with MCP integration"""
def __init__(self, model_name: str, provider: str, api_key: Optional[str] = None, **kwargs):
self._model_name = model_name
self._provider = provider
self._client = None
self._initialized = False
self._extra_config = kwargs
# Get API key from config if not provided
if api_key is None and CONFIG_AVAILABLE:
api_key = self._get_api_key_from_config()
if not api_key:
raise ValueError(f"API key not provided and could not be loaded from config for provider: {provider}")
self._api_key = api_key
def _get_api_key_from_config(self) -> Optional[str]:
"""Get API key from config based on provider"""
if not CONFIG_AVAILABLE:
return None
provider_key_map = {
"openai": Config.OPENAI_API_KEY,
"claude": Config.CLAUDE_API_KEY,
"grok": Config.GROK_API_KEY
}
return provider_key_map.get(self._provider)
@property
def model_name(self) -> str:
"""Get the AI model name"""
return self._model_name
async def initialize(self) -> None:
"""Initialize the AI client - to be implemented by subclasses"""
if self._initialized:
return
await self._initialize_client()
self._initialized = True
async def _initialize_client(self) -> None:
"""Initialize the specific AI client - to be implemented by subclasses"""
pass
async def process_with_tools(
self,
query: str,
available_tools: List[Dict[str, Any]],
mcp_client: IMCPClient
) -> str:
"""Process a query with MCP tools using a common pattern"""
# Format tools for the specific AI provider
formatted_tools = self._format_tools_for_provider(available_tools)
# Create initial messages
messages = [{"role": "user", "content": query}]
# Get AI response with tool calling
response = await self.chat_completion(
messages=messages,
tools=formatted_tools,
tool_choice="auto"
)
# Extract assistant message
assistant_message = response["choices"][0]["message"]
# Check if tools were called
if "tool_calls" in assistant_message and assistant_message["tool_calls"]:
# Add assistant message to conversation
messages.append(assistant_message)
# Process each tool call
for tool_call in assistant_message["tool_calls"]:
try:
# Extract tool call details
tool_name = tool_call["function"]["name"]
tool_args = json.loads(tool_call["function"]["arguments"])
# Call the tool via MCP client
tool_result = await mcp_client.call_tool(tool_name, tool_args)
# Add tool response to conversation
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": str(tool_result),
})
except Exception as e:
# Handle tool call errors
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Error calling tool: {str(e)}",
})
# Get final response from AI with tool results
final_response = await self.chat_completion(
messages=messages,
tools=formatted_tools,
tool_choice="none" # Don't allow more tool calls
)
return final_response["choices"][0]["message"]["content"]
# No tools called, return direct response
return assistant_message["content"]
def _format_tools_for_provider(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Format tools for the specific AI provider - to be implemented by subclasses"""
return tools
async def cleanup(self) -> None:
"""Clean up resources"""
if self._client:
await self._cleanup_client()
self._initialized = False
async def _cleanup_client(self) -> None:
"""Clean up the specific AI client - to be implemented by subclasses"""
pass