initial mcp server setup

This commit is contained in:
OwusuBlessing
2025-09-11 23:13:58 +01:00
commit 20f96c0f30
141 changed files with 14444 additions and 0 deletions
+198
View File
@@ -0,0 +1,198 @@
# Test Suite
This directory contains comprehensive tests for the MCP Template system.
## Test Structure
```
tests/
├── unit/ # Unit tests (fast, no external dependencies)
├── integration/ # Integration tests (require real server connections)
├── e2e/ # End-to-end tests
└── README.md # This file
```
## Test Types
### Unit Tests
- **Fast**: Run in milliseconds
- **Isolated**: No external dependencies
- **Focused**: Test individual components
- **Run with**: `python run_tests.py unit`
### Integration Tests
- **Real connections**: Test with actual MCP server
- **API calls**: May require API keys
- **Slower**: Take longer to execute
- **Run with**: `python run_tests.py integration`
## Running Tests
### Quick Start
```bash
# Run all tests
python run_tests.py all
# Run unit tests only
python run_tests.py unit
# Run integration tests only
python run_tests.py integration
# Run specific test file
python run_tests.py specific tests/integration/test_mcp_integration.py
```
### Advanced Options
```bash
# Run with coverage report
python run_tests.py all --coverage
# Skip slow tests
python run_tests.py integration --skip-slow
# Run specific test markers
pytest -m "integration and not slow"
pytest -m "unit"
pytest -m "requires_api_key"
```
## Test Markers
- **`@pytest.mark.unit`**: Unit tests (fast, isolated)
- **`@pytest.mark.integration`**: Integration tests (require server)
- **`@pytest.mark.requires_api_key`**: Tests requiring API keys
- **`@pytest.mark.slow`**: Slow-running tests
- **`@pytest.mark.asyncio`**: Async tests
## Environment Setup
### API Keys (for integration tests)
Create a `.env` file in the project root:
```bash
# OpenAI (required for most tests)
OPENAI_API_KEY=your_openai_api_key_here
# Optional: Other providers
ANTHROPIC_API_KEY=your_anthropic_key_here
GROK_API_KEY=your_grok_key_here
```
### Test Configuration
Tests are configured via `pytest.ini`:
- Async mode enabled
- Custom markers defined
- Test discovery patterns
- Warning filters
## Integration Test Details
### MCP Server Management
Integration tests automatically:
1. Start MCP server processes
2. Wait for server readiness
3. Connect test clients
4. Clean up processes after tests
### Transport Testing
- **SSE Transport**: HTTP-based, tested with real server
- **STDIO Transport**: Direct process communication
### AI Provider Testing
- **OpenAI**: Primary provider for most tests
- **Claude**: Tested if ANTHROPIC_API_KEY available
- **Grok**: Tested if GROK_API_KEY available
## Writing Tests
### Unit Test Example
```python
import pytest
from mcp_template.core.types import TransportType
@pytest.mark.unit
def test_transport_enum():
"""Test TransportType enum values"""
assert TransportType.SSE.value == "sse"
assert TransportType.STDIO.value == "stdio"
```
### Integration Test Example
```python
import pytest
from mcp_llm_client import MCPAIClient, TransportType
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_client_connection(sse_server):
"""Test MCP client can connect to server"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(model="gpt-4o", transport=TransportType.SSE)
try:
await client.connect(sse_server.url)
tools = await client.get_mcp_tools()
assert len(tools) > 0
finally:
await client.disconnect()
```
## Test Coverage
Run tests with coverage:
```bash
# Generate HTML coverage report
python run_tests.py all --coverage
# View coverage report
open htmlcov/index.html
```
## Troubleshooting
### Common Issues
1. **Server won't start**: Check port availability
2. **API key errors**: Ensure `.env` file exists
3. **Connection timeouts**: Increase timeout values
4. **Process cleanup**: Tests handle cleanup automatically
### Debug Mode
```bash
# Run with verbose output
pytest -v -s tests/integration/test_mcp_integration.py::TestMCPClientIntegration::test_sse_transport_connection
# Run single test with debugging
pytest -v --pdb tests/integration/test_mcp_integration.py -k "test_sse_transport"
```
## CI/CD Integration
Tests are designed to work in CI/CD environments:
- Automatic skipping when API keys unavailable
- Process cleanup on test failure
- No manual server startup required
- Cross-platform compatibility
## Performance Testing
Some tests measure performance:
- Connection establishment time
- Query response times
- Concurrent connection handling
- Resource cleanup efficiency
Run performance tests:
```bash
pytest -m "slow" -v
```
+2
View File
@@ -0,0 +1,2 @@
# Test suite for MCP Template
# Run tests with: python -m pytest tests/
+255
View File
@@ -0,0 +1,255 @@
"""
End-to-end tests for complete MCP workflow
"""
import pytest
import asyncio
import tempfile
import os
from pathlib import Path
from unittest.mock import Mock, AsyncMock, patch
from mcp_template.core.types import MCPTool, MCPServerConfig, TransportType
from mcp_template.server.fastmcp_server import FastMCPServer
from mcp_template.llm_client.openai_client import OpenAIClient
from mcp_template.config.config_manager import ConfigManager
class TestEndToEnd:
"""End-to-end tests for complete MCP workflow"""
@pytest.fixture
async def knowledge_base_server(self):
"""Create a knowledge base server for e2e testing"""
kb_data = {
"company_policy": "Our company offers 20 days of paid vacation annually.",
"benefits": "We provide health insurance, 401k matching, and flexible work hours.",
"faq": "Common questions about our policies and procedures."
}
async def get_kb():
formatted = "Knowledge Base:\n\n"
for key, value in kb_data.items():
formatted += f"**{key.replace('_', ' ').title()}:**\n{value}\n\n"
return formatted
async def search_kb(query: str):
query_lower = query.lower()
results = []
for key, value in kb_data.items():
if query_lower in key.lower() or query_lower in value.lower():
results.append(f"**{key.replace('_', ' ').title()}:**\n{value}")
if not results:
return f"No information found for: {query}"
return "\n\n".join(results)
tools = [
MCPTool(
name="get_knowledge_base",
description="Retrieve the entire knowledge base",
input_schema={"type": "object", "properties": {}},
handler=get_kb
),
MCPTool(
name="search_kb",
description="Search the knowledge base for specific information",
input_schema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"]
},
handler=search_kb
)
]
config = MCPServerConfig(
name="Knowledge Base Server",
tools=tools
)
server = FastMCPServer(config)
await server.initialize()
return server
@pytest.fixture
def mock_openai_response(self):
"""Mock OpenAI API response with tool calls"""
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "I'll search the knowledge base for vacation policy information.",
"tool_calls": [
{
"id": "call_123",
"function": {
"name": "search_kb",
"arguments": '{"query": "vacation"}'
}
}
]
}
}
]
}
@patch('src.clients.openai_client.AsyncOpenAI')
@pytest.mark.asyncio
async def test_complete_mcp_workflow(self, mock_openai_class, knowledge_base_server, mock_openai_response):
"""Test complete MCP workflow from query to response"""
# Mock OpenAI client setup
mock_client = Mock()
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(
role="assistant",
content="I'll search the knowledge base for vacation policy information.",
tool_calls=[
Mock(
id="call_123",
function=Mock(
name="search_kb",
arguments='{"query": "vacation"}'
)
)
]
))
]
# Second call response (after tool execution)
mock_final_response = Mock()
mock_final_response.choices = [
Mock(message=Mock(
content="Based on the knowledge base, our company offers 20 days of paid vacation annually."
))
]
mock_client.chat.completions.create = AsyncMock(side_effect=[mock_response, mock_final_response])
mock_openai_class.return_value = mock_client
# Create AI client
ai_client = OpenAIClient("gpt-4o", "test-key")
await ai_client.initialize()
# Create mock MCP client that uses the real server
class MockMCPClient:
def __init__(self, server):
self.server = server
async def call_tool(self, name, arguments):
return await self.server.call_tool(name, arguments)
async def list_tools(self):
return await self.server.list_tools()
mcp_client = MockMCPClient(knowledge_base_server)
# Test the complete workflow
query = "What's the company vacation policy?"
response = await ai_client.process_with_tools(query, [], mcp_client)
# Verify the workflow
assert "vacation" in response.lower() or "20 days" in response
# Verify OpenAI was called twice (initial + after tool execution)
assert mock_client.chat.completions.create.call_count == 2
# Verify the tool was called correctly
tool_result = await mcp_client.call_tool("search_kb", {"query": "vacation"})
assert "vacation" in tool_result.lower()
@pytest.mark.asyncio
async def test_configuration_workflow(self, knowledge_base_server):
"""Test configuration loading and usage"""
# Create temporary config file
config_data = {
"server": {
"name": "Test Server",
"port": 8080,
"transport": "sse"
},
"client": {
"provider": "openai",
"model": "gpt-4o"
},
"api_keys": {
"openai": "test-key"
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
import json
json.dump(config_data, f)
config_file = f.name
try:
# Test config manager
config_manager = ConfigManager(config_file)
server_config = await config_manager.get_server_config()
assert server_config["name"] == "Test Server"
assert server_config["port"] == 8080
client_config = await config_manager.get_client_config()
assert client_config["provider"] == "openai"
assert client_config["model"] == "gpt-4o"
api_key = await config_manager.get_api_key("openai")
assert api_key == "test-key"
finally:
# Clean up
os.unlink(config_file)
@pytest.mark.asyncio
async def test_error_recovery(self, knowledge_base_server):
"""Test error recovery in MCP operations"""
# Test with invalid tool call
try:
await knowledge_base_server.call_tool("nonexistent_tool", {})
assert False, "Should have raised ValueError"
except ValueError as e:
assert "not found" in str(e).lower()
# Test with valid tool but invalid arguments
try:
await knowledge_base_server.call_tool("search_kb", {"invalid": "args"})
assert False, "Should have raised error"
except Exception:
# Should handle gracefully
pass
# Verify server still works after errors
result = await knowledge_base_server.call_tool("get_knowledge_base", {})
assert "Knowledge Base" in result
@pytest.mark.asyncio
async def test_multiple_queries_workflow(self, knowledge_base_server):
"""Test multiple queries in sequence"""
queries = [
"vacation policy",
"benefits",
"faq"
]
for query in queries:
# Direct tool call (simulating what AI would do)
result = await knowledge_base_server.call_tool("search_kb", {"query": query})
# Verify we get relevant results
assert query.lower() in result.lower() or len(result) > 50
# Test getting full knowledge base
full_kb = await knowledge_base_server.call_tool("get_knowledge_base", {})
assert "company_policy" in full_kb
assert "benefits" in full_kb
assert "faq" in full_kb
+611
View File
@@ -0,0 +1,611 @@
"""
Comprehensive Integration Tests for MCP Client
This module contains end-to-end integration tests for the MCP client,
testing real server connections, tool calling, and AI provider integration.
"""
import pytest
import asyncio
import subprocess
import time
import signal
import os
import sys
from typing import Optional, Dict, Any, List
from unittest.mock import Mock, patch
import httpx
# Add src to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
from mcp_llm_client import MCPAIClient, TransportType
from mcp_template.llm_client.client_factory import AIClientFactory
class MCPServerProcess:
"""Helper class to manage MCP server process for testing"""
def __init__(self, transport: str = "sse", port: int = 8051):
self.transport = transport
self.port = port
self.process: Optional[subprocess.Popen] = None
self.url = f"http://localhost:{port}"
async def start(self):
"""Start the MCP server process"""
cmd = [
sys.executable,
"run_mcp_server.py",
"--transport", self.transport,
"--port", str(self.port)
]
self.process = subprocess.Popen(
cmd,
cwd=os.path.join(os.path.dirname(__file__), '..', '..'),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# Wait for server to be ready
await self._wait_for_server()
return self
async def _wait_for_server(self, timeout: int = 30):
"""Wait for server to be ready"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
async with httpx.AsyncClient() as client:
if self.transport == "sse":
response = await client.get(f"{self.url}/sse", timeout=5.0)
if response.status_code == 200:
return
else:
# For stdio, just wait a bit
await asyncio.sleep(2)
return
except Exception:
pass
await asyncio.sleep(1)
raise TimeoutError(f"Server did not start within {timeout} seconds")
async def stop(self):
"""Stop the MCP server process"""
if self.process:
try:
self.process.terminate()
self.process.wait(timeout=10)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
@pytest.fixture
async def sse_server():
"""Fixture to start and stop SSE MCP server"""
server = MCPServerProcess(transport="sse", port=8051)
await server.start()
yield server
await server.stop()
@pytest.fixture
async def stdio_server():
"""Fixture to start and stop STDIO MCP server"""
server = MCPServerProcess(transport="stdio", port=8052)
await server.start()
yield server
await server.stop()
class TestMCPClientIntegration:
"""Comprehensive integration tests for MCP client"""
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_sse_transport_connection(self, sse_server):
"""Test MCP client can connect to server using SSE transport"""
# Skip if no OpenAI API key
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1,
max_tokens=100
)
try:
# Test connection
await client.connect(sse_server.url)
# Test getting tools
tools = await client.get_mcp_tools()
assert isinstance(tools, list)
assert len(tools) > 0
# Verify tool structure
for tool in tools:
assert "name" in tool
assert "description" in tool
assert "inputSchema" in tool
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_stdio_transport_connection(self):
"""Test MCP client can connect using STDIO transport"""
# Skip if no OpenAI API key
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.STDIO,
provider="openai",
temperature=0.1,
max_tokens=100
)
try:
# For STDIO, we need to provide server command
await client.connect_stdio(
server_command=[sys.executable, "run_mcp_server.py", "--transport", "stdio"]
)
# Test getting tools
tools = await client.get_mcp_tools()
assert isinstance(tools, list)
assert len(tools) > 0
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.slow
@pytest.mark.asyncio
async def test_end_to_end_tool_calling_sse(self, sse_server):
"""Test complete end-to-end tool calling with SSE transport"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1,
max_tokens=200
)
try:
await client.connect(sse_server.url)
# Test a simple mathematical query that should trigger tool calls
query = "Calculate 15 + 27 and then multiply the result by 2"
response = await client.process_query(query)
assert isinstance(response, str)
assert len(response) > 0
# The response should contain the calculation result
# We can't predict exact wording but should contain numbers
assert any(char.isdigit() for char in response)
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_multiple_provider_support(self, sse_server):
"""Test MCP client works with different AI providers"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
providers_to_test = ["openai"]
for provider in providers_to_test:
client = MCPAIClient(
model="gpt-4o" if provider == "openai" else "claude-3-opus-20240229",
transport=TransportType.SSE,
provider=provider,
temperature=0.1,
max_tokens=100
)
try:
await client.connect(sse_server.url)
# Test basic tool listing
tools = await client.get_mcp_tools()
assert len(tools) > 0
# Test simple query
response = await client.process_query("What tools are available?")
assert isinstance(response, str)
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_error_handling_connection_failure(self):
"""Test error handling when server connection fails"""
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai"
)
# Try to connect to non-existent server
with pytest.raises(Exception):
await client.connect("http://localhost:9999")
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_error_handling_invalid_query(self, sse_server):
"""Test error handling with invalid queries"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1
)
try:
await client.connect(sse_server.url)
# Test with empty query
response = await client.process_query("")
assert isinstance(response, str)
# Test with very long query
long_query = "test " * 1000
response = await client.process_query(long_query)
assert isinstance(response, str)
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_interactive_session_mode(self, sse_server):
"""Test interactive session functionality"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1,
max_tokens=100
)
try:
await client.connect(sse_server.url)
# Mock user inputs for interactive session
inputs = ["Calculate 5 + 3", "quit"]
with patch('builtins.input', side_effect=inputs):
# This would normally run an interactive loop
# For testing, we'll just verify the client is ready
assert client.session is not None
assert client.ai_client is not None
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_tool_schema_validation(self, sse_server):
"""Test that tool schemas are properly formatted for AI providers"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai"
)
try:
await client.connect(sse_server.url)
tools = await client.get_mcp_tools()
# Verify OpenAI-specific tool formatting
for tool in tools:
assert "type" in tool
assert tool["type"] == "function"
assert "function" in tool
assert "name" in tool["function"]
assert "description" in tool["function"]
assert "parameters" in tool["function"]
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.slow
@pytest.mark.asyncio
async def test_concurrent_connections(self, sse_server):
"""Test multiple clients can connect simultaneously"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
async def create_and_test_client(client_id: int):
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1
)
try:
await client.connect(sse_server.url)
tools = await client.get_mcp_tools()
return len(tools)
finally:
await client.disconnect()
# Test 3 concurrent connections
tasks = [create_and_test_client(i) for i in range(3)]
results = await asyncio.gather(*tasks)
# All should succeed and return same number of tools
assert all(r > 0 for r in results)
assert len(set(results)) == 1 # All should return same count
@pytest.mark.integration
@pytest.mark.asyncio
async def test_configuration_parameters(self):
"""Test various configuration parameters work correctly"""
# Test different temperature settings
for temp in [0.1, 0.5, 0.9]:
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=temp,
max_tokens=100
)
assert client.temperature == temp
assert client.ai_client.temperature == temp
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_resource_access(self, sse_server):
"""Test accessing MCP resources"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai"
)
try:
await client.connect(sse_server.url)
# Test listing resources
resources = await client.session.list_resources()
assert isinstance(resources.resources, list)
# Test reading resources if any exist
if resources.resources:
for resource in resources.resources[:2]: # Test first 2 resources
content = await client.session.read_resource(resource.uri)
assert content is not None
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_prompt_access(self, sse_server):
"""Test accessing MCP prompts"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai"
)
try:
await client.connect(sse_server.url)
# Test listing prompts
prompts = await client.session.list_prompts()
assert isinstance(prompts.prompts, list)
# Test getting prompts if any exist
if prompts.prompts:
for prompt in prompts.prompts[:2]: # Test first 2 prompts
prompt_content = await client.session.get_prompt(prompt.name)
assert prompt_content is not None
finally:
await client.disconnect()
class TestMCPClientUtilities:
"""Test utility functions and edge cases"""
@pytest.mark.unit
def test_transport_type_enum(self):
"""Test TransportType enum values"""
assert TransportType.SSE.value == "sse"
assert TransportType.STDIO.value == "stdio"
@pytest.mark.unit
def test_client_initialization_validation(self):
"""Test client initialization with various parameters"""
# Test with minimal parameters
client = MCPAIClient()
assert client.model == "gpt-4o"
assert client.provider == "openai"
assert client.transport == TransportType.SSE
# Test with custom parameters
client = MCPAIClient(
model="gpt-4o-mini",
provider="openai",
transport=TransportType.STDIO,
temperature=0.5,
max_tokens=500,
top_p=0.9
)
assert client.model == "gpt-4o-mini"
assert client.provider == "openai"
assert client.transport == TransportType.STDIO
assert client.temperature == 0.5
assert client.max_tokens == 500
@pytest.mark.unit
@patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"})
def test_ai_client_factory_integration(self):
"""Test that AI client factory creates correct client types"""
# Test OpenAI client creation
openai_client = AIClientFactory.create_client(
provider="openai",
model_name="gpt-4o",
temperature=0.5
)
assert openai_client is not None
assert hasattr(openai_client, 'chat_completion')
assert hasattr(openai_client, '_format_tools_for_provider')
@pytest.mark.unit
def test_missing_llm_client_handling(self):
"""Test behavior when LLM client is not available"""
with patch('mcp_llm_client.LLM_CLIENT_AVAILABLE', False):
with pytest.raises(ImportError, match="LLM client not available"):
MCPAIClient()
class TestPerformanceAndLoad:
"""Performance and load testing for MCP client"""
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.slow
@pytest.mark.asyncio
async def test_multiple_rapid_queries(self, sse_server):
"""Test handling multiple rapid queries"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1,
max_tokens=50
)
try:
await client.connect(sse_server.url)
queries = [
"What is 2 + 2?",
"Calculate 10 * 5",
"What tools do you have?",
"Test query 4"
]
# Execute queries concurrently
tasks = [client.process_query(query) for query in queries]
responses = await asyncio.gather(*tasks)
# Verify all responses
assert len(responses) == len(queries)
assert all(isinstance(r, str) for r in responses)
assert all(len(r) > 0 for r in responses)
finally:
await client.disconnect()
@pytest.mark.integration
@pytest.mark.requires_api_key
@pytest.mark.asyncio
async def test_connection_reuse(self, sse_server):
"""Test reusing connection for multiple operations"""
if not os.getenv("OPENAI_API_KEY"):
pytest.skip("OpenAI API key not available")
client = MCPAIClient(
model="gpt-4o",
transport=TransportType.SSE,
provider="openai",
temperature=0.1
)
try:
await client.connect(sse_server.url)
# Perform multiple operations on same connection
for i in range(5):
tools = await client.get_mcp_tools()
assert len(tools) > 0
response = await client.process_query(f"Test query {i}")
assert isinstance(response, str)
finally:
await client.disconnect()
# Cleanup fixture to ensure no processes are left running
@pytest.fixture(scope="session", autouse=True)
async def cleanup_processes():
"""Clean up any remaining MCP server processes"""
yield
# Kill any remaining MCP server processes
try:
# Find and kill any remaining server processes
result = subprocess.run(
["pgrep", "-f", "run_mcp_server.py"],
capture_output=True,
text=True
)
if result.returncode == 0:
pids = result.stdout.strip().split('\n')
for pid in pids:
if pid:
try:
os.kill(int(pid), signal.SIGTERM)
except (ProcessLookupError, OSError):
pass # Process already dead
except (subprocess.SubprocessError, FileNotFoundError):
pass # pgrep not available or no processes found
+202
View File
@@ -0,0 +1,202 @@
"""
Unit tests for AI client components
"""
from unittest import async_case
import pytest
from unittest.mock import Mock, AsyncMock, patch
from mcp_template.llm_client.base_client import BaseAIClient
from mcp_template.llm_client.openai_client import OpenAIClient
from mcp_template.llm_client.client_factory import AIClientFactory
class TestBaseAIClient:
"""Test base AI client functionality"""
def test_base_client_creation(self):
"""Test creating a base client"""
# Create a concrete implementation for testing
class TestAIClient(BaseAIClient):
async def chat_completion(self, messages, tools=None, **kwargs):
return {"choices": [{"message": {"content": "test response"}}]}
async def _initialize_client(self):
pass
client = TestAIClient(
model_name="test-model",
provider="test",
api_key="test-key"
)
assert client.model_name == "test-model"
assert client._api_key == "test-key"
assert not client._initialized
def test_abstract_class_cannot_be_instantiated(self):
"""Test that abstract BaseAIClient cannot be instantiated directly"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BaseAIClient(
model_name="test-model",
provider="test",
api_key="test-key"
)
class TestOpenAIClient:
"""Test OpenAI client implementation"""
def test_openai_client_creation(self):
"""Test creating an OpenAI client"""
client = OpenAIClient(
model_name="gpt-4o",
api_key="test-key",
temperature=0.5,
max_tokens=500
)
assert client.model_name == "gpt-4o"
assert client._temperature == 0.5
assert client._max_tokens == 500
def test_openai_client_defaults(self):
"""Test OpenAI client default values"""
client = OpenAIClient(
model_name="gpt-4o",
api_key="test-key"
)
assert client._temperature == 0.7
assert client._max_tokens == 1000
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
@pytest.mark.asyncio
async def test_openai_initialization(self, mock_openai_class):
"""Test OpenAI client initialization"""
mock_client = Mock()
mock_openai_class.return_value = mock_client
client = OpenAIClient("gpt-4o", "test-key")
await client._initialize_client()
mock_openai_class.assert_called_once_with(api_key="test-key")
assert client._client == mock_client
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
@pytest.mark.asyncio
async def test_openai_chat_completion(self, mock_openai_class):
"""Test OpenAI chat completion"""
# Mock the OpenAI client and response
mock_client = Mock()
mock_response = Mock()
mock_choice = Mock()
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = "Test response"
mock_message.tool_calls = None
mock_choice.message = mock_message
mock_response.choices = [mock_choice]
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
mock_openai_class.return_value = mock_client
client = OpenAIClient("gpt-4o", "test-key")
await client.initialize()
messages = [{"role": "user", "content": "Hello"}]
result = await client.chat_completion(messages)
# Verify the call was made correctly
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args[1]
assert call_args["model"] == "gpt-4o"
assert call_args["messages"] == messages
assert call_args["temperature"] == 0.7
assert call_args["max_tokens"] == 1000
# Verify response format
assert "choices" in result
assert len(result["choices"]) == 1
assert result["choices"][0]["message"]["content"] == "Test response"
def test_openai_tool_formatting(self):
"""Test OpenAI tool formatting"""
client = OpenAIClient("gpt-4o", "test-key")
tools = [
{
"name": "add",
"description": "Add two numbers",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a", "b"]
}
}
]
formatted = client._format_tools_for_provider(tools)
assert len(formatted) == 1
assert formatted[0]["type"] == "function"
assert formatted[0]["function"]["name"] == "add"
assert formatted[0]["function"]["description"] == "Add two numbers"
assert "parameters" in formatted[0]["function"]
class TestAIClientFactory:
"""Test AI client factory"""
def test_create_openai_client(self):
"""Test creating OpenAI client via factory"""
client = AIClientFactory.create_openai_client(
model_name="gpt-4o",
api_key="test-key"
)
assert isinstance(client, OpenAIClient)
assert client.model_name == "gpt-4o"
def test_create_client_by_provider(self):
"""Test creating client by provider name"""
client = AIClientFactory.create_client(
provider="openai",
model_name="gpt-4o",
api_key="test-key"
)
assert isinstance(client, OpenAIClient)
assert client.model_name == "gpt-4o"
def test_invalid_provider(self):
"""Test creating client with invalid provider"""
with pytest.raises(ValueError, match="Unsupported AI provider"):
AIClientFactory.create_client("invalid", "model", "key")
@patch('mcp_template.llm_client.client_factory.CONFIG_AVAILABLE', True)
@patch('mcp_template.llm_client.client_factory.Config')
def test_missing_api_key(self, mock_config):
"""Test creating client without API key"""
# Mock config to return None for API key
mock_config.OPENAI_API_KEY = None
with pytest.raises(ValueError, match="API key not provided and could not be loaded from config for provider: openai"):
AIClientFactory.create_client("openai", "gpt-4o")
def test_available_providers(self):
"""Test getting available providers"""
providers = AIClientFactory.get_available_providers()
assert "openai" in providers
assert "claude" in providers
assert "grok" in providers
def test_validate_provider(self):
"""Test provider validation"""
assert AIClientFactory.validate_provider("openai") is True
assert AIClientFactory.validate_provider("claude") is True
assert AIClientFactory.validate_provider("invalid") is False
+213
View File
@@ -0,0 +1,213 @@
"""
Unit tests for core MCP types
"""
import pytest
from mcp_template.core.types import (
TransportType,
MCPTool,
MCPResource,
MCPPrompt,
MCPServerConfig
)
class TestTransportType:
"""Test TransportType enum"""
def test_transport_type_values(self):
"""Test transport type enum values"""
assert TransportType.SSE.value == "sse"
assert TransportType.STDIO.value == "stdio"
def test_transport_type_from_string(self):
"""Test creating transport type from string"""
assert TransportType("sse") == TransportType.SSE
assert TransportType("stdio") == TransportType.STDIO
assert TransportType("SSE") == TransportType.SSE
class TestMCPTool:
"""Test MCPTool dataclass"""
def test_tool_creation(self):
"""Test creating a valid tool"""
async def sample_handler(a: int, b: int) -> int:
return a + b
tool = MCPTool(
name="add",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a", "b"]
},
handler=sample_handler
)
assert tool.name == "add"
assert tool.description == "Add two numbers"
assert "a" in tool.input_schema["properties"]
def test_tool_validation(self):
"""Test tool validation"""
async def handler():
pass
# Valid tool
tool = MCPTool(
name="test",
description="Test tool",
input_schema={},
handler=handler
)
assert tool.name == "test"
# Invalid tool - empty name
with pytest.raises(ValueError):
MCPTool(
name="",
description="Test tool",
input_schema={},
handler=handler
)
# Invalid tool - empty description
with pytest.raises(ValueError):
MCPTool(
name="test",
description="",
input_schema={},
handler=handler
)
class TestMCPResource:
"""Test MCPResource dataclass"""
def test_resource_creation(self):
"""Test creating a valid resource"""
resource = MCPResource(
uri="file:///test.txt",
name="Test File",
description="A test file",
mime_type="text/plain",
content="Hello, world!"
)
assert resource.uri == "file:///test.txt"
assert resource.name == "Test File"
assert resource.content == "Hello, world!"
def test_resource_validation(self):
"""Test resource validation"""
# Valid resource
resource = MCPResource(
uri="file:///test.txt",
name="Test",
description="Test resource",
mime_type="text/plain",
content=b"test"
)
assert resource.uri == "file:///test.txt"
# Invalid resource - empty URI
with pytest.raises(ValueError):
MCPResource(
uri="",
name="Test",
description="Test resource",
mime_type="text/plain",
content="test"
)
# Invalid resource - empty name
with pytest.raises(ValueError):
MCPResource(
uri="file:///test.txt",
name="",
description="Test resource",
mime_type="text/plain",
content="test"
)
class TestMCPPrompt:
"""Test MCPPrompt dataclass"""
def test_prompt_creation(self):
"""Test creating a valid prompt"""
prompt = MCPPrompt(
name="greeting",
description="A greeting prompt",
template="Hello, {name}! Welcome to {place}.",
arguments={
"name": {"type": "string", "description": "Person's name"},
"place": {"type": "string", "description": "Place name"}
}
)
assert prompt.name == "greeting"
assert "name" in prompt.template
assert prompt.arguments is not None
def test_prompt_validation(self):
"""Test prompt validation"""
# Valid prompt
prompt = MCPPrompt(
name="test",
description="Test prompt",
template="Hello, world!"
)
assert prompt.name == "test"
# Invalid prompt - empty name
with pytest.raises(ValueError):
MCPPrompt(
name="",
description="Test prompt",
template="Hello, world!"
)
# Invalid prompt - empty template
with pytest.raises(ValueError):
MCPPrompt(
name="test",
description="Test prompt",
template=""
)
class TestMCPServerConfig:
"""Test MCPServerConfig dataclass"""
def test_server_config_creation(self):
"""Test creating a valid server config"""
config = MCPServerConfig(
name="Test Server",
version="1.0.0",
transport=TransportType.SSE,
host="localhost",
port=8080,
stateless_http=True
)
assert config.name == "Test Server"
assert config.transport == TransportType.SSE
assert config.port == 8080
assert config.tools == []
assert config.resources == []
assert config.prompts == []
def test_server_config_defaults(self):
"""Test server config default values"""
config = MCPServerConfig(name="Test Server")
assert config.version == "1.0.0"
assert config.transport == TransportType.STDIO
assert config.host == "0.0.0.0"
assert config.port == 8050
assert config.stateless_http is True
+495
View File
@@ -0,0 +1,495 @@
"""
Unit tests for MCP server components
"""
import pytest
import asyncio
import tempfile
import os
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from pathlib import Path
from mcp_template.core.types import MCPTool, MCPResource, MCPPrompt, MCPServerConfig, TransportType
# FastMCPServer not available, using ModularMCPServer for testing
from mcp_template.server.server_factory import MCPServerFactory
from mcp_template.server.modular_server import ModularMCPServer
from mcp_template.server.tools.tool_registry import ServerToolRegistry
from mcp_template.server.prompts.prompt_registry import ServerPromptRegistry
from mcp_template.server.resources.resource_registry import ServerResourceRegistry
from mcp_template.server.tools.base_tool import BaseServerTool, ContextAwareTool
from mcp_template.server.prompts.base_prompt import BaseServerPrompt
from mcp_template.server.resources.base_resource import BaseServerResource
class TestModularMCPServerBasic:
"""Test basic ModularMCP server functionality (replacing FastMCPServer tests)"""
@pytest.fixture
def sample_tool(self):
"""Create a sample tool for testing"""
async def add_handler(a: int, b: int) -> int:
return a + b
return MCPTool(
name="add",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a", "b"]
},
handler=add_handler
)
@pytest.fixture
def server_config(self):
"""Create a sample server config"""
return MCPServerConfig(
name="Test Server",
transport=TransportType.STDIO,
host="localhost",
port=8050
)
@pytest.mark.asyncio
async def test_server_initialization(self, server_config, sample_tool):
"""Test server initialization"""
# Use factory to create server
server = MCPServerFactory.create_server(
name="Test Server",
transport="stdio"
)
# Server should not be initialized yet
assert not server._initialized
# Initialize server
await server.initialize()
# Server should be initialized
assert server._initialized
@pytest.mark.asyncio
async def test_server_info(self, server_config):
"""Test getting server information"""
server = MCPServerFactory.create_server(
name="Test Server",
transport="stdio"
)
# Test server info before initialization
info = server.get_server_info()
assert info["name"] == "Test Server"
assert info["host"] == "0.0.0.0"
assert info["port"] == 8050
def test_server_run_method(self, server_config):
"""Test server run method"""
server = MCPServerFactory.create_server(
name="Test Server",
transport="stdio"
)
# Mock the run method to avoid actually starting the server
with patch.object(server.mcp, 'run') as mock_run:
server.run("stdio")
mock_run.assert_called_once_with(transport="stdio")
class TestMCPServerFactory:
"""Test MCP server factory"""
def test_create_server_basic(self):
"""Test creating a basic server"""
server = MCPServerFactory.create_server(
name="Test Server",
transport="stdio"
)
assert isinstance(server, ModularMCPServer)
assert server.name == "Test Server"
assert server.host == "0.0.0.0"
assert server.port == 8050
def test_create_calculator_server(self):
"""Test creating a calculator server"""
server = MCPServerFactory.create_basic_calculator_server(
name="Calculator Server",
transport="sse",
port=8080
)
assert isinstance(server, ModularMCPServer)
assert server.name == "Calculator Server"
assert server.host == "0.0.0.0"
assert server.port == 8080
# For ModularMCPServer, we can't easily test tool registration without mocking
# The tools would be registered through the registry system
assert server.tool_registry is not None
def test_create_knowledge_base_server(self):
"""Test creating a knowledge base server"""
kb_data = {
"policy": "Company vacation policy",
"faq": "Frequently asked questions"
}
server = MCPServerFactory.create_knowledge_base_server(
name="KB Server",
kb_data=kb_data
)
assert isinstance(server, ModularMCPServer)
assert server.name == "KB Server"
# For ModularMCPServer, we can't easily test tool registration without mocking
# The tools would be registered through the registry system
assert server.tool_registry is not None
class TestModularMCPServer:
"""Test ModularMCP server implementation"""
@pytest.fixture
def modular_server(self):
"""Create a modular server for testing"""
return ModularMCPServer(
name="Test Modular Server",
host="localhost",
port=8051
)
@pytest.mark.asyncio
async def test_modular_server_initialization(self, modular_server):
"""Test modular server initialization"""
assert modular_server.name == "Test Modular Server"
assert modular_server.host == "localhost"
assert modular_server.port == 8051
assert not modular_server._initialized
assert modular_server.tool_registry is not None
assert modular_server.prompt_registry is not None
assert modular_server.resource_registry is not None
@pytest.mark.asyncio
async def test_modular_server_initialize(self, modular_server):
"""Test modular server initialization process"""
# Mock the registry methods to avoid actual file discovery
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
await modular_server.initialize()
assert modular_server._initialized
mock_tools.assert_called_once_with(modular_server.mcp)
mock_prompts.assert_called_once_with(modular_server.mcp)
mock_resources.assert_called_once_with(modular_server.mcp)
@pytest.mark.asyncio
async def test_modular_server_double_initialize(self, modular_server):
"""Test that double initialization doesn't cause issues"""
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
# Initialize twice
await modular_server.initialize()
await modular_server.initialize()
# Should only be called once due to _initialized check
mock_tools.assert_called_once()
mock_prompts.assert_called_once()
mock_resources.assert_called_once()
def test_modular_server_get_server_info(self, modular_server):
"""Test getting server information"""
# Mock the registry methods
with patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[Mock(name="tool1")]), \
patch.object(modular_server.tool_registry, 'get_tool_names', return_value=["tool1"]), \
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[Mock(name="prompt1")]), \
patch.object(modular_server.prompt_registry, 'get_prompt_names', return_value=["prompt1"]), \
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[Mock(uri="resource1")]), \
patch.object(modular_server.resource_registry, 'get_resource_uris', return_value=["resource1"]):
info = modular_server.get_server_info()
assert info["name"] == "Test Modular Server"
assert info["host"] == "localhost"
assert info["port"] == 8051
assert info["tools"]["count"] == 1
assert info["tools"]["names"] == ["tool1"]
assert info["prompts"]["count"] == 1
assert info["prompts"]["names"] == ["prompt1"]
assert info["resources"]["count"] == 1
assert info["resources"]["uris"] == ["resource1"]
def test_modular_server_run_without_initialize(self, modular_server):
"""Test that run method initializes server if not already initialized"""
with patch.object(modular_server, 'initialize') as mock_init, \
patch.object(modular_server.mcp, 'run') as mock_run:
modular_server.run("stdio")
mock_init.assert_called_once()
mock_run.assert_called_once_with(transport="stdio")
def test_modular_server_run_already_initialized(self, modular_server):
"""Test that run method doesn't reinitialize if already initialized"""
modular_server._initialized = True
with patch.object(modular_server, 'initialize') as mock_init, \
patch.object(modular_server.mcp, 'run') as mock_run:
modular_server.run("sse")
mock_init.assert_not_called()
mock_run.assert_called_once_with(transport="sse")
def test_modular_server_custom_directories(self):
"""Test modular server with custom directories"""
server = ModularMCPServer(
name="Custom Server",
tools_directory="/custom/tools",
prompts_directory="/custom/prompts",
resources_directory="/custom/resources"
)
assert server.tool_registry.directory == "/custom/tools"
assert server.prompt_registry.directory == "/custom/prompts"
assert server.resource_registry.directory == "/custom/resources"
class TestServerRegistries:
"""Test server registry classes"""
@pytest.fixture
def mock_tool(self):
"""Create a mock tool for testing"""
tool = Mock()
tool.name = "test_tool"
tool.description = "A test tool"
tool.input_schema = {"type": "object"}
return tool
@pytest.fixture
def mock_prompt(self):
"""Create a mock prompt for testing"""
prompt = Mock()
prompt.name = "test_prompt"
prompt.description = "A test prompt"
prompt.arguments = {"name": {"type": "string"}}
return prompt
@pytest.fixture
def mock_resource(self):
"""Create a mock resource for testing"""
resource = Mock()
resource.uri = "test://resource"
resource.name = "Test Resource"
resource.description = "A test resource"
return resource
def test_tool_registry_initialization(self):
"""Test tool registry initialization"""
registry = ServerToolRegistry("/test/tools")
assert registry.directory == "/test/tools"
assert registry._tools == {}
def test_tool_registry_register_tool(self, mock_tool):
"""Test tool registration"""
registry = ServerToolRegistry()
registry.register_tool(mock_tool)
assert "test_tool" in registry._tools
assert registry._tools["test_tool"] == mock_tool
def test_tool_registry_get_all_tools(self, mock_tool):
"""Test getting all tools"""
registry = ServerToolRegistry()
registry.register_tool(mock_tool)
tools = registry.get_all_tools()
assert len(tools) == 1
assert tools[0] == mock_tool
def test_tool_registry_get_tool_names(self, mock_tool):
"""Test getting tool names"""
registry = ServerToolRegistry()
registry.register_tool(mock_tool)
names = registry.get_tool_names()
assert names == ["test_tool"]
def test_prompt_registry_initialization(self):
"""Test prompt registry initialization"""
registry = ServerPromptRegistry("/test/prompts")
assert registry.directory == "/test/prompts"
assert registry._prompts == {}
def test_prompt_registry_register_prompt(self, mock_prompt):
"""Test prompt registration"""
registry = ServerPromptRegistry()
registry.register_prompt(mock_prompt)
assert "test_prompt" in registry._prompts
assert registry._prompts["test_prompt"] == mock_prompt
def test_prompt_registry_get_all_prompts(self, mock_prompt):
"""Test getting all prompts"""
registry = ServerPromptRegistry()
registry.register_prompt(mock_prompt)
prompts = registry.get_all_prompts()
assert len(prompts) == 1
assert prompts[0] == mock_prompt
def test_prompt_registry_get_prompt_names(self, mock_prompt):
"""Test getting prompt names"""
registry = ServerPromptRegistry()
registry.register_prompt(mock_prompt)
names = registry.get_prompt_names()
assert names == ["test_prompt"]
def test_resource_registry_initialization(self):
"""Test resource registry initialization"""
registry = ServerResourceRegistry("/test/resources")
assert registry.directory == "/test/resources"
assert registry._resources == {}
def test_resource_registry_register_resource(self, mock_resource):
"""Test resource registration"""
registry = ServerResourceRegistry()
registry.register_resource(mock_resource)
assert "test://resource" in registry._resources
assert registry._resources["test://resource"] == mock_resource
def test_resource_registry_get_all_resources(self, mock_resource):
"""Test getting all resources"""
registry = ServerResourceRegistry()
registry.register_resource(mock_resource)
resources = registry.get_all_resources()
assert len(resources) == 1
assert resources[0] == mock_resource
def test_resource_registry_get_resource_uris(self, mock_resource):
"""Test getting resource URIs"""
registry = ServerResourceRegistry()
registry.register_resource(mock_resource)
uris = registry.get_resource_uris()
assert uris == ["test://resource"]
class TestBaseServerTools:
"""Test base server tool classes"""
def test_base_server_tool_initialization(self):
"""Test BaseServerTool initialization"""
class TestTool(BaseServerTool):
async def execute(self, **kwargs):
return "test_result"
tool = TestTool(
name="test_tool",
description="A test tool",
input_schema={"type": "object"}
)
assert tool.name == "test_tool"
assert tool.description == "A test tool"
assert tool.input_schema == {"type": "object"}
def test_base_server_tool_get_tool_definition(self):
"""Test getting tool definition"""
class TestTool(BaseServerTool):
async def execute(self, **kwargs):
return "test_result"
tool = TestTool(
name="test_tool",
description="A test tool",
input_schema={"type": "object"}
)
definition = tool.get_tool_definition()
expected = {
"name": "test_tool",
"description": "A test tool",
"input_schema": {"type": "object"}
}
assert definition == expected
def test_base_server_tool_create_fastmcp_tool(self):
"""Test creating FastMCP tool wrapper"""
class TestTool(BaseServerTool):
async def execute(self, **kwargs):
return "test_result"
tool = TestTool(
name="test_tool",
description="A test tool",
input_schema={"type": "object"}
)
# Mock MCP server
mock_mcp = Mock()
mock_mcp.tool.return_value = lambda func: func
wrapper = tool.create_fastmcp_tool(mock_mcp)
assert wrapper.__name__ == "test_tool"
assert wrapper.__doc__ == "A test tool"
def test_context_aware_tool_initialization(self):
"""Test ContextAwareTool initialization"""
class TestContextTool(ContextAwareTool):
async def execute_with_context(self, ctx, **kwargs):
return "context_result"
tool = TestContextTool(
name="context_tool",
description="A context-aware tool",
input_schema={"type": "object"}
)
assert tool.name == "context_tool"
assert tool.description == "A context-aware tool"
assert tool.input_schema == {"type": "object"}
@pytest.mark.asyncio
async def test_context_aware_tool_execute_with_context(self):
"""Test ContextAwareTool execution with context"""
from unittest.mock import Mock
class TestContextTool(ContextAwareTool):
async def execute_with_context(self, ctx, **kwargs):
return f"Context: {self.context}, Args: {kwargs}"
tool = TestContextTool(
name="context_tool",
description="A context-aware tool",
input_schema={"type": "object"}
)
# Set context
tool.context = {"user_id": "123"}
# Mock the context
mock_ctx = Mock()
result = await tool.execute_with_context(mock_ctx, test_arg="value")
assert "Context: {'user_id': '123'}" in result
assert "Args: {'test_arg': 'value'}" in result