151 lines
4.9 KiB
Python
151 lines
4.9 KiB
Python
"""
|
|
Base AI Client with common MCP integration functionality
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Any, Dict, List, Optional
|
|
from abc import ABC
|
|
|
|
# Add the project root to the path to import config
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
|
|
try:
|
|
from config import Config
|
|
CONFIG_AVAILABLE = True
|
|
except ImportError:
|
|
CONFIG_AVAILABLE = False
|
|
|
|
from ..core.interfaces import IAIClient, IMCPClient
|
|
|
|
|
|
class BaseAIClient(IAIClient, ABC):
|
|
"""Base class for AI clients with MCP integration"""
|
|
|
|
def __init__(self, model_name: str, provider: str, api_key: Optional[str] = None, **kwargs):
|
|
self._model_name = model_name
|
|
self._provider = provider
|
|
self._client = None
|
|
self._initialized = False
|
|
self._extra_config = kwargs
|
|
|
|
# Get API key from config if not provided
|
|
if api_key is None and CONFIG_AVAILABLE:
|
|
api_key = self._get_api_key_from_config()
|
|
|
|
if not api_key:
|
|
raise ValueError(f"API key not provided and could not be loaded from config for provider: {provider}")
|
|
|
|
self._api_key = api_key
|
|
|
|
def _get_api_key_from_config(self) -> Optional[str]:
|
|
"""Get API key from config based on provider"""
|
|
if not CONFIG_AVAILABLE:
|
|
return None
|
|
|
|
provider_key_map = {
|
|
"openai": Config.OPENAI_API_KEY,
|
|
"claude": Config.CLAUDE_API_KEY,
|
|
"grok": Config.GROK_API_KEY
|
|
}
|
|
|
|
return provider_key_map.get(self._provider)
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
"""Get the AI model name"""
|
|
return self._model_name
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the AI client - to be implemented by subclasses"""
|
|
if self._initialized:
|
|
return
|
|
await self._initialize_client()
|
|
self._initialized = True
|
|
|
|
async def _initialize_client(self) -> None:
|
|
"""Initialize the specific AI client - to be implemented by subclasses"""
|
|
pass
|
|
|
|
async def process_with_tools(
|
|
self,
|
|
query: str,
|
|
available_tools: List[Dict[str, Any]],
|
|
mcp_client: IMCPClient
|
|
) -> str:
|
|
"""Process a query with MCP tools using a common pattern"""
|
|
|
|
# Format tools for the specific AI provider
|
|
formatted_tools = self._format_tools_for_provider(available_tools)
|
|
|
|
# Create initial messages
|
|
messages = [{"role": "user", "content": query}]
|
|
|
|
# Get AI response with tool calling
|
|
response = await self.chat_completion(
|
|
messages=messages,
|
|
tools=formatted_tools,
|
|
tool_choice="auto"
|
|
)
|
|
|
|
# Extract assistant message
|
|
assistant_message = response["choices"][0]["message"]
|
|
|
|
# Check if tools were called
|
|
if "tool_calls" in assistant_message and assistant_message["tool_calls"]:
|
|
# Add assistant message to conversation
|
|
messages.append(assistant_message)
|
|
|
|
# Process each tool call
|
|
for tool_call in assistant_message["tool_calls"]:
|
|
try:
|
|
# Extract tool call details
|
|
tool_name = tool_call["function"]["name"]
|
|
tool_args = json.loads(tool_call["function"]["arguments"])
|
|
|
|
# Call the tool via MCP client
|
|
tool_result = await mcp_client.call_tool(tool_name, tool_args)
|
|
|
|
# Add tool response to conversation
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": str(tool_result),
|
|
})
|
|
|
|
except Exception as e:
|
|
# Handle tool call errors
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"content": f"Error calling tool: {str(e)}",
|
|
})
|
|
|
|
# Get final response from AI with tool results
|
|
final_response = await self.chat_completion(
|
|
messages=messages,
|
|
tools=formatted_tools,
|
|
tool_choice="none" # Don't allow more tool calls
|
|
)
|
|
|
|
return final_response["choices"][0]["message"]["content"]
|
|
|
|
# No tools called, return direct response
|
|
return assistant_message["content"]
|
|
|
|
def _format_tools_for_provider(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Format tools for the specific AI provider - to be implemented by subclasses"""
|
|
return tools
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Clean up resources"""
|
|
if self._client:
|
|
await self._cleanup_client()
|
|
self._initialized = False
|
|
|
|
async def _cleanup_client(self) -> None:
|
|
"""Clean up the specific AI client - to be implemented by subclasses"""
|
|
pass
|