#!/usr/bin/env python3 """ MCP AI Client with Flexible Transport and LLM Provider Support This client can connect to MCP servers using either SSE or stdio transport and use various AI models (OpenAI, Claude, Grok) to process queries with access to MCP tools. """ import asyncio import json import sys import os from contextlib import AsyncExitStack from typing import Any, Dict, List, Optional from enum import Enum import nest_asyncio from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client # Load environment variables load_dotenv() # Apply nest_asyncio to allow nested event loops nest_asyncio.apply() # Import LLM client factory try: # Try to import from source first import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from mcp_template.llm_client.client_factory import AIClientFactory LLM_CLIENT_AVAILABLE = True except ImportError: try: # Fallback to build version from mcp_template.llm_client.client_factory import AIClientFactory LLM_CLIENT_AVAILABLE = True except ImportError: LLM_CLIENT_AVAILABLE = False class TransportType(Enum): """Supported MCP transport types""" STDIO = "stdio" SSE = "sse" class MCPAIClient: """Client for interacting with AI models using MCP tools with flexible transport and provider support.""" def __init__( self, model: str = "gpt-4o", transport: TransportType = TransportType.SSE, provider: str = "openai", temperature: float = 0.7, max_tokens: int = 1000, top_k: Optional[int] = None, top_p: Optional[float] = None, **kwargs ): """Initialize the AI MCP client. Args: model: The AI model to use. transport: The MCP transport type to use (stdio or sse). provider: The AI provider to use (openai, claude, grok). temperature: Sampling temperature for AI responses. max_tokens: Maximum tokens for AI responses. **kwargs: Additional parameters for the AI client. """ self.transport = transport self.model = model self.provider = provider self.temperature = temperature self.max_tokens = max_tokens # Initialize session and client objects self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack() # Initialize AI client using factory if not LLM_CLIENT_AVAILABLE: raise ImportError("LLM client not available. Make sure the LLM client modules are properly installed.") # Prepare additional parameters for AI client ai_kwargs = kwargs.copy() if top_k is not None: ai_kwargs["top_k"] = top_k if top_p is not None: ai_kwargs["top_p"] = top_p self.ai_client = AIClientFactory.create_client( provider=provider, model_name=model, temperature=temperature, max_tokens=max_tokens, **ai_kwargs ) self.stdio: Optional[Any] = None self.write: Optional[Any] = None # Transport-specific attributes self.read_stream: Optional[Any] = None self.write_stream: Optional[Any] = None async def connect_stdio(self, server_script_path: str = "run_mcp_server.py", server_args: List[str] = None): """Connect to an MCP server using stdio transport. Args: server_script_path: Path to the server script. server_args: Additional arguments for the server script. """ if server_args is None: server_args = ["--transport", "stdio"] # Server configuration server_params = StdioServerParameters( command="python", args=[server_script_path] + server_args, ) # Connect to the server stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) self.stdio, self.write = stdio_transport self.session = await self.exit_stack.enter_async_context( ClientSession(self.stdio, self.write) ) async def connect_sse(self, server_url: str = "http://localhost:8050/sse"): """Connect to an MCP server using SSE transport. Args: server_url: The SSE endpoint URL of the server. """ # Connect to the server sse_transport = await self.exit_stack.enter_async_context( sse_client(server_url) ) self.read_stream, self.write_stream = sse_transport self.session = await self.exit_stack.enter_async_context( ClientSession(self.read_stream, self.write_stream) ) async def connect_to_server(self, server_url: str = "http://localhost:8050/sse", server_script_path: str = "run_mcp_server.py"): """Connect to an MCP server using the configured transport. Args: server_url: The SSE endpoint URL (used for SSE transport). server_script_path: Path to the server script (used for stdio transport). """ if self.transport == TransportType.SSE: await self.connect_sse(server_url) elif self.transport == TransportType.STDIO: await self.connect_stdio(server_script_path) else: raise ValueError(f"Unsupported transport type: {self.transport}") # Initialize the connection await self.session.initialize() print(f"โœ… Connected to MCP server using {self.transport.value.upper()} transport") # List available components await self.list_available_components() async def list_available_components(self): """List all available tools, prompts, and resources.""" print("\n" + "="*60) print("AVAILABLE MCP COMPONENTS") print("="*60) # List available tools tools_result = await self.session.list_tools() print(f"\n๐Ÿ”ง Tools ({len(tools_result.tools)}):") for tool in tools_result.tools: print(f" โ€ข {tool.name}: {tool.description}") # List available prompts prompts_result = await self.session.list_prompts() print(f"\n Prompts ({len(prompts_result.prompts)}):") for prompt in prompts_result.prompts: print(f" โ€ข {prompt.name}: {prompt.description}") # List available resources resources_result = await self.session.list_resources() print(f"\n Resources ({len(resources_result.resources)}):") for resource in resources_result.resources: print(f" โ€ข {resource.uri}: {resource.name}") if resource.description: print(f" โ””โ”€ {resource.description}") print("\n" + "="*60) async def get_mcp_tools(self) -> List[Dict[str, Any]]: """Get available tools from the MCP server in AI provider format. Returns: A list of tools formatted for the specific AI provider. """ tools_result = await self.session.list_tools() # Convert MCP tools to standard format standard_tools = [ { "name": tool.name, "description": tool.description, "inputSchema": tool.inputSchema, } for tool in tools_result.tools ] # Use AI client's formatting method return self.ai_client._format_tools_for_provider(standard_tools) async def process_query(self, query: str) -> str: """Process a query using AI model and available MCP tools. Args: query: The user query. Returns: The response from the AI model. """ print(f"\n๐Ÿค” Processing query: '{query}'") # Get available tools tools = await self.get_mcp_tools() # Initial AI API call print(f"๐Ÿง  Calling {self.provider}/{self.model} with {len(tools)} available tools...") response = await self.ai_client.chat_completion( messages=[{"role": "user", "content": query}], tools=tools, tool_choice="auto", ) # Get assistant's response assistant_message = response["choices"][0]["message"] # Initialize conversation with user query and assistant response messages = [ {"role": "user", "content": query}, { "role": assistant_message["role"], "content": assistant_message.get("content"), "tool_calls": [ { "id": tc["id"], "type": tc.get("type", "function"), "function": tc["function"] } for tc in assistant_message.get("tool_calls", []) ] if assistant_message.get("tool_calls") else None, }, ] # Handle tool calls if present if "tool_calls" in assistant_message and assistant_message["tool_calls"]: print(f"๐Ÿ”ง Assistant wants to use {len(assistant_message['tool_calls'])} tool(s)") # Process each tool call for i, tool_call in enumerate(assistant_message["tool_calls"], 1): tool_name = tool_call["function"]["name"] tool_args = json.loads(tool_call["function"]["arguments"]) print(f" {i}. Calling tool: {tool_name}") print(f" Arguments: {tool_args}") # In proper MCP, arguments should be passed directly fastmcp_args = tool_args # Execute tool call try: result = await self.session.call_tool( tool_name, arguments=fastmcp_args, ) tool_result = result.content[0].text if result.content else "No result" print(f" Result: {tool_result}{'...' if len(tool_result) > 100 else ''}") # Add tool response to conversation messages.append( { "role": "tool", "tool_call_id": tool_call["id"], "content": tool_result, } ) except Exception as e: error_msg = f"Tool execution failed: {e}" print(f" โŒ Error: {error_msg}") messages.append( { "role": "tool", "tool_call_id": tool_call["id"], "content": error_msg, } ) # Get final response from AI with tool results print(f"๐Ÿง  Getting final response from {self.provider}/{self.model}...") final_response = await self.ai_client.chat_completion( messages=messages, tools=tools, tool_choice="none", # Don't allow more tool calls ) final_answer = final_response["choices"][0]["message"]["content"] print(f"๐Ÿ’ก Final answer: {final_answer[:200]}{'...' if len(final_answer) > 200 else ''}") return final_answer # No tool calls, just return the direct response direct_answer = assistant_message["content"] print(f"๐Ÿ’ก Direct answer: {direct_answer[:200]}{'...' if len(direct_answer) > 200 else ''}") return direct_answer async def interactive_session(self): """Start an interactive session for querying.""" print("๐Ÿš€ Starting interactive MCP-AI session") print(f"๐Ÿ“ก Transport: {self.transport.value.upper()}") print(f"๐Ÿค– Model: {self.provider}/{self.model}") print(f"๐ŸŒก๏ธ Temperature: {self.temperature}") print(f"๐Ÿ“ Max Tokens: {self.max_tokens}") print("๐Ÿ’ก Type 'quit' or 'exit' to end the session") print("-" * 50) while True: try: query = input("\nโ“ Your query: ").strip() if query.lower() in ['quit', 'exit', 'q']: print("๐Ÿ‘‹ Goodbye!") break if not query: continue # Process the query response = await self.process_query(query) print(f"\n๐ŸŽฏ Response: {response}") except KeyboardInterrupt: print("\n๐Ÿ‘‹ Session interrupted. Goodbye!") break except Exception as e: print(f"โŒ Error processing query: {e}") continue async def cleanup(self): """Clean up resources.""" try: await self.exit_stack.aclose() print("๐Ÿงน Resources cleaned up successfully") except Exception as e: print(f"โš ๏ธ Cleanup warning: {e}") async def main(): """Main entry point for the client.""" import argparse parser = argparse.ArgumentParser(description="MCP AI Client with Flexible Transport and LLM Provider Support") parser.add_argument( "--transport", choices=["sse", "stdio"], default="sse", help="MCP transport type (default: sse)" ) parser.add_argument( "--provider", choices=["openai", "claude", "grok"], default="openai", help="AI provider to use (default: openai)" ) parser.add_argument( "--model", help="AI model to use (defaults based on provider: openai=gpt-4o, claude=claude-3-opus-20240229, grok=grok-1)" ) parser.add_argument( "--temperature", type=float, default=0.7, help="Sampling temperature (default: 0.7)" ) parser.add_argument( "--max-tokens", type=int, default=1000, help="Maximum tokens for response (default: 1000)" ) parser.add_argument( "--top-k", type=int, help="Top-k sampling parameter (provider-specific)" ) parser.add_argument( "--top-p", type=float, help="Top-p sampling parameter (provider-specific)" ) parser.add_argument( "--server-url", default="http://localhost:8050/sse", help="Server URL for SSE transport (default: http://localhost:8050/sse)" ) parser.add_argument( "--server-script", default="run_mcp_server.py", help="Server script path for stdio transport (default: run_mcp_server.py)" ) parser.add_argument( "--query", help="Single query to process (if not provided, starts interactive mode)" ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose output" ) args = parser.parse_args() # Validate API keys based on provider provider = args.provider api_key_env_vars = { "openai": "OPENAI_API_KEY", "claude": "ANTHROPIC_API_KEY", "grok": "GROK_API_KEY" } api_key_env = api_key_env_vars.get(provider) if not api_key_env: print(f"โŒ Error: Unknown provider: {provider}") return api_key = os.getenv(api_key_env) if not api_key: print(f"โŒ Error: {api_key_env} environment variable not set") print(f" Please set your {provider.upper()} API key:") print(f" export {api_key_env}='your-api-key-here'") print(f" Or create a .env file with: {api_key_env}=your-api-key-here") return # Basic API key validation if provider == "openai" and (not api_key.startswith("sk-") or len(api_key) < 20): print("โŒ Error: OPENAI_API_KEY appears to be invalid") print(" API key should start with 'sk-' and be at least 20 characters long") return # Set default model if not provided model = args.model if not model: default_models = { "openai": "gpt-4o", "claude": "claude-3-opus-20240229", "grok": "grok-1" } model = default_models.get(provider, "gpt-4o") # Create client transport = TransportType(args.transport) client = MCPAIClient( model=model, transport=transport, provider=provider, temperature=args.temperature, max_tokens=args.max_tokens, top_k=args.top_k, top_p=args.top_p ) try: # Connect to server if transport == TransportType.SSE: await client.connect_to_server(server_url=args.server_url) else: await client.connect_to_server(server_script_path=args.server_script) # Handle single query or interactive mode if args.query: response = await client.process_query(args.query) print(f"\n๐ŸŽฏ Response: {response}") else: await client.interactive_session() except Exception as e: print(f"โŒ Error: {e}") if "ConnectionRefusedError" in str(e): if transport == TransportType.SSE: print("๐Ÿ’ก Make sure the MCP server is running with SSE transport:") print(f" python {args.server_script} --transport sse") else: print("๐Ÿ’ก Make sure the server script path is correct") finally: await client.cleanup() # Backward compatibility alias MCPOpenAIClient = MCPAIClient if __name__ == "__main__": asyncio.run(main())