initial mcp server setup
This commit is contained in:
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Unit tests for AI client components
|
||||
"""
|
||||
from unittest import async_case
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from mcp_template.llm_client.base_client import BaseAIClient
|
||||
from mcp_template.llm_client.openai_client import OpenAIClient
|
||||
from mcp_template.llm_client.client_factory import AIClientFactory
|
||||
|
||||
|
||||
class TestBaseAIClient:
|
||||
"""Test base AI client functionality"""
|
||||
|
||||
def test_base_client_creation(self):
|
||||
"""Test creating a base client"""
|
||||
# Create a concrete implementation for testing
|
||||
class TestAIClient(BaseAIClient):
|
||||
async def chat_completion(self, messages, tools=None, **kwargs):
|
||||
return {"choices": [{"message": {"content": "test response"}}]}
|
||||
|
||||
async def _initialize_client(self):
|
||||
pass
|
||||
|
||||
client = TestAIClient(
|
||||
model_name="test-model",
|
||||
provider="test",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert client.model_name == "test-model"
|
||||
assert client._api_key == "test-key"
|
||||
assert not client._initialized
|
||||
|
||||
def test_abstract_class_cannot_be_instantiated(self):
|
||||
"""Test that abstract BaseAIClient cannot be instantiated directly"""
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
BaseAIClient(
|
||||
model_name="test-model",
|
||||
provider="test",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIClient:
|
||||
"""Test OpenAI client implementation"""
|
||||
|
||||
def test_openai_client_creation(self):
|
||||
"""Test creating an OpenAI client"""
|
||||
client = OpenAIClient(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key",
|
||||
temperature=0.5,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
assert client.model_name == "gpt-4o"
|
||||
assert client._temperature == 0.5
|
||||
assert client._max_tokens == 500
|
||||
|
||||
def test_openai_client_defaults(self):
|
||||
"""Test OpenAI client default values"""
|
||||
client = OpenAIClient(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert client._temperature == 0.7
|
||||
assert client._max_tokens == 1000
|
||||
|
||||
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_initialization(self, mock_openai_class):
|
||||
"""Test OpenAI client initialization"""
|
||||
mock_client = Mock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
await client._initialize_client()
|
||||
|
||||
mock_openai_class.assert_called_once_with(api_key="test-key")
|
||||
assert client._client == mock_client
|
||||
|
||||
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_chat_completion(self, mock_openai_class):
|
||||
"""Test OpenAI chat completion"""
|
||||
# Mock the OpenAI client and response
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_choice = Mock()
|
||||
mock_message = Mock()
|
||||
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = "Test response"
|
||||
mock_message.tool_calls = None
|
||||
|
||||
mock_choice.message = mock_message
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
await client.initialize()
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await client.chat_completion(messages)
|
||||
|
||||
# Verify the call was made correctly
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args[1]
|
||||
|
||||
assert call_args["model"] == "gpt-4o"
|
||||
assert call_args["messages"] == messages
|
||||
assert call_args["temperature"] == 0.7
|
||||
assert call_args["max_tokens"] == 1000
|
||||
|
||||
# Verify response format
|
||||
assert "choices" in result
|
||||
assert len(result["choices"]) == 1
|
||||
assert result["choices"][0]["message"]["content"] == "Test response"
|
||||
|
||||
def test_openai_tool_formatting(self):
|
||||
"""Test OpenAI tool formatting"""
|
||||
client = OpenAIClient("gpt-4o", "test-key")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
formatted = client._format_tools_for_provider(tools)
|
||||
|
||||
assert len(formatted) == 1
|
||||
assert formatted[0]["type"] == "function"
|
||||
assert formatted[0]["function"]["name"] == "add"
|
||||
assert formatted[0]["function"]["description"] == "Add two numbers"
|
||||
assert "parameters" in formatted[0]["function"]
|
||||
|
||||
|
||||
class TestAIClientFactory:
|
||||
"""Test AI client factory"""
|
||||
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating OpenAI client via factory"""
|
||||
client = AIClientFactory.create_openai_client(
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model_name == "gpt-4o"
|
||||
|
||||
def test_create_client_by_provider(self):
|
||||
"""Test creating client by provider name"""
|
||||
client = AIClientFactory.create_client(
|
||||
provider="openai",
|
||||
model_name="gpt-4o",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model_name == "gpt-4o"
|
||||
|
||||
def test_invalid_provider(self):
|
||||
"""Test creating client with invalid provider"""
|
||||
with pytest.raises(ValueError, match="Unsupported AI provider"):
|
||||
AIClientFactory.create_client("invalid", "model", "key")
|
||||
|
||||
@patch('mcp_template.llm_client.client_factory.CONFIG_AVAILABLE', True)
|
||||
@patch('mcp_template.llm_client.client_factory.Config')
|
||||
def test_missing_api_key(self, mock_config):
|
||||
"""Test creating client without API key"""
|
||||
# Mock config to return None for API key
|
||||
mock_config.OPENAI_API_KEY = None
|
||||
|
||||
with pytest.raises(ValueError, match="API key not provided and could not be loaded from config for provider: openai"):
|
||||
AIClientFactory.create_client("openai", "gpt-4o")
|
||||
|
||||
def test_available_providers(self):
|
||||
"""Test getting available providers"""
|
||||
providers = AIClientFactory.get_available_providers()
|
||||
assert "openai" in providers
|
||||
assert "claude" in providers
|
||||
assert "grok" in providers
|
||||
|
||||
def test_validate_provider(self):
|
||||
"""Test provider validation"""
|
||||
assert AIClientFactory.validate_provider("openai") is True
|
||||
assert AIClientFactory.validate_provider("claude") is True
|
||||
assert AIClientFactory.validate_provider("invalid") is False
|
||||
Reference in New Issue
Block a user