256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
"""
|
|
End-to-end tests for complete MCP workflow
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
|
|
from mcp_template.core.types import MCPTool, MCPServerConfig, TransportType
|
|
from mcp_template.server.fastmcp_server import FastMCPServer
|
|
from mcp_template.llm_client.openai_client import OpenAIClient
|
|
from mcp_template.config.config_manager import ConfigManager
|
|
|
|
|
|
class TestEndToEnd:
|
|
"""End-to-end tests for complete MCP workflow"""
|
|
|
|
@pytest.fixture
|
|
async def knowledge_base_server(self):
|
|
"""Create a knowledge base server for e2e testing"""
|
|
kb_data = {
|
|
"company_policy": "Our company offers 20 days of paid vacation annually.",
|
|
"benefits": "We provide health insurance, 401k matching, and flexible work hours.",
|
|
"faq": "Common questions about our policies and procedures."
|
|
}
|
|
|
|
async def get_kb():
|
|
formatted = "Knowledge Base:\n\n"
|
|
for key, value in kb_data.items():
|
|
formatted += f"**{key.replace('_', ' ').title()}:**\n{value}\n\n"
|
|
return formatted
|
|
|
|
async def search_kb(query: str):
|
|
query_lower = query.lower()
|
|
results = []
|
|
|
|
for key, value in kb_data.items():
|
|
if query_lower in key.lower() or query_lower in value.lower():
|
|
results.append(f"**{key.replace('_', ' ').title()}:**\n{value}")
|
|
|
|
if not results:
|
|
return f"No information found for: {query}"
|
|
|
|
return "\n\n".join(results)
|
|
|
|
tools = [
|
|
MCPTool(
|
|
name="get_knowledge_base",
|
|
description="Retrieve the entire knowledge base",
|
|
input_schema={"type": "object", "properties": {}},
|
|
handler=get_kb
|
|
),
|
|
MCPTool(
|
|
name="search_kb",
|
|
description="Search the knowledge base for specific information",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "Search query"}
|
|
},
|
|
"required": ["query"]
|
|
},
|
|
handler=search_kb
|
|
)
|
|
]
|
|
|
|
config = MCPServerConfig(
|
|
name="Knowledge Base Server",
|
|
tools=tools
|
|
)
|
|
|
|
server = FastMCPServer(config)
|
|
await server.initialize()
|
|
return server
|
|
|
|
@pytest.fixture
|
|
def mock_openai_response(self):
|
|
"""Mock OpenAI API response with tool calls"""
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "I'll search the knowledge base for vacation policy information.",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_123",
|
|
"function": {
|
|
"name": "search_kb",
|
|
"arguments": '{"query": "vacation"}'
|
|
}
|
|
}
|
|
]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
@patch('src.clients.openai_client.AsyncOpenAI')
|
|
@pytest.mark.asyncio
|
|
async def test_complete_mcp_workflow(self, mock_openai_class, knowledge_base_server, mock_openai_response):
|
|
"""Test complete MCP workflow from query to response"""
|
|
|
|
# Mock OpenAI client setup
|
|
mock_client = Mock()
|
|
mock_response = Mock()
|
|
mock_response.choices = [
|
|
Mock(message=Mock(
|
|
role="assistant",
|
|
content="I'll search the knowledge base for vacation policy information.",
|
|
tool_calls=[
|
|
Mock(
|
|
id="call_123",
|
|
function=Mock(
|
|
name="search_kb",
|
|
arguments='{"query": "vacation"}'
|
|
)
|
|
)
|
|
]
|
|
))
|
|
]
|
|
|
|
# Second call response (after tool execution)
|
|
mock_final_response = Mock()
|
|
mock_final_response.choices = [
|
|
Mock(message=Mock(
|
|
content="Based on the knowledge base, our company offers 20 days of paid vacation annually."
|
|
))
|
|
]
|
|
|
|
mock_client.chat.completions.create = AsyncMock(side_effect=[mock_response, mock_final_response])
|
|
mock_openai_class.return_value = mock_client
|
|
|
|
# Create AI client
|
|
ai_client = OpenAIClient("gpt-4o", "test-key")
|
|
await ai_client.initialize()
|
|
|
|
# Create mock MCP client that uses the real server
|
|
class MockMCPClient:
|
|
def __init__(self, server):
|
|
self.server = server
|
|
|
|
async def call_tool(self, name, arguments):
|
|
return await self.server.call_tool(name, arguments)
|
|
|
|
async def list_tools(self):
|
|
return await self.server.list_tools()
|
|
|
|
mcp_client = MockMCPClient(knowledge_base_server)
|
|
|
|
# Test the complete workflow
|
|
query = "What's the company vacation policy?"
|
|
response = await ai_client.process_with_tools(query, [], mcp_client)
|
|
|
|
# Verify the workflow
|
|
assert "vacation" in response.lower() or "20 days" in response
|
|
|
|
# Verify OpenAI was called twice (initial + after tool execution)
|
|
assert mock_client.chat.completions.create.call_count == 2
|
|
|
|
# Verify the tool was called correctly
|
|
tool_result = await mcp_client.call_tool("search_kb", {"query": "vacation"})
|
|
assert "vacation" in tool_result.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_configuration_workflow(self, knowledge_base_server):
|
|
"""Test configuration loading and usage"""
|
|
|
|
# Create temporary config file
|
|
config_data = {
|
|
"server": {
|
|
"name": "Test Server",
|
|
"port": 8080,
|
|
"transport": "sse"
|
|
},
|
|
"client": {
|
|
"provider": "openai",
|
|
"model": "gpt-4o"
|
|
},
|
|
"api_keys": {
|
|
"openai": "test-key"
|
|
}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
import json
|
|
json.dump(config_data, f)
|
|
config_file = f.name
|
|
|
|
try:
|
|
# Test config manager
|
|
config_manager = ConfigManager(config_file)
|
|
|
|
server_config = await config_manager.get_server_config()
|
|
assert server_config["name"] == "Test Server"
|
|
assert server_config["port"] == 8080
|
|
|
|
client_config = await config_manager.get_client_config()
|
|
assert client_config["provider"] == "openai"
|
|
assert client_config["model"] == "gpt-4o"
|
|
|
|
api_key = await config_manager.get_api_key("openai")
|
|
assert api_key == "test-key"
|
|
|
|
finally:
|
|
# Clean up
|
|
os.unlink(config_file)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_recovery(self, knowledge_base_server):
|
|
"""Test error recovery in MCP operations"""
|
|
|
|
# Test with invalid tool call
|
|
try:
|
|
await knowledge_base_server.call_tool("nonexistent_tool", {})
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError as e:
|
|
assert "not found" in str(e).lower()
|
|
|
|
# Test with valid tool but invalid arguments
|
|
try:
|
|
await knowledge_base_server.call_tool("search_kb", {"invalid": "args"})
|
|
assert False, "Should have raised error"
|
|
except Exception:
|
|
# Should handle gracefully
|
|
pass
|
|
|
|
# Verify server still works after errors
|
|
result = await knowledge_base_server.call_tool("get_knowledge_base", {})
|
|
assert "Knowledge Base" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_queries_workflow(self, knowledge_base_server):
|
|
"""Test multiple queries in sequence"""
|
|
|
|
queries = [
|
|
"vacation policy",
|
|
"benefits",
|
|
"faq"
|
|
]
|
|
|
|
for query in queries:
|
|
# Direct tool call (simulating what AI would do)
|
|
result = await knowledge_base_server.call_tool("search_kb", {"query": query})
|
|
|
|
# Verify we get relevant results
|
|
assert query.lower() in result.lower() or len(result) > 50
|
|
|
|
# Test getting full knowledge base
|
|
full_kb = await knowledge_base_server.call_tool("get_knowledge_base", {})
|
|
assert "company_policy" in full_kb
|
|
assert "benefits" in full_kb
|
|
assert "faq" in full_kb
|