203 lines
6.9 KiB
Python
203 lines
6.9 KiB
Python
"""
|
|
Unit tests for AI client components
|
|
"""
|
|
from unittest import async_case
|
|
import pytest
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from mcp_template.llm_client.base_client import BaseAIClient
|
|
from mcp_template.llm_client.openai_client import OpenAIClient
|
|
from mcp_template.llm_client.client_factory import AIClientFactory
|
|
|
|
|
|
class TestBaseAIClient:
|
|
"""Test base AI client functionality"""
|
|
|
|
def test_base_client_creation(self):
|
|
"""Test creating a base client"""
|
|
# Create a concrete implementation for testing
|
|
class TestAIClient(BaseAIClient):
|
|
async def chat_completion(self, messages, tools=None, **kwargs):
|
|
return {"choices": [{"message": {"content": "test response"}}]}
|
|
|
|
async def _initialize_client(self):
|
|
pass
|
|
|
|
client = TestAIClient(
|
|
model_name="test-model",
|
|
provider="test",
|
|
api_key="test-key"
|
|
)
|
|
|
|
assert client.model_name == "test-model"
|
|
assert client._api_key == "test-key"
|
|
assert not client._initialized
|
|
|
|
def test_abstract_class_cannot_be_instantiated(self):
|
|
"""Test that abstract BaseAIClient cannot be instantiated directly"""
|
|
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
|
BaseAIClient(
|
|
model_name="test-model",
|
|
provider="test",
|
|
api_key="test-key"
|
|
)
|
|
|
|
|
|
class TestOpenAIClient:
|
|
"""Test OpenAI client implementation"""
|
|
|
|
def test_openai_client_creation(self):
|
|
"""Test creating an OpenAI client"""
|
|
client = OpenAIClient(
|
|
model_name="gpt-4o",
|
|
api_key="test-key",
|
|
temperature=0.5,
|
|
max_tokens=500
|
|
)
|
|
|
|
assert client.model_name == "gpt-4o"
|
|
assert client._temperature == 0.5
|
|
assert client._max_tokens == 500
|
|
|
|
def test_openai_client_defaults(self):
|
|
"""Test OpenAI client default values"""
|
|
client = OpenAIClient(
|
|
model_name="gpt-4o",
|
|
api_key="test-key"
|
|
)
|
|
|
|
assert client._temperature == 0.7
|
|
assert client._max_tokens == 1000
|
|
|
|
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
|
@pytest.mark.asyncio
|
|
async def test_openai_initialization(self, mock_openai_class):
|
|
"""Test OpenAI client initialization"""
|
|
mock_client = Mock()
|
|
mock_openai_class.return_value = mock_client
|
|
|
|
client = OpenAIClient("gpt-4o", "test-key")
|
|
await client._initialize_client()
|
|
|
|
mock_openai_class.assert_called_once_with(api_key="test-key")
|
|
assert client._client == mock_client
|
|
|
|
@patch('mcp_template.llm_client.openai_client.AsyncOpenAI')
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_completion(self, mock_openai_class):
|
|
"""Test OpenAI chat completion"""
|
|
# Mock the OpenAI client and response
|
|
mock_client = Mock()
|
|
mock_response = Mock()
|
|
mock_choice = Mock()
|
|
mock_message = Mock()
|
|
|
|
mock_message.role = "assistant"
|
|
mock_message.content = "Test response"
|
|
mock_message.tool_calls = None
|
|
|
|
mock_choice.message = mock_message
|
|
mock_response.choices = [mock_choice]
|
|
|
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
mock_openai_class.return_value = mock_client
|
|
|
|
client = OpenAIClient("gpt-4o", "test-key")
|
|
await client.initialize()
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
result = await client.chat_completion(messages)
|
|
|
|
# Verify the call was made correctly
|
|
mock_client.chat.completions.create.assert_called_once()
|
|
call_args = mock_client.chat.completions.create.call_args[1]
|
|
|
|
assert call_args["model"] == "gpt-4o"
|
|
assert call_args["messages"] == messages
|
|
assert call_args["temperature"] == 0.7
|
|
assert call_args["max_tokens"] == 1000
|
|
|
|
# Verify response format
|
|
assert "choices" in result
|
|
assert len(result["choices"]) == 1
|
|
assert result["choices"][0]["message"]["content"] == "Test response"
|
|
|
|
def test_openai_tool_formatting(self):
|
|
"""Test OpenAI tool formatting"""
|
|
client = OpenAIClient("gpt-4o", "test-key")
|
|
|
|
tools = [
|
|
{
|
|
"name": "add",
|
|
"description": "Add two numbers",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"a": {"type": "number"},
|
|
"b": {"type": "number"}
|
|
},
|
|
"required": ["a", "b"]
|
|
}
|
|
}
|
|
]
|
|
|
|
formatted = client._format_tools_for_provider(tools)
|
|
|
|
assert len(formatted) == 1
|
|
assert formatted[0]["type"] == "function"
|
|
assert formatted[0]["function"]["name"] == "add"
|
|
assert formatted[0]["function"]["description"] == "Add two numbers"
|
|
assert "parameters" in formatted[0]["function"]
|
|
|
|
|
|
class TestAIClientFactory:
|
|
"""Test AI client factory"""
|
|
|
|
def test_create_openai_client(self):
|
|
"""Test creating OpenAI client via factory"""
|
|
client = AIClientFactory.create_openai_client(
|
|
model_name="gpt-4o",
|
|
api_key="test-key"
|
|
)
|
|
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.model_name == "gpt-4o"
|
|
|
|
def test_create_client_by_provider(self):
|
|
"""Test creating client by provider name"""
|
|
client = AIClientFactory.create_client(
|
|
provider="openai",
|
|
model_name="gpt-4o",
|
|
api_key="test-key"
|
|
)
|
|
|
|
assert isinstance(client, OpenAIClient)
|
|
assert client.model_name == "gpt-4o"
|
|
|
|
def test_invalid_provider(self):
|
|
"""Test creating client with invalid provider"""
|
|
with pytest.raises(ValueError, match="Unsupported AI provider"):
|
|
AIClientFactory.create_client("invalid", "model", "key")
|
|
|
|
@patch('mcp_template.llm_client.client_factory.CONFIG_AVAILABLE', True)
|
|
@patch('mcp_template.llm_client.client_factory.Config')
|
|
def test_missing_api_key(self, mock_config):
|
|
"""Test creating client without API key"""
|
|
# Mock config to return None for API key
|
|
mock_config.OPENAI_API_KEY = None
|
|
|
|
with pytest.raises(ValueError, match="API key not provided and could not be loaded from config for provider: openai"):
|
|
AIClientFactory.create_client("openai", "gpt-4o")
|
|
|
|
def test_available_providers(self):
|
|
"""Test getting available providers"""
|
|
providers = AIClientFactory.get_available_providers()
|
|
assert "openai" in providers
|
|
assert "claude" in providers
|
|
assert "grok" in providers
|
|
|
|
def test_validate_provider(self):
|
|
"""Test provider validation"""
|
|
assert AIClientFactory.validate_provider("openai") is True
|
|
assert AIClientFactory.validate_provider("claude") is True
|
|
assert AIClientFactory.validate_provider("invalid") is False
|