initial mcp server setup
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Unit tests for AI client components
|
||||
"""
|
||||
from unittest import async_case
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from mcp_template.llm_client.base_client import BaseAIClient
|
||||
from mcp_template.llm_client.openai_client import OpenAIClient
|
||||
from mcp_template.llm_client.client_factory import AIClientFactory
|
||||
|
||||
|
||||
class TestBaseAIClient:
|
||||
"""Test base AI client functionality"""
|
||||
|
||||
def test_base_client_creation(self):
|
||||
"""Test creating a base client"""
|
||||
# Create a concrete implementation for testing
|
||||
class TestAIClient(BaseAIClient):
|
||||
async def chat_completion(self, messages, tools=None, **kwargs):
|
||||
return {"choices": [{"message": {"content": "test response"}}]}
|
||||
|
||||
async def _initialize_client(self):
|
||||
pass
|
||||
|
||||
client = TestAIClient(
|
||||
model_name="test-model",
|
||||
provider="test",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert client.model_name == "test-model"
|
||||
assert client._api_key == "test-key"
|
||||
assert not client._initialized
|
||||
|
||||
def test_abstract_class_cannot_be_instantiated(self):
|
||||
"""Test that abstract BaseAIClient cannot be instantiated directly"""
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
BaseAIClient(
|
||||
model_name="test-model",
|
||||
provider="test",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIClient:
|
||||
"""Test OpenAI client implementation"""
|
||||
|
||||
def test_openai_client_creation(self):
|
||||
"""Test creating an OpenAI client"""
|
||||
client = OpenAIClient(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key",
|
||||
temperature=0.5,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
assert client.model_name == "gpt-4o"
|
||||
assert client._temperature == 0.5
|
||||
assert client._max_tokens == 500
|
||||
|
||||
def test_openai_client_defaults(self):
|
||||
"""Test OpenAI client default values"""
|
||||
client = OpenAIClient(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert client._temperature == 0.7
|
||||
assert client._max_tokens == 1000
|
||||
|
||||
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_initialization(self, mock_openai_class):
|
||||
"""Test OpenAI client initialization"""
|
||||
mock_client = Mock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
await client._initialize_client()
|
||||
|
||||
mock_openai_class.assert_called_once_with(api_key="test-key")
|
||||
assert client._client == mock_client
|
||||
|
||||
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion(self, mock_openai_class):
|
||||
"""Test OpenAI chat completion"""
|
||||
# Mock the OpenAI client and response
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_choice = Mock()
|
||||
mock_message = Mock()
|
||||
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = "Test response"
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice.message = mock_message
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
await client.initialize()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await client.chat_completion(messages)
|
||||
|
||||
# Verify the call was made correctly
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args[1]
|
||||
|
||||
assert call_args["model"] == "gpt-4o"
|
||||
assert call_args["messages"] == messages
|
||||
assert call_args["temperature"] == 0.7
|
||||
assert call_args["max_tokens"] == 1000
|
||||
|
||||
# Verify response format
|
||||
assert "choices" in result
|
||||
assert len(result["choices"]) == 1
|
||||
assert result["choices"][0]["message"]["content"] == "Test response"
|
||||
|
||||
def test_openai_tool_formatting(self):
|
||||
"""Test OpenAI tool formatting"""
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
formatted = client._format_tools_for_provider(tools)
|
||||
|
||||
assert len(formatted) == 1
|
||||
assert formatted[0]["type"] == "function"
|
||||
assert formatted[0]["function"]["name"] == "add"
|
||||
assert formatted[0]["function"]["description"] == "Add two numbers"
|
||||
assert "parameters" in formatted[0]["function"]
|
||||
|
||||
|
||||
class TestAIClientFactory:
|
||||
"""Test AI client factory"""
|
||||
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating OpenAI client via factory"""
|
||||
client = AIClientFactory.create_openai_client(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model_name == "gpt-4o"
|
||||
|
||||
def test_create_client_by_provider(self):
|
||||
"""Test creating client by provider name"""
|
||||
client = AIClientFactory.create_client(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model_name == "gpt-4o"
|
||||
|
||||
def test_invalid_provider(self):
|
||||
"""Test creating client with invalid provider"""
|
||||
with pytest.raises(ValueError, match="Unsupported AI provider"):
|
||||
AIClientFactory.create_client("invalid", "model", "key")
|
||||
|
||||
@patch('mcp_template.llm_client.client_factory.CONFIG_AVAILABLE', True)
|
||||
@patch('mcp_template.llm_client.client_factory.Config')
|
||||
def test_missing_api_key(self, mock_config):
|
||||
"""Test creating client without API key"""
|
||||
# Mock config to return None for API key
|
||||
mock_config.OPENAI_API_KEY = None
|
||||
|
||||
with pytest.raises(ValueError, match="API key not provided and could not be loaded from config for provider: openai"):
|
||||
AIClientFactory.create_client("openai", "gpt-4o")
|
||||
|
||||
def test_available_providers(self):
|
||||
"""Test getting available providers"""
|
||||
providers = AIClientFactory.get_available_providers()
|
||||
assert "openai" in providers
|
||||
assert "claude" in providers
|
||||
assert "grok" in providers
|
||||
|
||||
def test_validate_provider(self):
|
||||
"""Test provider validation"""
|
||||
assert AIClientFactory.validate_provider("openai") is True
|
||||
assert AIClientFactory.validate_provider("claude") is True
|
||||
assert AIClientFactory.validate_provider("invalid") is False
|
||||
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Unit tests for core MCP types
|
||||
"""
|
||||
import pytest
|
||||
from mcp_template.core.types import (
|
||||
TransportType,
|
||||
MCPTool,
|
||||
MCPResource,
|
||||
MCPPrompt,
|
||||
MCPServerConfig
|
||||
)
|
||||
|
||||
|
||||
class TestTransportType:
|
||||
"""Test TransportType enum"""
|
||||
|
||||
def test_transport_type_values(self):
|
||||
"""Test transport type enum values"""
|
||||
assert TransportType.SSE.value == "sse"
|
||||
assert TransportType.STDIO.value == "stdio"
|
||||
|
||||
def test_transport_type_from_string(self):
|
||||
"""Test creating transport type from string"""
|
||||
assert TransportType("sse") == TransportType.SSE
|
||||
assert TransportType("stdio") == TransportType.STDIO
|
||||
assert TransportType("SSE") == TransportType.SSE
|
||||
|
||||
|
||||
class TestMCPTool:
|
||||
"""Test MCPTool dataclass"""
|
||||
|
||||
def test_tool_creation(self):
|
||||
"""Test creating a valid tool"""
|
||||
async def sample_handler(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
tool = MCPTool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
},
|
||||
handler=sample_handler
|
||||
)
|
||||
|
||||
assert tool.name == "add"
|
||||
assert tool.description == "Add two numbers"
|
||||
assert "a" in tool.input_schema["properties"]
|
||||
|
||||
def test_tool_validation(self):
|
||||
"""Test tool validation"""
|
||||
async def handler():
|
||||
pass
|
||||
|
||||
# Valid tool
|
||||
tool = MCPTool(
|
||||
name="test",
|
||||
description="Test tool",
|
||||
input_schema={},
|
||||
handler=handler
|
||||
)
|
||||
assert tool.name == "test"
|
||||
|
||||
# Invalid tool - empty name
|
||||
with pytest.raises(ValueError):
|
||||
MCPTool(
|
||||
name="",
|
||||
description="Test tool",
|
||||
input_schema={},
|
||||
handler=handler
|
||||
)
|
||||
|
||||
# Invalid tool - empty description
|
||||
with pytest.raises(ValueError):
|
||||
MCPTool(
|
||||
name="test",
|
||||
description="",
|
||||
input_schema={},
|
||||
handler=handler
|
||||
)
|
||||
|
||||
|
||||
class TestMCPResource:
|
||||
"""Test MCPResource dataclass"""
|
||||
|
||||
def test_resource_creation(self):
|
||||
"""Test creating a valid resource"""
|
||||
resource = MCPResource(
|
||||
uri="file:///test.txt",
|
||||
name="Test File",
|
||||
description="A test file",
|
||||
mime_type="text/plain",
|
||||
content="Hello, world!"
|
||||
)
|
||||
|
||||
assert resource.uri == "file:///test.txt"
|
||||
assert resource.name == "Test File"
|
||||
assert resource.content == "Hello, world!"
|
||||
|
||||
def test_resource_validation(self):
|
||||
"""Test resource validation"""
|
||||
# Valid resource
|
||||
resource = MCPResource(
|
||||
uri="file:///test.txt",
|
||||
name="Test",
|
||||
description="Test resource",
|
||||
mime_type="text/plain",
|
||||
content=b"test"
|
||||
)
|
||||
assert resource.uri == "file:///test.txt"
|
||||
|
||||
# Invalid resource - empty URI
|
||||
with pytest.raises(ValueError):
|
||||
MCPResource(
|
||||
uri="",
|
||||
name="Test",
|
||||
description="Test resource",
|
||||
mime_type="text/plain",
|
||||
content="test"
|
||||
)
|
||||
|
||||
# Invalid resource - empty name
|
||||
with pytest.raises(ValueError):
|
||||
MCPResource(
|
||||
uri="file:///test.txt",
|
||||
name="",
|
||||
description="Test resource",
|
||||
mime_type="text/plain",
|
||||
content="test"
|
||||
)
|
||||
|
||||
|
||||
class TestMCPPrompt:
|
||||
"""Test MCPPrompt dataclass"""
|
||||
|
||||
def test_prompt_creation(self):
|
||||
"""Test creating a valid prompt"""
|
||||
prompt = MCPPrompt(
|
||||
name="greeting",
|
||||
description="A greeting prompt",
|
||||
template="Hello, {name}! Welcome to {place}.",
|
||||
arguments={
|
||||
"name": {"type": "string", "description": "Person's name"},
|
||||
"place": {"type": "string", "description": "Place name"}
|
||||
}
|
||||
)
|
||||
|
||||
assert prompt.name == "greeting"
|
||||
assert "name" in prompt.template
|
||||
assert prompt.arguments is not None
|
||||
|
||||
def test_prompt_validation(self):
|
||||
"""Test prompt validation"""
|
||||
# Valid prompt
|
||||
prompt = MCPPrompt(
|
||||
name="test",
|
||||
description="Test prompt",
|
||||
template="Hello, world!"
|
||||
)
|
||||
assert prompt.name == "test"
|
||||
|
||||
# Invalid prompt - empty name
|
||||
with pytest.raises(ValueError):
|
||||
MCPPrompt(
|
||||
name="",
|
||||
description="Test prompt",
|
||||
template="Hello, world!"
|
||||
)
|
||||
|
||||
# Invalid prompt - empty template
|
||||
with pytest.raises(ValueError):
|
||||
MCPPrompt(
|
||||
name="test",
|
||||
description="Test prompt",
|
||||
template=""
|
||||
)
|
||||
|
||||
|
||||
class TestMCPServerConfig:
|
||||
"""Test MCPServerConfig dataclass"""
|
||||
|
||||
def test_server_config_creation(self):
|
||||
"""Test creating a valid server config"""
|
||||
config = MCPServerConfig(
|
||||
name="Test Server",
|
||||
version="1.0.0",
|
||||
transport=TransportType.SSE,
|
||||
host="localhost",
|
||||
port=8080,
|
||||
stateless_http=True
|
||||
)
|
||||
|
||||
assert config.name == "Test Server"
|
||||
assert config.transport == TransportType.SSE
|
||||
assert config.port == 8080
|
||||
assert config.tools == []
|
||||
assert config.resources == []
|
||||
assert config.prompts == []
|
||||
|
||||
def test_server_config_defaults(self):
|
||||
"""Test server config default values"""
|
||||
config = MCPServerConfig(name="Test Server")
|
||||
|
||||
assert config.version == "1.0.0"
|
||||
assert config.transport == TransportType.STDIO
|
||||
assert config.host == "0.0.0.0"
|
||||
assert config.port == 8050
|
||||
assert config.stateless_http is True
|
||||
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Unit tests for MCP server components
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
from mcp_template.core.types import MCPTool, MCPResource, MCPPrompt, MCPServerConfig, TransportType
|
||||
# FastMCPServer not available, using ModularMCPServer for testing
|
||||
from mcp_template.server.server_factory import MCPServerFactory
|
||||
from mcp_template.server.modular_server import ModularMCPServer
|
||||
from mcp_template.server.tools.tool_registry import ServerToolRegistry
|
||||
from mcp_template.server.prompts.prompt_registry import ServerPromptRegistry
|
||||
from mcp_template.server.resources.resource_registry import ServerResourceRegistry
|
||||
from mcp_template.server.tools.base_tool import BaseServerTool, ContextAwareTool
|
||||
from mcp_template.server.prompts.base_prompt import BaseServerPrompt
|
||||
from mcp_template.server.resources.base_resource import BaseServerResource
|
||||
|
||||
|
||||
class TestModularMCPServerBasic:
|
||||
"""Test basic ModularMCP server functionality (replacing FastMCPServer tests)"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool(self):
|
||||
"""Create a sample tool for testing"""
|
||||
async def add_handler(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
return MCPTool(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
},
|
||||
handler=add_handler
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Create a sample server config"""
|
||||
return MCPServerConfig(
|
||||
name="Test Server",
|
||||
transport=TransportType.STDIO,
|
||||
host="localhost",
|
||||
port=8050
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_initialization(self, server_config, sample_tool):
|
||||
"""Test server initialization"""
|
||||
# Use factory to create server
|
||||
server = MCPServerFactory.create_server(
|
||||
name="Test Server",
|
||||
transport="stdio"
|
||||
)
|
||||
|
||||
# Server should not be initialized yet
|
||||
assert not server._initialized
|
||||
|
||||
# Initialize server
|
||||
await server.initialize()
|
||||
|
||||
# Server should be initialized
|
||||
assert server._initialized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_info(self, server_config):
|
||||
"""Test getting server information"""
|
||||
server = MCPServerFactory.create_server(
|
||||
name="Test Server",
|
||||
transport="stdio"
|
||||
)
|
||||
|
||||
# Test server info before initialization
|
||||
info = server.get_server_info()
|
||||
assert info["name"] == "Test Server"
|
||||
assert info["host"] == "0.0.0.0"
|
||||
assert info["port"] == 8050
|
||||
|
||||
def test_server_run_method(self, server_config):
|
||||
"""Test server run method"""
|
||||
server = MCPServerFactory.create_server(
|
||||
name="Test Server",
|
||||
transport="stdio"
|
||||
)
|
||||
|
||||
# Mock the run method to avoid actually starting the server
|
||||
with patch.object(server.mcp, 'run') as mock_run:
|
||||
server.run("stdio")
|
||||
mock_run.assert_called_once_with(transport="stdio")
|
||||
|
||||
|
||||
|
||||
class TestMCPServerFactory:
|
||||
"""Test MCP server factory"""
|
||||
|
||||
def test_create_server_basic(self):
|
||||
"""Test creating a basic server"""
|
||||
server = MCPServerFactory.create_server(
|
||||
name="Test Server",
|
||||
transport="stdio"
|
||||
)
|
||||
|
||||
assert isinstance(server, ModularMCPServer)
|
||||
assert server.name == "Test Server"
|
||||
assert server.host == "0.0.0.0"
|
||||
assert server.port == 8050
|
||||
|
||||
def test_create_calculator_server(self):
|
||||
"""Test creating a calculator server"""
|
||||
server = MCPServerFactory.create_basic_calculator_server(
|
||||
name="Calculator Server",
|
||||
transport="sse",
|
||||
port=8080
|
||||
)
|
||||
|
||||
assert isinstance(server, ModularMCPServer)
|
||||
assert server.name == "Calculator Server"
|
||||
assert server.host == "0.0.0.0"
|
||||
assert server.port == 8080
|
||||
|
||||
# For ModularMCPServer, we can't easily test tool registration without mocking
|
||||
# The tools would be registered through the registry system
|
||||
assert server.tool_registry is not None
|
||||
|
||||
def test_create_knowledge_base_server(self):
|
||||
"""Test creating a knowledge base server"""
|
||||
kb_data = {
|
||||
"policy": "Company vacation policy",
|
||||
"faq": "Frequently asked questions"
|
||||
}
|
||||
|
||||
server = MCPServerFactory.create_knowledge_base_server(
|
||||
name="KB Server",
|
||||
kb_data=kb_data
|
||||
)
|
||||
|
||||
assert isinstance(server, ModularMCPServer)
|
||||
assert server.name == "KB Server"
|
||||
|
||||
# For ModularMCPServer, we can't easily test tool registration without mocking
|
||||
# The tools would be registered through the registry system
|
||||
assert server.tool_registry is not None
|
||||
|
||||
|
||||
class TestModularMCPServer:
|
||||
"""Test ModularMCP server implementation"""
|
||||
|
||||
@pytest.fixture
|
||||
def modular_server(self):
|
||||
"""Create a modular server for testing"""
|
||||
return ModularMCPServer(
|
||||
name="Test Modular Server",
|
||||
host="localhost",
|
||||
port=8051
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modular_server_initialization(self, modular_server):
|
||||
"""Test modular server initialization"""
|
||||
assert modular_server.name == "Test Modular Server"
|
||||
assert modular_server.host == "localhost"
|
||||
assert modular_server.port == 8051
|
||||
assert not modular_server._initialized
|
||||
assert modular_server.tool_registry is not None
|
||||
assert modular_server.prompt_registry is not None
|
||||
assert modular_server.resource_registry is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modular_server_initialize(self, modular_server):
|
||||
"""Test modular server initialization process"""
|
||||
# Mock the registry methods to avoid actual file discovery
|
||||
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
|
||||
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
|
||||
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
|
||||
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
|
||||
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
|
||||
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
|
||||
|
||||
await modular_server.initialize()
|
||||
|
||||
assert modular_server._initialized
|
||||
mock_tools.assert_called_once_with(modular_server.mcp)
|
||||
mock_prompts.assert_called_once_with(modular_server.mcp)
|
||||
mock_resources.assert_called_once_with(modular_server.mcp)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modular_server_double_initialize(self, modular_server):
|
||||
"""Test that double initialization doesn't cause issues"""
|
||||
with patch.object(modular_server.tool_registry, 'register_tools_with_server') as mock_tools, \
|
||||
patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[]), \
|
||||
patch.object(modular_server.prompt_registry, 'register_prompts_with_server') as mock_prompts, \
|
||||
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[]), \
|
||||
patch.object(modular_server.resource_registry, 'register_resources_with_server') as mock_resources, \
|
||||
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[]):
|
||||
|
||||
# Initialize twice
|
||||
await modular_server.initialize()
|
||||
await modular_server.initialize()
|
||||
|
||||
# Should only be called once due to _initialized check
|
||||
mock_tools.assert_called_once()
|
||||
mock_prompts.assert_called_once()
|
||||
mock_resources.assert_called_once()
|
||||
|
||||
def test_modular_server_get_server_info(self, modular_server):
|
||||
"""Test getting server information"""
|
||||
# Mock the registry methods
|
||||
with patch.object(modular_server.tool_registry, 'get_all_tools', return_value=[Mock(name="tool1")]), \
|
||||
patch.object(modular_server.tool_registry, 'get_tool_names', return_value=["tool1"]), \
|
||||
patch.object(modular_server.prompt_registry, 'get_all_prompts', return_value=[Mock(name="prompt1")]), \
|
||||
patch.object(modular_server.prompt_registry, 'get_prompt_names', return_value=["prompt1"]), \
|
||||
patch.object(modular_server.resource_registry, 'get_all_resources', return_value=[Mock(uri="resource1")]), \
|
||||
patch.object(modular_server.resource_registry, 'get_resource_uris', return_value=["resource1"]):
|
||||
|
||||
info = modular_server.get_server_info()
|
||||
|
||||
assert info["name"] == "Test Modular Server"
|
||||
assert info["host"] == "localhost"
|
||||
assert info["port"] == 8051
|
||||
assert info["tools"]["count"] == 1
|
||||
assert info["tools"]["names"] == ["tool1"]
|
||||
assert info["prompts"]["count"] == 1
|
||||
assert info["prompts"]["names"] == ["prompt1"]
|
||||
assert info["resources"]["count"] == 1
|
||||
assert info["resources"]["uris"] == ["resource1"]
|
||||
|
||||
def test_modular_server_run_without_initialize(self, modular_server):
|
||||
"""Test that run method initializes server if not already initialized"""
|
||||
with patch.object(modular_server, 'initialize') as mock_init, \
|
||||
patch.object(modular_server.mcp, 'run') as mock_run:
|
||||
|
||||
modular_server.run("stdio")
|
||||
|
||||
mock_init.assert_called_once()
|
||||
mock_run.assert_called_once_with(transport="stdio")
|
||||
|
||||
def test_modular_server_run_already_initialized(self, modular_server):
|
||||
"""Test that run method doesn't reinitialize if already initialized"""
|
||||
modular_server._initialized = True
|
||||
|
||||
with patch.object(modular_server, 'initialize') as mock_init, \
|
||||
patch.object(modular_server.mcp, 'run') as mock_run:
|
||||
|
||||
modular_server.run("sse")
|
||||
|
||||
mock_init.assert_not_called()
|
||||
mock_run.assert_called_once_with(transport="sse")
|
||||
|
||||
def test_modular_server_custom_directories(self):
|
||||
"""Test modular server with custom directories"""
|
||||
server = ModularMCPServer(
|
||||
name="Custom Server",
|
||||
tools_directory="/custom/tools",
|
||||
prompts_directory="/custom/prompts",
|
||||
resources_directory="/custom/resources"
|
||||
)
|
||||
|
||||
assert server.tool_registry.directory == "/custom/tools"
|
||||
assert server.prompt_registry.directory == "/custom/prompts"
|
||||
assert server.resource_registry.directory == "/custom/resources"
|
||||
|
||||
|
||||
class TestServerRegistries:
|
||||
"""Test server registry classes"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool(self):
|
||||
"""Create a mock tool for testing"""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "A test tool"
|
||||
tool.input_schema = {"type": "object"}
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt(self):
|
||||
"""Create a mock prompt for testing"""
|
||||
prompt = Mock()
|
||||
prompt.name = "test_prompt"
|
||||
prompt.description = "A test prompt"
|
||||
prompt.arguments = {"name": {"type": "string"}}
|
||||
return prompt
|
||||
|
||||
@pytest.fixture
|
||||
def mock_resource(self):
|
||||
"""Create a mock resource for testing"""
|
||||
resource = Mock()
|
||||
resource.uri = "test://resource"
|
||||
resource.name = "Test Resource"
|
||||
resource.description = "A test resource"
|
||||
return resource
|
||||
|
||||
def test_tool_registry_initialization(self):
|
||||
"""Test tool registry initialization"""
|
||||
registry = ServerToolRegistry("/test/tools")
|
||||
assert registry.directory == "/test/tools"
|
||||
assert registry._tools == {}
|
||||
|
||||
def test_tool_registry_register_tool(self, mock_tool):
|
||||
"""Test tool registration"""
|
||||
registry = ServerToolRegistry()
|
||||
registry.register_tool(mock_tool)
|
||||
|
||||
assert "test_tool" in registry._tools
|
||||
assert registry._tools["test_tool"] == mock_tool
|
||||
|
||||
def test_tool_registry_get_all_tools(self, mock_tool):
|
||||
"""Test getting all tools"""
|
||||
registry = ServerToolRegistry()
|
||||
registry.register_tool(mock_tool)
|
||||
|
||||
tools = registry.get_all_tools()
|
||||
assert len(tools) == 1
|
||||
assert tools[0] == mock_tool
|
||||
|
||||
def test_tool_registry_get_tool_names(self, mock_tool):
|
||||
"""Test getting tool names"""
|
||||
registry = ServerToolRegistry()
|
||||
registry.register_tool(mock_tool)
|
||||
|
||||
names = registry.get_tool_names()
|
||||
assert names == ["test_tool"]
|
||||
|
||||
def test_prompt_registry_initialization(self):
|
||||
"""Test prompt registry initialization"""
|
||||
registry = ServerPromptRegistry("/test/prompts")
|
||||
assert registry.directory == "/test/prompts"
|
||||
assert registry._prompts == {}
|
||||
|
||||
def test_prompt_registry_register_prompt(self, mock_prompt):
|
||||
"""Test prompt registration"""
|
||||
registry = ServerPromptRegistry()
|
||||
registry.register_prompt(mock_prompt)
|
||||
|
||||
assert "test_prompt" in registry._prompts
|
||||
assert registry._prompts["test_prompt"] == mock_prompt
|
||||
|
||||
def test_prompt_registry_get_all_prompts(self, mock_prompt):
|
||||
"""Test getting all prompts"""
|
||||
registry = ServerPromptRegistry()
|
||||
registry.register_prompt(mock_prompt)
|
||||
|
||||
prompts = registry.get_all_prompts()
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0] == mock_prompt
|
||||
|
||||
def test_prompt_registry_get_prompt_names(self, mock_prompt):
|
||||
"""Test getting prompt names"""
|
||||
registry = ServerPromptRegistry()
|
||||
registry.register_prompt(mock_prompt)
|
||||
|
||||
names = registry.get_prompt_names()
|
||||
assert names == ["test_prompt"]
|
||||
|
||||
def test_resource_registry_initialization(self):
|
||||
"""Test resource registry initialization"""
|
||||
registry = ServerResourceRegistry("/test/resources")
|
||||
assert registry.directory == "/test/resources"
|
||||
assert registry._resources == {}
|
||||
|
||||
def test_resource_registry_register_resource(self, mock_resource):
|
||||
"""Test resource registration"""
|
||||
registry = ServerResourceRegistry()
|
||||
registry.register_resource(mock_resource)
|
||||
|
||||
assert "test://resource" in registry._resources
|
||||
assert registry._resources["test://resource"] == mock_resource
|
||||
|
||||
def test_resource_registry_get_all_resources(self, mock_resource):
|
||||
"""Test getting all resources"""
|
||||
registry = ServerResourceRegistry()
|
||||
registry.register_resource(mock_resource)
|
||||
|
||||
resources = registry.get_all_resources()
|
||||
assert len(resources) == 1
|
||||
assert resources[0] == mock_resource
|
||||
|
||||
def test_resource_registry_get_resource_uris(self, mock_resource):
|
||||
"""Test getting resource URIs"""
|
||||
registry = ServerResourceRegistry()
|
||||
registry.register_resource(mock_resource)
|
||||
|
||||
uris = registry.get_resource_uris()
|
||||
assert uris == ["test://resource"]
|
||||
|
||||
|
||||
class TestBaseServerTools:
|
||||
"""Test base server tool classes"""
|
||||
|
||||
def test_base_server_tool_initialization(self):
|
||||
"""Test BaseServerTool initialization"""
|
||||
class TestTool(BaseServerTool):
|
||||
async def execute(self, **kwargs):
|
||||
return "test_result"
|
||||
|
||||
tool = TestTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
input_schema={"type": "object"}
|
||||
)
|
||||
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "A test tool"
|
||||
assert tool.input_schema == {"type": "object"}
|
||||
|
||||
def test_base_server_tool_get_tool_definition(self):
|
||||
"""Test getting tool definition"""
|
||||
class TestTool(BaseServerTool):
|
||||
async def execute(self, **kwargs):
|
||||
return "test_result"
|
||||
|
||||
tool = TestTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
input_schema={"type": "object"}
|
||||
)
|
||||
|
||||
definition = tool.get_tool_definition()
|
||||
expected = {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"input_schema": {"type": "object"}
|
||||
}
|
||||
|
||||
assert definition == expected
|
||||
|
||||
def test_base_server_tool_create_fastmcp_tool(self):
|
||||
"""Test creating FastMCP tool wrapper"""
|
||||
class TestTool(BaseServerTool):
|
||||
async def execute(self, **kwargs):
|
||||
return "test_result"
|
||||
|
||||
tool = TestTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
input_schema={"type": "object"}
|
||||
)
|
||||
|
||||
# Mock MCP server
|
||||
mock_mcp = Mock()
|
||||
mock_mcp.tool.return_value = lambda func: func
|
||||
|
||||
wrapper = tool.create_fastmcp_tool(mock_mcp)
|
||||
|
||||
assert wrapper.__name__ == "test_tool"
|
||||
assert wrapper.__doc__ == "A test tool"
|
||||
|
||||
def test_context_aware_tool_initialization(self):
|
||||
"""Test ContextAwareTool initialization"""
|
||||
class TestContextTool(ContextAwareTool):
|
||||
async def execute_with_context(self, ctx, **kwargs):
|
||||
return "context_result"
|
||||
|
||||
tool = TestContextTool(
|
||||
name="context_tool",
|
||||
description="A context-aware tool",
|
||||
input_schema={"type": "object"}
|
||||
)
|
||||
|
||||
assert tool.name == "context_tool"
|
||||
assert tool.description == "A context-aware tool"
|
||||
assert tool.input_schema == {"type": "object"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_aware_tool_execute_with_context(self):
|
||||
"""Test ContextAwareTool execution with context"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
class TestContextTool(ContextAwareTool):
|
||||
async def execute_with_context(self, ctx, **kwargs):
|
||||
return f"Context: {self.context}, Args: {kwargs}"
|
||||
|
||||
tool = TestContextTool(
|
||||
name="context_tool",
|
||||
description="A context-aware tool",
|
||||
input_schema={"type": "object"}
|
||||
)
|
||||
|
||||
# Set context
|
||||
tool.context = {"user_id": "123"}
|
||||
|
||||
# Mock the context
|
||||
mock_ctx = Mock()
|
||||
|
||||
result = await tool.execute_with_context(mock_ctx, test_arg="value")
|
||||
assert "Context: {'user_id': '123'}" in result
|
||||
assert "Args: {'test_arg': 'value'}" in result
|
||||
Reference in New Issue
Block a user