310 lines
11 KiB
Python
310 lines
11 KiB
Python
|
|
#!/usr/bin/env python
|
||
|
|
"""
|
||
|
|
Test script for RAG functionality in the AI service.
|
||
|
|
This script tests the document-based question answering capabilities
|
||
|
|
by making requests to the API endpoints.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import json
|
||
|
|
import argparse
|
||
|
|
import requests
|
||
|
|
from typing import Dict, Any, Optional
|
||
|
|
from pprint import pprint
|
||
|
|
|
||
|
|
# Default configuration
|
||
|
|
DEFAULT_API_URL = "http://localhost:5252" # Local development server
|
||
|
|
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
|
||
|
|
DEFAULT_MODEL = "llama3.1"
|
||
|
|
|
||
|
|
class RAGTester:
|
||
|
|
"""Test the RAG functionality of the AI service."""
|
||
|
|
|
||
|
|
def __init__(self, api_url: str, verbose: bool = False):
|
||
|
|
"""
|
||
|
|
Initialize the RAG tester.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
api_url: URL of the AI service API.
|
||
|
|
verbose: Whether to print verbose output.
|
||
|
|
"""
|
||
|
|
self.api_url = api_url
|
||
|
|
self.verbose = verbose
|
||
|
|
self.session = requests.Session()
|
||
|
|
|
||
|
|
# Print configuration
|
||
|
|
print(f"Testing RAG functionality against API at: {self.api_url}")
|
||
|
|
|
||
|
|
def _log(self, message: str):
|
||
|
|
"""Log a message if verbose mode is enabled."""
|
||
|
|
if self.verbose:
|
||
|
|
print(f"[DEBUG] {message}")
|
||
|
|
|
||
|
|
def check_server_health(self) -> bool:
|
||
|
|
"""
|
||
|
|
Check if the server is healthy.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the server is healthy, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
response = self.session.get(f"{self.api_url}/health", timeout=10)
|
||
|
|
response.raise_for_status()
|
||
|
|
result = response.json()
|
||
|
|
|
||
|
|
if result.get("status") == "healthy":
|
||
|
|
print("✅ Server is healthy")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print("❌ Server health check failed")
|
||
|
|
print(f"Response: {result}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error checking server health: {str(e)}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def check_config(self) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Check the server configuration.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Server configuration.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
response = self.session.get(f"{self.api_url}/config", timeout=10)
|
||
|
|
response.raise_for_status()
|
||
|
|
config = response.json()
|
||
|
|
|
||
|
|
print("✅ Server configuration:")
|
||
|
|
print(f" - API Host: {config.get('api_host')}")
|
||
|
|
print(f" - API Port: {config.get('api_port')}")
|
||
|
|
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
|
||
|
|
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
|
||
|
|
print(f" - Default Model: {config.get('default_model')}")
|
||
|
|
print(f" - API Timeout: {config.get('api_timeout')} seconds")
|
||
|
|
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
|
||
|
|
|
||
|
|
return config
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error checking server configuration: {str(e)}")
|
||
|
|
return {}
|
||
|
|
|
||
|
|
def test_ollama_connection(self) -> bool:
|
||
|
|
"""
|
||
|
|
Test the connection to Ollama.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the connection is successful, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
|
||
|
|
response.raise_for_status()
|
||
|
|
result = response.json()
|
||
|
|
|
||
|
|
if result.get("status") == "success":
|
||
|
|
print("✅ Successfully connected to Ollama API")
|
||
|
|
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||
|
|
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print("❌ Failed to connect to Ollama API")
|
||
|
|
print(f" - Error: {result.get('message')}")
|
||
|
|
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error testing Ollama connection: {str(e)}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
|
||
|
|
"""
|
||
|
|
Test basic chat completion without RAG.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: Optional model ID to use.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the test is successful, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
|
||
|
|
response.raise_for_status()
|
||
|
|
result = response.json()
|
||
|
|
|
||
|
|
if result.get("status") == "success":
|
||
|
|
print("✅ Successfully tested chat completion")
|
||
|
|
print(f" - Model: {result.get('model')}")
|
||
|
|
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print("❌ Failed to test chat completion")
|
||
|
|
print(f" - Error: {result.get('message')}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error testing chat completion: {str(e)}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
|
||
|
|
"""
|
||
|
|
Test RAG completion with a query.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: Query to test with RAG.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the test is successful, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
response = self.session.post(
|
||
|
|
f"{self.api_url}/test-rag",
|
||
|
|
params={"query": query},
|
||
|
|
timeout=120 # Longer timeout for RAG
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
result = response.json()
|
||
|
|
|
||
|
|
if result.get("status") == "success":
|
||
|
|
print("✅ Successfully tested RAG completion")
|
||
|
|
print(f" - Model: {result.get('model')}")
|
||
|
|
print(f" - Query: {result.get('query')}")
|
||
|
|
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||
|
|
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
|
||
|
|
|
||
|
|
# Print full response in verbose mode
|
||
|
|
if self.verbose:
|
||
|
|
print("\nFull response:")
|
||
|
|
print(result.get('response'))
|
||
|
|
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
print("❌ Failed to test RAG completion")
|
||
|
|
print(f" - Error: {result.get('message')}")
|
||
|
|
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error testing RAG completion: {str(e)}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
|
||
|
|
"""
|
||
|
|
Create a chat and test RAG with a message.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: Query to test with RAG.
|
||
|
|
user_id: User ID for the chat.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the test is successful, False otherwise.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Create a chat
|
||
|
|
create_response = self.session.post(
|
||
|
|
f"{self.api_url}/chats",
|
||
|
|
json={
|
||
|
|
"user_id": user_id,
|
||
|
|
"title": "RAG Test Chat",
|
||
|
|
"model_id": DEFAULT_MODEL,
|
||
|
|
"is_team_chat": False
|
||
|
|
},
|
||
|
|
timeout=30
|
||
|
|
)
|
||
|
|
create_response.raise_for_status()
|
||
|
|
chat = create_response.json()
|
||
|
|
chat_id = chat.get("id")
|
||
|
|
|
||
|
|
print(f"✅ Created chat with ID: {chat_id}")
|
||
|
|
|
||
|
|
# Send a message with RAG enabled
|
||
|
|
message_response = self.session.post(
|
||
|
|
f"{self.api_url}/chats/{chat_id}/messages",
|
||
|
|
json={
|
||
|
|
"message": query,
|
||
|
|
"user_id": user_id,
|
||
|
|
"use_rag": True,
|
||
|
|
"temperature": 0.7,
|
||
|
|
"max_tokens": 1000
|
||
|
|
},
|
||
|
|
timeout=120 # Longer timeout for RAG
|
||
|
|
)
|
||
|
|
message_response.raise_for_status()
|
||
|
|
message = message_response.json()
|
||
|
|
|
||
|
|
print("✅ Successfully sent message with RAG")
|
||
|
|
print(f" - Message ID: {message.get('id')}")
|
||
|
|
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
|
||
|
|
|
||
|
|
# Print full response in verbose mode
|
||
|
|
if self.verbose:
|
||
|
|
print("\nFull response:")
|
||
|
|
print(message.get('content'))
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error testing chat with RAG: {str(e)}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
|
||
|
|
"""
|
||
|
|
Run all tests.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: Query to test with RAG.
|
||
|
|
"""
|
||
|
|
print("\n=== Running RAG Functionality Tests ===\n")
|
||
|
|
|
||
|
|
# Check server health
|
||
|
|
if not self.check_server_health():
|
||
|
|
print("❌ Server health check failed. Aborting tests.")
|
||
|
|
return
|
||
|
|
|
||
|
|
# Check configuration
|
||
|
|
config = self.check_config()
|
||
|
|
if not config:
|
||
|
|
print("❌ Failed to get server configuration. Continuing with tests...")
|
||
|
|
|
||
|
|
# Test Ollama connection
|
||
|
|
if not self.test_ollama_connection():
|
||
|
|
print("⚠️ Ollama connection test failed. Some tests may fail.")
|
||
|
|
|
||
|
|
# Test basic chat completion
|
||
|
|
if not self.test_chat_completion():
|
||
|
|
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
|
||
|
|
|
||
|
|
# Test RAG completion
|
||
|
|
print("\n--- Testing RAG Completion ---\n")
|
||
|
|
self.test_rag_completion(query)
|
||
|
|
|
||
|
|
# Test chat with RAG
|
||
|
|
print("\n--- Testing Chat with RAG ---\n")
|
||
|
|
self.create_chat_and_test_rag(query)
|
||
|
|
|
||
|
|
print("\n=== Tests Completed ===\n")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main function to run the tests."""
|
||
|
|
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
|
||
|
|
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
|
||
|
|
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
|
||
|
|
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
|
||
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Use remote URL if specified
|
||
|
|
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
|
||
|
|
|
||
|
|
# Create and run the tester
|
||
|
|
tester = RAGTester(api_url=api_url, verbose=args.verbose)
|
||
|
|
tester.run_all_tests(query=args.query)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|