496 lines
18 KiB
Python
496 lines
18 KiB
Python
"""
|
|
Unit tests for MCP server components
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import tempfile
|
|
import os
|
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
|
from pathlib import Path
|
|
from mcp_template.core.types import MCPTool, MCPResource, MCPPrompt, MCPServerConfig, TransportType
|
|
# FastMCPServer not available, using ModularMCPServer for testing
|
|
from mcp_template.server.server_factory import MCPServerFactory
|
|
from mcp_template.server.modular_server import ModularMCPServer
|
|
from mcp_template.server.tools.tool_registry import ServerToolRegistry
|
|
from mcp_template.server.prompts.prompt_registry import ServerPromptRegistry
|
|
from mcp_template.server.resources.resource_registry import ServerResourceRegistry
|
|
from mcp_template.server.tools.base_tool import BaseServerTool, ContextAwareTool
|
|
from mcp_template.server.prompts.base_prompt import BaseServerPrompt
|
|
from mcp_template.server.resources.base_resource import BaseServerResource
|
|
|
|
|
|
class TestModularMCPServerBasic:
|
|
"""Test basic ModularMCP server functionality (replacing FastMCPServer tests)"""
|
|
|
|
@pytest.fixture
|
|
def sample_tool(self):
|
|
"""Create a sample tool for testing"""
|
|
async def add_handler(a: int, b: int) -> int:
|
|
return a + b
|
|
|
|
return MCPTool(
|
|
name="add",
|
|
description="Add two numbers",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": {
|
|
"a": {"type": "number"},
|
|
"b": {"type": "number"}
|
|
},
|
|
"required": ["a", "b"]
|
|
},
|
|
handler=add_handler
|
|
)
|
|
|
|
@pytest.fixture
|
|
def server_config(self):
|
|
"""Create a sample server config"""
|
|
return MCPServerConfig(
|
|
name="Test Server",
|
|
transport=TransportType.STDIO,
|
|
host="localhost",
|
|
port=8050
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_initialization(self, server_config, sample_tool):
|
|
"""Test server initialization"""
|
|
# Use factory to create server
|
|
server = MCPServerFactory.create_server(
|
|
name="Test Server",
|
|
transport="stdio"
|
|
)
|
|
|
|
# Server should not be initialized yet
|
|
assert not server._initialized
|
|
|
|
# Initialize server
|
|
await server.initialize()
|
|
|
|
# Server should be initialized
|
|
assert server._initialized
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_info(self, server_config):
|
|
"""Test getting server information"""
|
|
server = MCPServerFactory.create_server(
|
|
name="Test Server",
|
|
transport="stdio"
|
|
)
|
|
|
|
# Test server info before initialization
|
|
info = server.get_server_info()
|
|
assert info["name"] == "Test Server"
|
|
assert info["host"] == "0.0.0.0"
|
|
assert info["port"] == 8050
|
|
|
|
def test_server_run_method(self, server_config):
|
|
"""Test server run method"""
|
|
server = MCPServerFactory.create_server(
|
|
name="Test Server",
|
|
transport="stdio"
|
|
)
|
|
|
|
# Mock the run method to avoid actually starting the server
|
|
with patch.object(server.mcp, 'run') as mock_run:
|
|
server.run("stdio")
|
|
mock_run.assert_called_once_with(transport="stdio")
|
|
|
|
|
|
|
|
class TestMCPServerFactory:
|
|
"""Test MCP server factory"""
|
|
|
|
def test_create_server_basic(self):
|
|
"""Test creating a basic server"""
|
|
server = MCPServerFactory.create_server(
|
|
name="Test Server",
|
|
transport="stdio"
|
|
)
|
|
|
|
assert isinstance(server, ModularMCPServer)
|
|
assert server.name == "Test Server"
|
|
assert server.host == "0.0.0.0"
|
|
assert server.port == 8050
|
|
|
|
def test_create_calculator_server(self):
|
|
"""Test creating a calculator server"""
|
|
server = MCPServerFactory.create_basic_calculator_server(
|
|
name="Calculator Server",
|
|
transport="sse",
|
|
port=8080
|
|
)
|
|
|
|
assert isinstance(server, ModularMCPServer)
|
|
assert server.name == "Calculator Server"
|
|
assert server.host == "0.0.0.0"
|
|
assert server.port == 8080
|
|
|
|
# For ModularMCPServer, we can't easily test tool registration without mocking
|
|
# The tools would be registered through the registry system
|
|
assert server.tool_registry is not None
|
|
|
|
def test_create_knowledge_base_server(self):
|
|
"""Test creating a knowledge base server"""
|
|
kb_data = {
|
|
"policy": "Company vacation policy",
|
|
"faq": "Frequently asked questions"
|
|
}
|
|
|
|
server = MCPServerFactory.create_knowledge_base_server(
|
|
name="KB Server",
|
|
kb_data=kb_data
|
|
)
|
|
|
|
assert isinstance(server, ModularMCPServer)
|
|
assert server.name == "KB Server"
|
|
|
|
# For ModularMCPServer, we can't easily test tool registration without mocking
|
|
# The tools would be registered through the registry system
|
|
assert server.tool_registry is not None
|
|
|
|
|
|
class TestModularMCPServer:
|
|
"""Test ModularMCP server implementation"""
|
|
|
|
@pytest.fixture
|
|
def modular_server(self):
|
|
"""Create a modular server for testing"""
|
|
return ModularMCPServer(
|
|
name="Test Modular Server",
|
|
host="localhost",
|
|
port=8051
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_modular_server_initialization(self, modular_server):
|
|
"""Test modular server initialization"""
|
|
assert modular_server.name == "Test Modular Server"
|
|
assert modular_server.host == "localhost"
|
|
assert modular_server.port == 8051
|
|
assert not modular_server._initialized
|
|
assert modular_server.tool_registry is not None
|
|
assert modular_server.prompt_registry is not None
|
|
assert modular_server.resource_registry is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_modular_server_initialize(self, modular_server):
|
|
"""Test modular server initialization process"""
|
|
# Mock the registry methods to avoid actual file discovery
|
|
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
|
|
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
|
|
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
|
|
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
|
|
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
|
|
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
|
|
|
|
await modular_server.initialize()
|
|
|
|
assert modular_server._initialized
|
|
mock_tools.assert_called_once_with(modular_server.mcp)
|
|
mock_prompts.assert_called_once_with(modular_server.mcp)
|
|
mock_resources.assert_called_once_with(modular_server.mcp)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_modular_server_double_initialize(self, modular_server):
|
|
"""Test that double initialization doesn't cause issues"""
|
|
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
|
|
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
|
|
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
|
|
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
|
|
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
|
|
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
|
|
|
|
# Initialize twice
|
|
await modular_server.initialize()
|
|
await modular_server.initialize()
|
|
|
|
# Should only be called once due to _initialized check
|
|
mock_tools.assert_called_once()
|
|
mock_prompts.assert_called_once()
|
|
mock_resources.assert_called_once()
|
|
|
|
def test_modular_server_get_server_info(self, modular_server):
|
|
"""Test getting server information"""
|
|
# Mock the registry methods
|
|
with patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[Mock(name="tool1")]), \
|
|
patch.object(modular_server.tool_registry, 'get_tool_names', return_value=["tool1"]), \
|
|
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[Mock(name="prompt1")]), \
|
|
patch.object(modular_server.prompt_registry, 'get_prompt_names', return_value=["prompt1"]), \
|
|
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[Mock(uri="resource1")]), \
|
|
patch.object(modular_server.resource_registry, 'get_resource_uris', return_value=["resource1"]):
|
|
|
|
info = modular_server.get_server_info()
|
|
|
|
assert info["name"] == "Test Modular Server"
|
|
assert info["host"] == "localhost"
|
|
assert info["port"] == 8051
|
|
assert info["tools"]["count"] == 1
|
|
assert info["tools"]["names"] == ["tool1"]
|
|
assert info["prompts"]["count"] == 1
|
|
assert info["prompts"]["names"] == ["prompt1"]
|
|
assert info["resources"]["count"] == 1
|
|
assert info["resources"]["uris"] == ["resource1"]
|
|
|
|
def test_modular_server_run_without_initialize(self, modular_server):
|
|
"""Test that run method initializes server if not already initialized"""
|
|
with patch.object(modular_server, 'initialize') as mock_init, \
|
|
patch.object(modular_server.mcp, 'run') as mock_run:
|
|
|
|
modular_server.run("stdio")
|
|
|
|
mock_init.assert_called_once()
|
|
mock_run.assert_called_once_with(transport="stdio")
|
|
|
|
def test_modular_server_run_already_initialized(self, modular_server):
|
|
"""Test that run method doesn't reinitialize if already initialized"""
|
|
modular_server._initialized = True
|
|
|
|
with patch.object(modular_server, 'initialize') as mock_init, \
|
|
patch.object(modular_server.mcp, 'run') as mock_run:
|
|
|
|
modular_server.run("sse")
|
|
|
|
mock_init.assert_not_called()
|
|
mock_run.assert_called_once_with(transport="sse")
|
|
|
|
def test_modular_server_custom_directories(self):
|
|
"""Test modular server with custom directories"""
|
|
server = ModularMCPServer(
|
|
name="Custom Server",
|
|
tools_directory="/custom/tools",
|
|
prompts_directory="/custom/prompts",
|
|
resources_directory="/custom/resources"
|
|
)
|
|
|
|
assert server.tool_registry.directory == "/custom/tools"
|
|
assert server.prompt_registry.directory == "/custom/prompts"
|
|
assert server.resource_registry.directory == "/custom/resources"
|
|
|
|
|
|
class TestServerRegistries:
|
|
"""Test server registry classes"""
|
|
|
|
@pytest.fixture
|
|
def mock_tool(self):
|
|
"""Create a mock tool for testing"""
|
|
tool = Mock()
|
|
tool.name = "test_tool"
|
|
tool.description = "A test tool"
|
|
tool.input_schema = {"type": "object"}
|
|
return tool
|
|
|
|
@pytest.fixture
|
|
def mock_prompt(self):
|
|
"""Create a mock prompt for testing"""
|
|
prompt = Mock()
|
|
prompt.name = "test_prompt"
|
|
prompt.description = "A test prompt"
|
|
prompt.arguments = {"name": {"type": "string"}}
|
|
return prompt
|
|
|
|
@pytest.fixture
|
|
def mock_resource(self):
|
|
"""Create a mock resource for testing"""
|
|
resource = Mock()
|
|
resource.uri = "test://resource"
|
|
resource.name = "Test Resource"
|
|
resource.description = "A test resource"
|
|
return resource
|
|
|
|
def test_tool_registry_initialization(self):
|
|
"""Test tool registry initialization"""
|
|
registry = ServerToolRegistry("/test/tools")
|
|
assert registry.directory == "/test/tools"
|
|
assert registry._tools == {}
|
|
|
|
def test_tool_registry_register_tool(self, mock_tool):
|
|
"""Test tool registration"""
|
|
registry = ServerToolRegistry()
|
|
registry.register_tool(mock_tool)
|
|
|
|
assert "test_tool" in registry._tools
|
|
assert registry._tools["test_tool"] == mock_tool
|
|
|
|
def test_tool_registry_get_all_tools(self, mock_tool):
|
|
"""Test getting all tools"""
|
|
registry = ServerToolRegistry()
|
|
registry.register_tool(mock_tool)
|
|
|
|
tools = registry.get_all_tools()
|
|
assert len(tools) == 1
|
|
assert tools[0] == mock_tool
|
|
|
|
def test_tool_registry_get_tool_names(self, mock_tool):
|
|
"""Test getting tool names"""
|
|
registry = ServerToolRegistry()
|
|
registry.register_tool(mock_tool)
|
|
|
|
names = registry.get_tool_names()
|
|
assert names == ["test_tool"]
|
|
|
|
def test_prompt_registry_initialization(self):
|
|
"""Test prompt registry initialization"""
|
|
registry = ServerPromptRegistry("/test/prompts")
|
|
assert registry.directory == "/test/prompts"
|
|
assert registry._prompts == {}
|
|
|
|
def test_prompt_registry_register_prompt(self, mock_prompt):
|
|
"""Test prompt registration"""
|
|
registry = ServerPromptRegistry()
|
|
registry.register_prompt(mock_prompt)
|
|
|
|
assert "test_prompt" in registry._prompts
|
|
assert registry._prompts["test_prompt"] == mock_prompt
|
|
|
|
def test_prompt_registry_get_all_prompts(self, mock_prompt):
|
|
"""Test getting all prompts"""
|
|
registry = ServerPromptRegistry()
|
|
registry.register_prompt(mock_prompt)
|
|
|
|
prompts = registry.get_all_prompts()
|
|
assert len(prompts) == 1
|
|
assert prompts[0] == mock_prompt
|
|
|
|
def test_prompt_registry_get_prompt_names(self, mock_prompt):
|
|
"""Test getting prompt names"""
|
|
registry = ServerPromptRegistry()
|
|
registry.register_prompt(mock_prompt)
|
|
|
|
names = registry.get_prompt_names()
|
|
assert names == ["test_prompt"]
|
|
|
|
def test_resource_registry_initialization(self):
|
|
"""Test resource registry initialization"""
|
|
registry = ServerResourceRegistry("/test/resources")
|
|
assert registry.directory == "/test/resources"
|
|
assert registry._resources == {}
|
|
|
|
def test_resource_registry_register_resource(self, mock_resource):
|
|
"""Test resource registration"""
|
|
registry = ServerResourceRegistry()
|
|
registry.register_resource(mock_resource)
|
|
|
|
assert "test://resource" in registry._resources
|
|
assert registry._resources["test://resource"] == mock_resource
|
|
|
|
def test_resource_registry_get_all_resources(self, mock_resource):
|
|
"""Test getting all resources"""
|
|
registry = ServerResourceRegistry()
|
|
registry.register_resource(mock_resource)
|
|
|
|
resources = registry.get_all_resources()
|
|
assert len(resources) == 1
|
|
assert resources[0] == mock_resource
|
|
|
|
def test_resource_registry_get_resource_uris(self, mock_resource):
|
|
"""Test getting resource URIs"""
|
|
registry = ServerResourceRegistry()
|
|
registry.register_resource(mock_resource)
|
|
|
|
uris = registry.get_resource_uris()
|
|
assert uris == ["test://resource"]
|
|
|
|
|
|
class TestBaseServerTools:
|
|
"""Test base server tool classes"""
|
|
|
|
def test_base_server_tool_initialization(self):
|
|
"""Test BaseServerTool initialization"""
|
|
class TestTool(BaseServerTool):
|
|
async def execute(self, **kwargs):
|
|
return "test_result"
|
|
|
|
tool = TestTool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
input_schema={"type": "object"}
|
|
)
|
|
|
|
assert tool.name == "test_tool"
|
|
assert tool.description == "A test tool"
|
|
assert tool.input_schema == {"type": "object"}
|
|
|
|
def test_base_server_tool_get_tool_definition(self):
|
|
"""Test getting tool definition"""
|
|
class TestTool(BaseServerTool):
|
|
async def execute(self, **kwargs):
|
|
return "test_result"
|
|
|
|
tool = TestTool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
input_schema={"type": "object"}
|
|
)
|
|
|
|
definition = tool.get_tool_definition()
|
|
expected = {
|
|
"name": "test_tool",
|
|
"description": "A test tool",
|
|
"input_schema": {"type": "object"}
|
|
}
|
|
|
|
assert definition == expected
|
|
|
|
def test_base_server_tool_create_fastmcp_tool(self):
|
|
"""Test creating FastMCP tool wrapper"""
|
|
class TestTool(BaseServerTool):
|
|
async def execute(self, **kwargs):
|
|
return "test_result"
|
|
|
|
tool = TestTool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
input_schema={"type": "object"}
|
|
)
|
|
|
|
# Mock MCP server
|
|
mock_mcp = Mock()
|
|
mock_mcp.tool.return_value = lambda func: func
|
|
|
|
wrapper = tool.create_fastmcp_tool(mock_mcp)
|
|
|
|
assert wrapper.__name__ == "test_tool"
|
|
assert wrapper.__doc__ == "A test tool"
|
|
|
|
def test_context_aware_tool_initialization(self):
|
|
"""Test ContextAwareTool initialization"""
|
|
class TestContextTool(ContextAwareTool):
|
|
async def execute_with_context(self, ctx, **kwargs):
|
|
return "context_result"
|
|
|
|
tool = TestContextTool(
|
|
name="context_tool",
|
|
description="A context-aware tool",
|
|
input_schema={"type": "object"}
|
|
)
|
|
|
|
assert tool.name == "context_tool"
|
|
assert tool.description == "A context-aware tool"
|
|
assert tool.input_schema == {"type": "object"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_aware_tool_execute_with_context(self):
|
|
"""Test ContextAwareTool execution with context"""
|
|
from unittest.mock import Mock
|
|
|
|
class TestContextTool(ContextAwareTool):
|
|
async def execute_with_context(self, ctx, **kwargs):
|
|
return f"Context: {self.context}, Args: {kwargs}"
|
|
|
|
tool = TestContextTool(
|
|
name="context_tool",
|
|
description="A context-aware tool",
|
|
input_schema={"type": "object"}
|
|
)
|
|
|
|
# Set context
|
|
tool.context = {"user_id": "123"}
|
|
|
|
# Mock the context
|
|
mock_ctx = Mock()
|
|
|
|
result = await tool.execute_with_context(mock_ctx, test_arg="value")
|
|
assert "Context: {'user_id': '123'}" in result
|
|
assert "Args: {'test_arg': 'value'}" in result
|