Files
ds_zagres_ai/test_rag.py
T

310 lines
11 KiB
Python
Raw Normal View History

2025-05-16 15:24:01 +01:00
#!/usr/bin/env python
"""
Test script for RAG functionality in the AI service.
This script tests the document-based question answering capabilities
by making requests to the API endpoints.
"""
import os
import sys
import json
import argparse
import requests
from typing import Dict, Any, Optional
from pprint import pprint
# Default configuration
DEFAULT_API_URL = "http://localhost:5252" # Local development server
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
DEFAULT_MODEL = "llama3.1"
class RAGTester:
"""Test the RAG functionality of the AI service."""
def __init__(self, api_url: str, verbose: bool = False):
"""
Initialize the RAG tester.
Args:
api_url: URL of the AI service API.
verbose: Whether to print verbose output.
"""
self.api_url = api_url
self.verbose = verbose
self.session = requests.Session()
# Print configuration
print(f"Testing RAG functionality against API at: {self.api_url}")
def _log(self, message: str):
"""Log a message if verbose mode is enabled."""
if self.verbose:
print(f"[DEBUG] {message}")
def check_server_health(self) -> bool:
"""
Check if the server is healthy.
Returns:
True if the server is healthy, False otherwise.
"""
try:
response = self.session.get(f"{self.api_url}/health", timeout=10)
response.raise_for_status()
result = response.json()
if result.get("status") == "healthy":
print("✅ Server is healthy")
return True
else:
print("❌ Server health check failed")
print(f"Response: {result}")
return False
except Exception as e:
print(f"❌ Error checking server health: {str(e)}")
return False
def check_config(self) -> Dict[str, Any]:
"""
Check the server configuration.
Returns:
Server configuration.
"""
try:
response = self.session.get(f"{self.api_url}/config", timeout=10)
response.raise_for_status()
config = response.json()
print("✅ Server configuration:")
print(f" - API Host: {config.get('api_host')}")
print(f" - API Port: {config.get('api_port')}")
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
print(f" - Default Model: {config.get('default_model')}")
print(f" - API Timeout: {config.get('api_timeout')} seconds")
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
return config
except Exception as e:
print(f"❌ Error checking server configuration: {str(e)}")
return {}
def test_ollama_connection(self) -> bool:
"""
Test the connection to Ollama.
Returns:
True if the connection is successful, False otherwise.
"""
try:
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully connected to Ollama API")
print(f" - Ollama URL: {result.get('ollama_url')}")
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
return True
else:
print("❌ Failed to connect to Ollama API")
print(f" - Error: {result.get('message')}")
print(f" - Ollama URL: {result.get('ollama_url')}")
return False
except Exception as e:
print(f"❌ Error testing Ollama connection: {str(e)}")
return False
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
"""
Test basic chat completion without RAG.
Args:
model_id: Optional model ID to use.
Returns:
True if the test is successful, False otherwise.
"""
try:
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully tested chat completion")
print(f" - Model: {result.get('model')}")
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
return True
else:
print("❌ Failed to test chat completion")
print(f" - Error: {result.get('message')}")
return False
except Exception as e:
print(f"❌ Error testing chat completion: {str(e)}")
return False
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
"""
Test RAG completion with a query.
Args:
query: Query to test with RAG.
Returns:
True if the test is successful, False otherwise.
"""
try:
response = self.session.post(
f"{self.api_url}/test-rag",
params={"query": query},
timeout=120 # Longer timeout for RAG
)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully tested RAG completion")
print(f" - Model: {result.get('model')}")
print(f" - Query: {result.get('query')}")
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
# Print full response in verbose mode
if self.verbose:
print("\nFull response:")
print(result.get('response'))
return True
else:
print("❌ Failed to test RAG completion")
print(f" - Error: {result.get('message')}")
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
return False
except Exception as e:
print(f"❌ Error testing RAG completion: {str(e)}")
return False
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
"""
Create a chat and test RAG with a message.
Args:
query: Query to test with RAG.
user_id: User ID for the chat.
Returns:
True if the test is successful, False otherwise.
"""
try:
# Create a chat
create_response = self.session.post(
f"{self.api_url}/chats",
json={
"user_id": user_id,
"title": "RAG Test Chat",
"model_id": DEFAULT_MODEL,
"is_team_chat": False
},
timeout=30
)
create_response.raise_for_status()
chat = create_response.json()
chat_id = chat.get("id")
print(f"✅ Created chat with ID: {chat_id}")
# Send a message with RAG enabled
message_response = self.session.post(
f"{self.api_url}/chats/{chat_id}/messages",
json={
"message": query,
"user_id": user_id,
"use_rag": True,
"temperature": 0.7,
"max_tokens": 1000
},
timeout=120 # Longer timeout for RAG
)
message_response.raise_for_status()
message = message_response.json()
print("✅ Successfully sent message with RAG")
print(f" - Message ID: {message.get('id')}")
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
# Print full response in verbose mode
if self.verbose:
print("\nFull response:")
print(message.get('content'))
return True
except Exception as e:
print(f"❌ Error testing chat with RAG: {str(e)}")
return False
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
"""
Run all tests.
Args:
query: Query to test with RAG.
"""
print("\n=== Running RAG Functionality Tests ===\n")
# Check server health
if not self.check_server_health():
print("❌ Server health check failed. Aborting tests.")
return
# Check configuration
config = self.check_config()
if not config:
print("❌ Failed to get server configuration. Continuing with tests...")
# Test Ollama connection
if not self.test_ollama_connection():
print("⚠️ Ollama connection test failed. Some tests may fail.")
# Test basic chat completion
if not self.test_chat_completion():
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
# Test RAG completion
print("\n--- Testing RAG Completion ---\n")
self.test_rag_completion(query)
# Test chat with RAG
print("\n--- Testing Chat with RAG ---\n")
self.create_chat_and_test_rag(query)
print("\n=== Tests Completed ===\n")
def main():
"""Main function to run the tests."""
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
args = parser.parse_args()
# Use remote URL if specified
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
# Create and run the tester
tester = RAGTester(api_url=api_url, verbose=args.verbose)
tester.run_all_tests(query=args.query)
if __name__ == "__main__":
main()