Added Rag Featured
This commit is contained in:
Executable
+309
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for RAG functionality in the AI service.
|
||||
This script tests the document-based question answering capabilities
|
||||
by making requests to the API endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import requests
|
||||
from typing import Dict, Any, Optional
|
||||
from pprint import pprint
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_API_URL = "http://localhost:5252" # Local development server
|
||||
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
|
||||
DEFAULT_MODEL = "llama3.1"
|
||||
|
||||
class RAGTester:
|
||||
"""Test the RAG functionality of the AI service."""
|
||||
|
||||
def __init__(self, api_url: str, verbose: bool = False):
|
||||
"""
|
||||
Initialize the RAG tester.
|
||||
|
||||
Args:
|
||||
api_url: URL of the AI service API.
|
||||
verbose: Whether to print verbose output.
|
||||
"""
|
||||
self.api_url = api_url
|
||||
self.verbose = verbose
|
||||
self.session = requests.Session()
|
||||
|
||||
# Print configuration
|
||||
print(f"Testing RAG functionality against API at: {self.api_url}")
|
||||
|
||||
def _log(self, message: str):
|
||||
"""Log a message if verbose mode is enabled."""
|
||||
if self.verbose:
|
||||
print(f"[DEBUG] {message}")
|
||||
|
||||
def check_server_health(self) -> bool:
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
Returns:
|
||||
True if the server is healthy, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/health", timeout=10)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "healthy":
|
||||
print("✅ Server is healthy")
|
||||
return True
|
||||
else:
|
||||
print("❌ Server health check failed")
|
||||
print(f"Response: {result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking server health: {str(e)}")
|
||||
return False
|
||||
|
||||
def check_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check the server configuration.
|
||||
|
||||
Returns:
|
||||
Server configuration.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/config", timeout=10)
|
||||
response.raise_for_status()
|
||||
config = response.json()
|
||||
|
||||
print("✅ Server configuration:")
|
||||
print(f" - API Host: {config.get('api_host')}")
|
||||
print(f" - API Port: {config.get('api_port')}")
|
||||
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
|
||||
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
|
||||
print(f" - Default Model: {config.get('default_model')}")
|
||||
print(f" - API Timeout: {config.get('api_timeout')} seconds")
|
||||
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking server configuration: {str(e)}")
|
||||
return {}
|
||||
|
||||
def test_ollama_connection(self) -> bool:
|
||||
"""
|
||||
Test the connection to Ollama.
|
||||
|
||||
Returns:
|
||||
True if the connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully connected to Ollama API")
|
||||
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to connect to Ollama API")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing Ollama connection: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Test basic chat completion without RAG.
|
||||
|
||||
Args:
|
||||
model_id: Optional model ID to use.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully tested chat completion")
|
||||
print(f" - Model: {result.get('model')}")
|
||||
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to test chat completion")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing chat completion: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
|
||||
"""
|
||||
Test RAG completion with a query.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.api_url}/test-rag",
|
||||
params={"query": query},
|
||||
timeout=120 # Longer timeout for RAG
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully tested RAG completion")
|
||||
print(f" - Model: {result.get('model')}")
|
||||
print(f" - Query: {result.get('query')}")
|
||||
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
|
||||
|
||||
# Print full response in verbose mode
|
||||
if self.verbose:
|
||||
print("\nFull response:")
|
||||
print(result.get('response'))
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to test RAG completion")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing RAG completion: {str(e)}")
|
||||
return False
|
||||
|
||||
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
|
||||
"""
|
||||
Create a chat and test RAG with a message.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
user_id: User ID for the chat.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Create a chat
|
||||
create_response = self.session.post(
|
||||
f"{self.api_url}/chats",
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"title": "RAG Test Chat",
|
||||
"model_id": DEFAULT_MODEL,
|
||||
"is_team_chat": False
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
chat = create_response.json()
|
||||
chat_id = chat.get("id")
|
||||
|
||||
print(f"✅ Created chat with ID: {chat_id}")
|
||||
|
||||
# Send a message with RAG enabled
|
||||
message_response = self.session.post(
|
||||
f"{self.api_url}/chats/{chat_id}/messages",
|
||||
json={
|
||||
"message": query,
|
||||
"user_id": user_id,
|
||||
"use_rag": True,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
timeout=120 # Longer timeout for RAG
|
||||
)
|
||||
message_response.raise_for_status()
|
||||
message = message_response.json()
|
||||
|
||||
print("✅ Successfully sent message with RAG")
|
||||
print(f" - Message ID: {message.get('id')}")
|
||||
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
|
||||
|
||||
# Print full response in verbose mode
|
||||
if self.verbose:
|
||||
print("\nFull response:")
|
||||
print(message.get('content'))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing chat with RAG: {str(e)}")
|
||||
return False
|
||||
|
||||
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
|
||||
"""
|
||||
Run all tests.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
"""
|
||||
print("\n=== Running RAG Functionality Tests ===\n")
|
||||
|
||||
# Check server health
|
||||
if not self.check_server_health():
|
||||
print("❌ Server health check failed. Aborting tests.")
|
||||
return
|
||||
|
||||
# Check configuration
|
||||
config = self.check_config()
|
||||
if not config:
|
||||
print("❌ Failed to get server configuration. Continuing with tests...")
|
||||
|
||||
# Test Ollama connection
|
||||
if not self.test_ollama_connection():
|
||||
print("⚠️ Ollama connection test failed. Some tests may fail.")
|
||||
|
||||
# Test basic chat completion
|
||||
if not self.test_chat_completion():
|
||||
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
|
||||
|
||||
# Test RAG completion
|
||||
print("\n--- Testing RAG Completion ---\n")
|
||||
self.test_rag_completion(query)
|
||||
|
||||
# Test chat with RAG
|
||||
print("\n--- Testing Chat with RAG ---\n")
|
||||
self.create_chat_and_test_rag(query)
|
||||
|
||||
print("\n=== Tests Completed ===\n")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the tests."""
|
||||
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
|
||||
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
|
||||
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
|
||||
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use remote URL if specified
|
||||
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
|
||||
|
||||
# Create and run the tester
|
||||
tester = RAGTester(api_url=api_url, verbose=args.verbose)
|
||||
tester.run_all_tests(query=args.query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user