Added Rag Featured

This commit is contained in:
Iyeoluwa Akinrinola
2025-05-16 15:24:01 +01:00
parent e82861a5db
commit 1896298a18
30 changed files with 503 additions and 1580 deletions
+73
View File
@@ -0,0 +1,73 @@
# Document-Based Question Answering with OpenWebUI
This document explains how to use the document-based question answering (RAG) functionality in the AI service.
## Overview
The AI service now supports Retrieval Augmented Generation (RAG) by leveraging OpenWebUI's built-in knowledge database. This allows users to:
1. Upload documents to OpenWebUI
2. Ask questions about those documents
3. Receive responses that incorporate information from the documents
## How It Works
When RAG is enabled:
1. The AI service forwards the request to OpenWebUI with `use_knowledge=True`
2. OpenWebUI searches its knowledge database for relevant information
3. The retrieved information is used to augment the model's response
4. The response is returned to the user
## Using RAG in API Requests
To enable RAG in your API requests, set the `use_rag` parameter to `true`:
```json
POST /chats/{chat_id}/messages
{
"message": "What information do you have about project X?",
"user_id": "user123",
"use_rag": true
}
```
## Testing RAG Functionality
You can test the RAG functionality using the `/test-rag` endpoint:
```
POST /test-rag?query=What information do you have about project X?
```
This will return a response that includes information from documents in OpenWebUI's knowledge database.
## Uploading Documents to OpenWebUI
To use RAG effectively, you need to upload documents to OpenWebUI:
1. Log in to OpenWebUI at your configured URL (default: http://104.225.217.215:8080)
2. Navigate to the Knowledge section
3. Upload your documents (PDF, TXT, DOCX, etc.)
4. OpenWebUI will automatically process and index the documents
## Troubleshooting
If RAG is not working as expected:
1. Ensure OpenWebUI is running and accessible
2. Check that documents are properly uploaded and indexed in OpenWebUI
3. Verify that the `use_rag` parameter is set to `true` in your requests
4. Check the logs for any errors related to OpenWebUI API calls
If there are connection issues with OpenWebUI, the AI service will automatically fall back to using the direct Ollama API without RAG.
## Configuration
The following configuration settings affect RAG functionality:
- `OPENWEBUI_URL`: URL of your OpenWebUI instance
- `OPENWEBUI_API_KEY`: API key for OpenWebUI (if required)
- `API_TIMEOUT`: Timeout for API requests (in seconds)
These can be set in your environment variables or in the `.env` file.
+42
View File
@@ -1,6 +1,11 @@
""" """
FastAPI application for the AI service. FastAPI application for the AI service.
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints. This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
The service supports document-based question answering using OpenWebUI's knowledge database:
- Set use_rag=True in API requests to enable Retrieval Augmented Generation
- When enabled, the service will use OpenWebUI's knowledge database to find relevant information
- Documents uploaded to OpenWebUI will be used to augment the model's responses
""" """
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
@@ -203,6 +208,43 @@ async def test_chat_completion():
"ollama_url": config.OLLAMA_API_URL "ollama_url": config.OLLAMA_API_URL
} }
@app.post("/test-rag")
async def test_rag_completion(query: str = "What information do you have in your knowledge database?"):
"""
Test the RAG (Retrieval Augmented Generation) functionality with a query.
This endpoint tests the integration with OpenWebUI's knowledge database.
Args:
query: The question to ask about documents in the knowledge database.
Returns:
Model response using RAG.
"""
try:
# Use the model service directly with RAG enabled
response = model_service.generate_response(
model_id=config.DEFAULT_MODEL,
prompt=query,
context=[],
use_rag=True # Enable RAG
)
return {
"status": "success",
"model": config.DEFAULT_MODEL,
"query": query,
"use_rag": True,
"response": response,
"openwebui_url": config.OPENWEBUI_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get RAG completion: {str(e)}",
"openwebui_url": config.OPENWEBUI_URL
}
@app.post("/test-ollama-direct") @app.post("/test-ollama-direct")
async def test_ollama_direct(): async def test_ollama_direct():
""" """
+27
View File
@@ -198,5 +198,32 @@
} }
], ],
"team_members": [] "team_members": []
},
"aefe133e-e85e-422f-9ca9-6441f24db74f": {
"id": "aefe133e-e85e-422f-9ca9-6441f24db74f",
"title": "RAG Test Chat",
"user_id": "test_user",
"model_id": "llama3.1",
"is_team_chat": false,
"created_at": "2025-05-16T14:41:08.281043",
"updated_at": "2025-05-16T14:41:09.142922",
"messages": [
{
"id": "35384001-0c6d-4ac6-a247-32fb53bb664c",
"content": "What information do you have in your knowledge database?",
"user_id": "test_user",
"is_user_message": true,
"timestamp": "2025-05-16T14:41:08.341940"
},
{
"id": "4b82555f-76e2-43c0-a1aa-c91257536133",
"content": "Connection error to Ollama API: HTTPConnectionPool(host='104.225.217.215', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x104bcc750>: Failed to establish a new connection: [Errno 61] Connection refused')). Please check if Ollama is running at http://104.225.217.215:11434.",
"user_id": null,
"is_user_message": false,
"timestamp": "2025-05-16T14:41:09.142892"
}
],
"team_members": [],
"openwebui_channel_id": null
} }
} }
-261
View File
@@ -1,261 +0,0 @@
"""
Service for document processing and chunking.
"""
import os
import json
import uuid
import requests
import base64
from typing import List, Dict, Any, Optional
from langchain_text_splitters import RecursiveCharacterTextSplitter
from ai_service.config import config
class DocumentService:
"""Service for document processing and chunking."""
def __init__(self):
"""Initialize the document service."""
self.chunk_size = config.CHUNK_SIZE
self.chunk_overlap = config.CHUNK_OVERLAP
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len
)
# OpenWebUI configuration
self.openwebui_url = config.OPENWEBUI_URL
self.openwebui_api_key = config.OPENWEBUI_API_KEY
# Ensure data directory exists
os.makedirs('ai_service/data', exist_ok=True)
# For now, we'll store document metadata in a simple JSON file
self.metadata_file = 'ai_service/data/document_metadata.json'
self._load_metadata()
def _load_metadata(self):
"""Load document metadata from file."""
if os.path.exists(self.metadata_file):
try:
with open(self.metadata_file, 'r') as f:
self.documents = json.load(f)
except Exception as e:
print(f"Error loading document metadata: {str(e)}")
self.documents = {}
else:
self.documents = {}
def _save_metadata(self):
"""Save document metadata to file."""
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.documents, f, indent=2)
except Exception as e:
print(f"Error saving document metadata: {str(e)}")
def process_document(self, content: str, title: str,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Process a document for embedding.
Args:
content: Document content.
title: Document title.
description: Optional document description.
metadata: Optional additional metadata.
Returns:
Document ID.
"""
# Generate a unique ID for the document
doc_id = str(uuid.uuid4())
# Upload the document to OpenWebUI for RAG processing
try:
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.openwebui_api_key:
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
# Prepare the document data
document_data = {
"filename": f"{title}.txt",
"content": base64.b64encode(content.encode('utf-8')).decode('utf-8'),
"description": description or title
}
# Upload to OpenWebUI
response = requests.post(
f"{self.openwebui_url}/api/knowledge/upload",
headers=headers,
json=document_data,
timeout=60
)
response.raise_for_status()
result = response.json()
# Get the OpenWebUI document ID
openwebui_doc_id = result.get('id', '')
# Store document metadata
self.documents[doc_id] = {
'id': doc_id,
'title': title,
'description': description or '',
'openwebui_id': openwebui_doc_id,
'metadata': metadata or {}
}
# Save metadata to file
self._save_metadata()
return doc_id
except Exception as e:
print(f"Error uploading document to OpenWebUI: {str(e)}")
# Fall back to local processing if OpenWebUI upload fails
print("Falling back to local document processing")
# Split the document into chunks for local reference
chunks = self.text_splitter.split_text(content)
# Store document metadata
self.documents[doc_id] = {
'id': doc_id,
'title': title,
'description': description or '',
'chunk_count': len(chunks),
'openwebui_upload_failed': True,
'metadata': metadata or {}
}
# Save metadata to file
self._save_metadata()
return doc_id
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""
Get document metadata.
Args:
doc_id: Document ID.
Returns:
Document metadata if found, None otherwise.
"""
return self.documents.get(doc_id)
def get_all_documents(self) -> List[Dict[str, Any]]:
"""
Get all document metadata.
Returns:
List of document metadata.
"""
# Get documents from local storage
local_documents = list(self.documents.values())
# Try to get documents from OpenWebUI as well
try:
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.openwebui_api_key:
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
# Get documents from OpenWebUI
response = requests.get(
f"{self.openwebui_url}/api/knowledge",
headers=headers,
timeout=30
)
if response.status_code == 200:
openwebui_docs = response.json()
# Update local documents with OpenWebUI information
for doc in local_documents:
if 'openwebui_id' in doc:
for openwebui_doc in openwebui_docs:
if openwebui_doc.get('id') == doc['openwebui_id']:
doc['openwebui_status'] = 'active'
doc['openwebui_info'] = openwebui_doc
break
except Exception as e:
print(f"Error getting documents from OpenWebUI: {str(e)}")
return local_documents
def delete_document(self, doc_id: str) -> bool:
"""
Delete a document and its chunks.
Args:
doc_id: Document ID.
Returns:
True if deletion was successful, False otherwise.
"""
if doc_id not in self.documents:
return False
# Check if document was uploaded to OpenWebUI
doc = self.documents[doc_id]
openwebui_id = doc.get('openwebui_id')
if openwebui_id:
try:
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.openwebui_api_key:
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
# Delete from OpenWebUI
response = requests.delete(
f"{self.openwebui_url}/api/knowledge/{openwebui_id}",
headers=headers,
timeout=30
)
if response.status_code != 200:
print(f"Warning: Failed to delete document from OpenWebUI: {response.text}")
except Exception as e:
print(f"Error deleting document from OpenWebUI: {str(e)}")
# Delete document metadata
del self.documents[doc_id]
# Save metadata to file
self._save_metadata()
return True
def search_documents(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Search for documents similar to a query.
Args:
query: Search query.
top_k: Number of results to return.
Returns:
List of similar document chunks with their metadata.
"""
# Note: We don't need to implement this method anymore since
# RAG is handled directly by OpenWebUI when use_rag=True in the model service
# Return empty results - this is just a placeholder
# The actual RAG functionality is in the model_service.generate_response method
return []
# Create a singleton instance
document_service = DocumentService()
+19 -5
View File
@@ -135,9 +135,6 @@ class ModelService:
model_id = self.default_model model_id = self.default_model
print(f" - Model not found, using default: {model_id}") print(f" - Model not found, using default: {model_id}")
# Ensure we're using a valid model
# (model_id is already validated above)
# Prepare the messages for the API call # Prepare the messages for the API call
messages = [] messages = []
@@ -178,6 +175,10 @@ class ModelService:
openwebui_request['max_tokens'] = params['max_tokens'] openwebui_request['max_tokens'] = params['max_tokens']
if 'top_p' in params: if 'top_p' in params:
openwebui_request['top_p'] = params['top_p'] openwebui_request['top_p'] = params['top_p']
if 'top_k' in params:
openwebui_request['top_k'] = params['top_k']
if 'repeat_penalty' in params:
openwebui_request['repeat_penalty'] = params['repeat_penalty']
# Make the API call to OpenWebUI # Make the API call to OpenWebUI
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@@ -201,10 +202,15 @@ class ModelService:
result = response.json() result = response.json()
# Extract the response content # Extract the response content
if 'message' in result: if 'choices' in result and len(result['choices']) > 0 and 'message' in result['choices'][0]:
# OpenAI-compatible format
return result['choices'][0]['message']['content']
elif 'message' in result and 'content' in result['message']:
# OpenWebUI format
return result['message']['content'] return result['message']['content']
else: else:
return "Error: Unexpected response format from OpenWebUI" print(f"WARNING: Unexpected response format from OpenWebUI: {json.dumps(result, indent=2)}")
return "Error: Unexpected response format from OpenWebUI. Falling back to direct model call."
except requests.exceptions.Timeout as e: except requests.exceptions.Timeout as e:
error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout." error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout."
@@ -216,6 +222,11 @@ class ModelService:
print(f"ERROR: {error_msg}") print(f"ERROR: {error_msg}")
print("Falling back to direct Ollama call without RAG") print("Falling back to direct Ollama call without RAG")
# Continue to the Ollama API call below # Continue to the Ollama API call below
except requests.exceptions.HTTPError as e:
error_msg = f"HTTP error from OpenWebUI API: {str(e)}."
print(f"ERROR: {error_msg}")
print("Falling back to direct Ollama call without RAG")
# Continue to the Ollama API call below
except Exception as e: except Exception as e:
error_msg = f"Error calling OpenWebUI API: {str(e)}" error_msg = f"Error calling OpenWebUI API: {str(e)}"
print(f"ERROR: {error_msg}") print(f"ERROR: {error_msg}")
@@ -247,6 +258,8 @@ class ModelService:
request_json['top_k'] = params['top_k'] request_json['top_k'] = params['top_k']
if 'max_tokens' in params: if 'max_tokens' in params:
request_json['max_tokens'] = params['max_tokens'] request_json['max_tokens'] = params['max_tokens']
if 'repeat_penalty' in params:
request_json['repeat_penalty'] = params['repeat_penalty']
# Make the API call to Ollama # Make the API call to Ollama
try: try:
@@ -272,6 +285,7 @@ class ModelService:
if 'message' in result and 'content' in result['message']: if 'message' in result and 'content' in result['message']:
return result['message']['content'] return result['message']['content']
else: else:
print(f"WARNING: Unexpected response format from Ollama: {json.dumps(result, indent=2)}")
return "Error: Unexpected response format from Ollama" return "Error: Unexpected response format from Ollama"
except requests.exceptions.Timeout as e: except requests.exceptions.Timeout as e:
+9
View File
@@ -71,10 +71,16 @@ async def chat_completions(request: Request):
temperature = body.get("temperature") temperature = body.get("temperature")
max_tokens = body.get("max_tokens") max_tokens = body.get("max_tokens")
top_p = body.get("top_p") top_p = body.get("top_p")
top_k = body.get("top_k")
frequency_penalty = body.get("frequency_penalty") frequency_penalty = body.get("frequency_penalty")
presence_penalty = body.get("presence_penalty") presence_penalty = body.get("presence_penalty")
repeat_penalty = body.get("repeat_penalty")
stop = body.get("stop") stop = body.get("stop")
# Check if RAG should be used
use_knowledge = body.get("use_knowledge", False)
use_rag = use_knowledge # Map OpenWebUI's use_knowledge to our use_rag parameter
# Create a unique chat ID # Create a unique chat ID
chat_id = str(uuid.uuid4()) chat_id = str(uuid.uuid4())
@@ -99,11 +105,14 @@ async def chat_completions(request: Request):
chat_id=chat_id, chat_id=chat_id,
message=user_message, message=user_message,
user_id=user_id, user_id=user_id,
use_rag=use_rag,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
top_k=top_k,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
) )
-34
View File
@@ -1,34 +0,0 @@
"""
Main application package for the chatbot application.
"""
from flask import Flask
from app.config.config import Config
def create_app(config_class=Config):
"""
Create and configure the Flask application.
Args:
config_class: Configuration class to use.
Returns:
Flask application instance.
"""
# Initialize Flask app
flask_app = Flask(__name__)
flask_app.config.from_object(config_class)
# Register Flask routes
from app.api import routes as flask_routes
flask_app.register_blueprint(flask_routes.bp)
# For now, we'll use only Flask routes and disable FastAPI integration
# until we resolve the integration issues
# Initialize database
from app.database import db
db.init_app(flask_app)
return flask_app
View File
-110
View File
@@ -1,110 +0,0 @@
"""
FastAPI routes for the application.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from app.services.chatbot_service import chatbot_service
router = APIRouter()
class MessageRequest(BaseModel):
"""Request model for sending a message."""
message: str
user_id: str = "default_user"
class MessageResponse(BaseModel):
"""Response model for a message."""
content: str
is_user: bool
timestamp: str
class ChatResponse(BaseModel):
"""Response model for a chat."""
chat_id: int
messages: List[MessageResponse]
@router.get("/health")
async def health_check():
"""
Health check endpoint.
Returns:
JSON response with health status.
"""
return {"status": "healthy"}
@router.post("/chat", response_model=ChatResponse)
async def create_chat(user_id: str = "default_user"):
"""
Create a new chat.
Args:
user_id: ID of the user creating the chat.
Returns:
Created chat.
"""
chat_id = chatbot_service.create_chat(user_id)
return {
"chat_id": chat_id,
"messages": []
}
@router.post("/chat/{chat_id}/message", response_model=MessageResponse)
async def send_message(chat_id: int, request: MessageRequest):
"""
Send a message to the chatbot.
Args:
chat_id: ID of the chat.
request: Message request.
Returns:
Bot response.
"""
try:
response = chatbot_service.get_response(chat_id, request.message)
# Get the last message (bot response)
messages = chatbot_service.get_chat_messages(chat_id)
last_message = messages[-1]
return last_message
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/chat/{chat_id}", response_model=ChatResponse)
async def get_chat(chat_id: int):
"""
Get a chat by ID.
Args:
chat_id: ID of the chat.
Returns:
Chat with messages.
"""
try:
messages = chatbot_service.get_chat_messages(chat_id)
return {
"chat_id": chat_id,
"messages": messages
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
def init_app(app):
"""
Initialize FastAPI application with routes.
Args:
app: FastAPI application instance.
"""
app.include_router(router, prefix="/api")
-100
View File
@@ -1,100 +0,0 @@
"""
Flask routes for the application.
"""
from flask import Blueprint, jsonify, request, abort
from app.services.chatbot_service import chatbot_service
bp = Blueprint('main', __name__)
@bp.route('/')
def index():
"""
Root endpoint.
Returns:
JSON response with application information.
"""
return jsonify({
'name': 'Chatbot Application',
'version': '1.0.0',
'status': 'running'
})
@bp.route('/api/health')
def health_check():
"""
Health check endpoint.
Returns:
JSON response with health status.
"""
return jsonify({
'status': 'healthy'
})
@bp.route('/api/chat', methods=['POST'])
def create_chat():
"""
Create a new chat.
Returns:
JSON response with chat ID.
"""
user_id = request.json.get('user_id', 'default_user')
chat_id = chatbot_service.create_chat(user_id)
return jsonify({
'chat_id': chat_id,
'messages': []
})
@bp.route('/api/chat/<int:chat_id>/message', methods=['POST'])
def send_message(chat_id):
"""
Send a message to the chatbot.
Args:
chat_id: ID of the chat.
Returns:
JSON response with bot response.
"""
if not request.json or 'message' not in request.json:
abort(400, description="Message is required")
try:
message = request.json['message']
response = chatbot_service.get_response(chat_id, message)
# Get the last message (bot response)
messages = chatbot_service.get_chat_messages(chat_id)
last_message = messages[-1]
return jsonify(last_message)
except ValueError as e:
abort(404, description=str(e))
@bp.route('/api/chat/<int:chat_id>', methods=['GET'])
def get_chat(chat_id):
"""
Get a chat by ID.
Args:
chat_id: ID of the chat.
Returns:
JSON response with chat messages.
"""
try:
messages = chatbot_service.get_chat_messages(chat_id)
return jsonify({
'chat_id': chat_id,
'messages': messages
})
except ValueError as e:
abort(404, description=str(e))
View File
-79
View File
@@ -1,79 +0,0 @@
"""
Configuration settings for the application.
"""
import os
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
class Config:
"""Base configuration."""
# Flask configuration
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-key-for-development-only')
DEBUG = False
TESTING = False
# Database configuration
SQLALCHEMY_DATABASE_URI = os.environ.get(
'DATABASE_URL',
'sqlite:///chatbot.db'
)
SQLALCHEMY_TRACK_MODIFICATIONS = False
INITIALIZE_DATABASE = os.environ.get('INITIALIZE_DATABASE', 'False').lower() == 'true'
# Pinecone configuration
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', '')
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT', '')
PINECONE_INDEX_NAME = os.environ.get('PINECONE_INDEX_NAME', 'chatbot-index')
# Model configuration
DEFAULT_MODEL = os.environ.get('DEFAULT_MODEL', 'gpt-3.5-turbo')
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
class DevelopmentConfig(Config):
"""Development configuration."""
DEBUG = True
class TestingConfig(Config):
"""Testing configuration."""
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
class ProductionConfig(Config):
"""Production configuration."""
# Ensure all required environment variables are set in production
@classmethod
def init_app(cls, app):
"""Initialize production application."""
# Check for required environment variables
required_vars = [
'SECRET_KEY',
'DATABASE_URL',
'PINECONE_API_KEY',
'PINECONE_ENVIRONMENT',
'OPENAI_API_KEY'
]
missing_vars = [var for var in required_vars if not os.environ.get(var)]
if missing_vars:
raise RuntimeError(
f"Missing required environment variables: {', '.join(missing_vars)}"
)
# Configuration dictionary
config = {
'development': DevelopmentConfig,
'testing': TestingConfig,
'production': ProductionConfig,
'default': DevelopmentConfig
}
View File
-36
View File
@@ -1,36 +0,0 @@
"""
Database module for the application.
"""
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
# Define naming convention for constraints
convention = {
"ix": 'ix_%(column_0_label)s',
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
}
# Create SQLAlchemy instance with naming convention
db = SQLAlchemy(metadata=MetaData(naming_convention=convention))
def init_app(app):
"""
Initialize the database with the Flask application.
Args:
app: Flask application instance.
"""
db.init_app(app)
# Only initialize database if configured to do so
if app.config.get('INITIALIZE_DATABASE', False):
# Import models to ensure they are registered with SQLAlchemy
from app.models import user, chat, document
# Create tables if they don't exist
with app.app_context():
db.create_all()
View File
-67
View File
@@ -1,67 +0,0 @@
"""
Chat models for the application.
"""
from datetime import datetime
from app.database.db import db
class Chat(db.Model):
"""Chat model representing a chat session."""
__tablename__ = 'chats'
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=True)
is_team_chat = db.Column(db.Boolean, default=False)
model_name = db.Column(db.String(50), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Foreign keys
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
# Relationships
messages = db.relationship('Message', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
team_members = db.relationship('TeamChatMember', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
def __repr__(self):
return f'<Chat {self.id}: {self.title or "Untitled"}>'
class Message(db.Model):
"""Message model representing a single message in a chat."""
__tablename__ = 'messages'
id = db.Column(db.Integer, primary_key=True)
content = db.Column(db.Text, nullable=False)
is_user_message = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
# Foreign keys
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=True)
def __repr__(self):
return f'<Message {self.id}: {self.content[:20]}...>'
class TeamChatMember(db.Model):
"""Model representing a member of a team chat."""
__tablename__ = 'team_chat_members'
id = db.Column(db.Integer, primary_key=True)
joined_at = db.Column(db.DateTime, default=datetime.utcnow)
# Foreign keys
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
# Ensure a user can only be added to a team chat once
__table_args__ = (
db.UniqueConstraint('chat_id', 'user_id', name='uq_team_chat_member'),
)
def __repr__(self):
return f'<TeamChatMember chat_id={self.chat_id}, user_id={self.user_id}>'
-59
View File
@@ -1,59 +0,0 @@
"""
Document models for the application.
"""
from datetime import datetime
import json
from app.database.db import db
class Document(db.Model):
"""Document model representing a document in the library."""
__tablename__ = 'documents'
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
file_path = db.Column(db.String(255), nullable=True)
content_type = db.Column(db.String(50), nullable=False)
status = db.Column(db.String(20), default='pending') # pending, processing, completed, error
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Foreign keys
uploaded_by = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
# Relationships
chunks = db.relationship('DocumentChunk', backref='document', lazy='dynamic', cascade='all, delete-orphan')
def __repr__(self):
return f'<Document {self.id}: {self.title}>'
class DocumentChunk(db.Model):
"""Model representing a chunk of a document for embedding."""
__tablename__ = 'document_chunks'
id = db.Column(db.Integer, primary_key=True)
content = db.Column(db.Text, nullable=False)
chunk_index = db.Column(db.Integer, nullable=False)
embedding_id = db.Column(db.String(100), nullable=True) # ID in Pinecone
meta_data = db.Column(db.Text, nullable=True) # JSON string of metadata
created_at = db.Column(db.DateTime, default=datetime.utcnow)
# Foreign keys
document_id = db.Column(db.Integer, db.ForeignKey('documents.id'), nullable=False)
def set_metadata(self, metadata_dict):
"""Set metadata as JSON string."""
self.meta_data = json.dumps(metadata_dict)
def get_metadata(self):
"""Get metadata as dictionary."""
if self.meta_data:
return json.loads(self.meta_data)
return {}
def __repr__(self):
return f'<DocumentChunk {self.id}: doc_id={self.document_id}, index={self.chunk_index}>'
-24
View File
@@ -1,24 +0,0 @@
"""
User model for the application.
"""
from datetime import datetime
from app.database.db import db
class User(db.Model):
"""User model representing application users."""
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(64), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(128), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# Relationships
chats = db.relationship('Chat', backref='user', lazy='dynamic')
def __repr__(self):
return f'<User {self.username}>'
View File
-227
View File
@@ -1,227 +0,0 @@
"""
Service for chat functionality.
"""
from typing import List, Dict, Any, Optional
from app.database.db import db
from app.models.chat import Chat, Message, TeamChatMember
from app.models.user import User
class ChatService:
"""Service for chat functionality."""
def create_chat(self, user_id: int, title: Optional[str] = None,
is_team_chat: bool = False, model_name: Optional[str] = None) -> Chat:
"""
Create a new chat.
Args:
user_id: ID of the user creating the chat.
title: Optional title for the chat.
is_team_chat: Whether this is a team chat.
model_name: Name of the model to use for this chat.
Returns:
Created chat.
"""
from app.config.config import Config
chat = Chat(
user_id=user_id,
title=title,
is_team_chat=is_team_chat,
model_name=model_name or Config().DEFAULT_MODEL
)
db.session.add(chat)
db.session.commit()
# If it's a team chat, add the creator as a member
if is_team_chat:
self.add_team_member(chat.id, user_id)
return chat
def get_chat(self, chat_id: int) -> Optional[Chat]:
"""
Get a chat by ID.
Args:
chat_id: ID of the chat.
Returns:
Chat if found, None otherwise.
"""
return Chat.query.get(chat_id)
def get_user_chats(self, user_id: int) -> List[Chat]:
"""
Get all chats for a user.
Args:
user_id: ID of the user.
Returns:
List of chats.
"""
# Get private chats
private_chats = Chat.query.filter_by(
user_id=user_id,
is_team_chat=False
).order_by(Chat.updated_at.desc()).all()
# Get team chats where user is a member
team_chat_ids = db.session.query(TeamChatMember.chat_id).filter_by(user_id=user_id).all()
team_chat_ids = [chat_id for (chat_id,) in team_chat_ids]
team_chats = Chat.query.filter(
Chat.id.in_(team_chat_ids)
).order_by(Chat.updated_at.desc()).all()
# Combine and sort by updated_at
all_chats = private_chats + team_chats
all_chats.sort(key=lambda x: x.updated_at, reverse=True)
return all_chats
def add_message(self, chat_id: int, content: str,
is_user_message: bool = True, user_id: Optional[int] = None) -> Message:
"""
Add a message to a chat.
Args:
chat_id: ID of the chat.
content: Message content.
is_user_message: Whether this is a user message (vs. bot message).
user_id: ID of the user sending the message (required for user messages).
Returns:
Created message.
"""
message = Message(
chat_id=chat_id,
content=content,
is_user_message=is_user_message,
user_id=user_id if is_user_message else None
)
db.session.add(message)
# Update chat's updated_at timestamp
chat = Chat.query.get(chat_id)
if chat:
chat.updated_at = message.created_at
db.session.commit()
return message
def get_chat_messages(self, chat_id: int) -> List[Message]:
"""
Get all messages for a chat.
Args:
chat_id: ID of the chat.
Returns:
List of messages.
"""
return Message.query.filter_by(chat_id=chat_id).order_by(Message.created_at).all()
def add_team_member(self, chat_id: int, user_id: int) -> Optional[TeamChatMember]:
"""
Add a user to a team chat.
Args:
chat_id: ID of the team chat.
user_id: ID of the user to add.
Returns:
Created team chat member if successful, None otherwise.
"""
chat = Chat.query.get(chat_id)
if not chat or not chat.is_team_chat:
return None
# Check if user is already a member
existing_member = TeamChatMember.query.filter_by(
chat_id=chat_id,
user_id=user_id
).first()
if existing_member:
return existing_member
member = TeamChatMember(
chat_id=chat_id,
user_id=user_id
)
db.session.add(member)
db.session.commit()
return member
def get_team_members(self, chat_id: int) -> List[User]:
"""
Get all members of a team chat.
Args:
chat_id: ID of the team chat.
Returns:
List of users.
"""
member_ids = db.session.query(TeamChatMember.user_id).filter_by(chat_id=chat_id).all()
member_ids = [user_id for (user_id,) in member_ids]
return User.query.filter(User.id.in_(member_ids)).all()
def remove_team_member(self, chat_id: int, user_id: int) -> bool:
"""
Remove a user from a team chat.
Args:
chat_id: ID of the team chat.
user_id: ID of the user to remove.
Returns:
True if removal was successful, False otherwise.
"""
member = TeamChatMember.query.filter_by(
chat_id=chat_id,
user_id=user_id
).first()
if not member:
return False
db.session.delete(member)
db.session.commit()
return True
def delete_chat(self, chat_id: int) -> bool:
"""
Delete a chat and all its messages.
Args:
chat_id: ID of the chat to delete.
Returns:
True if deletion was successful, False otherwise.
"""
chat = Chat.query.get(chat_id)
if not chat:
return False
try:
db.session.delete(chat)
db.session.commit()
return True
except Exception as e:
# Log the error
print(f"Error deleting chat {chat_id}: {str(e)}")
db.session.rollback()
return False
-105
View File
@@ -1,105 +0,0 @@
"""
Service for chatbot functionality without database dependency.
"""
from typing import List, Dict, Any, Optional
class ChatbotService:
"""Service for chatbot functionality."""
def __init__(self):
"""Initialize the chatbot service."""
# In-memory storage for chat history
self.chat_history = {}
self.current_chat_id = 0
def create_chat(self, user_id: str) -> int:
"""
Create a new chat session.
Args:
user_id: ID of the user creating the chat.
Returns:
ID of the created chat.
"""
self.current_chat_id += 1
chat_id = self.current_chat_id
self.chat_history[chat_id] = {
'user_id': user_id,
'messages': []
}
return chat_id
def add_message(self, chat_id: int, content: str, is_user: bool = True) -> Dict[str, Any]:
"""
Add a message to a chat.
Args:
chat_id: ID of the chat.
content: Message content.
is_user: Whether this is a user message (vs. bot message).
Returns:
Added message.
"""
if chat_id not in self.chat_history:
raise ValueError(f"Chat with ID {chat_id} not found")
message = {
'content': content,
'is_user': is_user,
'timestamp': self._get_timestamp()
}
self.chat_history[chat_id]['messages'].append(message)
return message
def get_chat_messages(self, chat_id: int) -> List[Dict[str, Any]]:
"""
Get all messages for a chat.
Args:
chat_id: ID of the chat.
Returns:
List of messages.
"""
if chat_id not in self.chat_history:
raise ValueError(f"Chat with ID {chat_id} not found")
return self.chat_history[chat_id]['messages']
def get_response(self, chat_id: int, message: str) -> str:
"""
Get a response from the chatbot.
Args:
chat_id: ID of the chat.
message: User message.
Returns:
Bot response.
"""
# Add user message to chat history
self.add_message(chat_id, message, is_user=True)
# Simple echo response for now
response = f"You said: {message}"
# Add bot response to chat history
self.add_message(chat_id, response, is_user=False)
return response
def _get_timestamp(self) -> str:
"""Get current timestamp."""
from datetime import datetime
return datetime.utcnow().isoformat()
# Create a singleton instance
chatbot_service = ChatbotService()
-165
View File
@@ -1,165 +0,0 @@
"""
Service for document processing and embedding.
"""
import os
from typing import List, Dict, Any, Optional
import pinecone
from app.database.db import db
from app.models.document import Document, DocumentChunk
from app.config.config import Config
class DocumentService:
"""Service for document processing and embedding."""
def __init__(self, config: Config = None):
"""
Initialize the document service.
Args:
config: Configuration object.
"""
self.config = config or Config()
self._initialize_pinecone()
def _initialize_pinecone(self):
"""Initialize Pinecone client."""
pinecone.init(
api_key=self.config.PINECONE_API_KEY,
environment=self.config.PINECONE_ENVIRONMENT
)
# Check if index exists, create if it doesn't
if self.config.PINECONE_INDEX_NAME not in pinecone.list_indexes():
pinecone.create_index(
name=self.config.PINECONE_INDEX_NAME,
dimension=768, # Default dimension for sentence-transformers
metric="cosine"
)
self.index = pinecone.Index(self.config.PINECONE_INDEX_NAME)
def create_document(self, title: str, file_path: str, content_type: str,
description: Optional[str], user_id: int) -> Document:
"""
Create a new document record.
Args:
title: Document title.
file_path: Path to the document file.
content_type: MIME type of the document.
description: Optional description of the document.
user_id: ID of the user who uploaded the document.
Returns:
Created document.
"""
document = Document(
title=title,
file_path=file_path,
content_type=content_type,
description=description,
uploaded_by=user_id,
status='pending'
)
db.session.add(document)
db.session.commit()
return document
def process_document(self, document_id: int) -> bool:
"""
Process a document for embedding.
Args:
document_id: ID of the document to process.
Returns:
True if processing was successful, False otherwise.
"""
document = Document.query.get(document_id)
if not document:
return False
try:
# Update status to processing
document.status = 'processing'
db.session.commit()
# TODO: Implement document parsing and chunking
# This will be implemented in the next step
# Update status to completed
document.status = 'completed'
db.session.commit()
return True
except Exception as e:
# Update status to error
document.status = 'error'
db.session.commit()
# Log the error
print(f"Error processing document {document_id}: {str(e)}")
return False
def get_document(self, document_id: int) -> Optional[Document]:
"""
Get a document by ID.
Args:
document_id: ID of the document.
Returns:
Document if found, None otherwise.
"""
return Document.query.get(document_id)
def get_all_documents(self, user_id: Optional[int] = None) -> List[Document]:
"""
Get all documents, optionally filtered by user.
Args:
user_id: Optional user ID to filter by.
Returns:
List of documents.
"""
query = Document.query
if user_id:
query = query.filter_by(uploaded_by=user_id)
return query.order_by(Document.created_at.desc()).all()
def delete_document(self, document_id: int) -> bool:
"""
Delete a document and its chunks.
Args:
document_id: ID of the document to delete.
Returns:
True if deletion was successful, False otherwise.
"""
document = Document.query.get(document_id)
if not document:
return False
try:
# Delete document chunks from Pinecone
chunks = DocumentChunk.query.filter_by(document_id=document_id).all()
embedding_ids = [chunk.embedding_id for chunk in chunks if chunk.embedding_id]
if embedding_ids:
self.index.delete(ids=embedding_ids)
# Delete document from database
db.session.delete(document)
db.session.commit()
return True
except Exception as e:
# Log the error
print(f"Error deleting document {document_id}: {str(e)}")
db.session.rollback()
return False
-95
View File
@@ -1,95 +0,0 @@
"""
Service for model management and interaction.
"""
from typing import List, Dict, Any, Optional
from app.config.config import Config
class ModelService:
"""Service for model management and interaction."""
# Available models
AVAILABLE_MODELS = {
'gpt-3.5-turbo': {
'name': 'GPT-3.5 Turbo',
'description': 'OpenAI GPT-3.5 Turbo model',
'provider': 'openai',
'max_tokens': 4096
},
'gpt-4': {
'name': 'GPT-4',
'description': 'OpenAI GPT-4 model',
'provider': 'openai',
'max_tokens': 8192
},
# Add more models as needed
}
def __init__(self, config: Config = None):
"""
Initialize the model service.
Args:
config: Configuration object.
"""
self.config = config or Config()
self.default_model = self.config.DEFAULT_MODEL
def get_available_models(self) -> List[Dict[str, Any]]:
"""
Get a list of available models.
Returns:
List of model information dictionaries.
"""
models = []
for model_id, model_info in self.AVAILABLE_MODELS.items():
model_data = {
'id': model_id,
'is_default': model_id == self.default_model,
**model_info
}
models.append(model_data)
return models
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get information about a specific model.
Args:
model_id: ID of the model.
Returns:
Model information dictionary if found, None otherwise.
"""
if model_id not in self.AVAILABLE_MODELS:
return None
return {
'id': model_id,
'is_default': model_id == self.default_model,
**self.AVAILABLE_MODELS[model_id]
}
def generate_response(self, model_id: str, prompt: str,
context: Optional[List[Dict[str, str]]] = None) -> str:
"""
Generate a response from the model.
Args:
model_id: ID of the model to use.
prompt: User prompt.
context: Optional conversation context.
Returns:
Generated response.
"""
# TODO: Implement actual model integration
# This is a placeholder that will be implemented in the next steps
if model_id not in self.AVAILABLE_MODELS:
model_id = self.default_model
# Placeholder response
return f"This is a placeholder response from {self.AVAILABLE_MODELS[model_id]['name']}. The actual model integration will be implemented in the next steps."
View File
+22
View File
@@ -0,0 +1,22 @@
#!/bin/bash
# Script to run the RAG test in a Python virtual environment
# Make the script executable
chmod +x test_rag.py
# Check if virtual environment is activated
if [[ -z "$VIRTUAL_ENV" ]]; then
echo "Virtual environment is not activated."
echo "Please activate your virtual environment first with:"
echo "source /path/to/your/venv/bin/activate"
exit 1
fi
echo "Running RAG tests against remote server..."
python test_rag.py --remote --verbose
# If you want to test against a different server, uncomment and modify this line:
# python test_rag.py --api-url "http://your-server-url:port" --verbose
# If you want to test with a specific query, uncomment and modify this line:
# python test_rag.py --remote --query "What information do you have about project X?" --verbose
-144
View File
@@ -1,144 +0,0 @@
"""
Simple API for testing deployment.
"""
import os
import uuid
from datetime import datetime
from typing import List, Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
# Create FastAPI app
app = FastAPI(
title="Simple AI Service API",
description="Simple API for testing deployment",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define API models
class MessageRequest(BaseModel):
"""Request model for sending a message."""
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
# Model parameters
temperature: Optional[float] = Field(None, description="Controls randomness")
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
system_prompt: Optional[str] = Field(None, description="System prompt")
class Message(BaseModel):
"""Model for a message."""
id: str = Field(..., description="Message ID")
content: str = Field(..., description="Message content")
user_id: Optional[str] = Field(None, description="User ID")
is_user_message: bool = Field(..., description="Whether this is a user message")
timestamp: str = Field(..., description="Message timestamp")
class ChatRequest(BaseModel):
"""Request model for creating a chat."""
user_id: str = Field(..., description="User ID")
title: Optional[str] = Field(None, description="Chat title")
model_id: Optional[str] = Field(None, description="Model ID")
class Chat(BaseModel):
"""Model for a chat."""
id: str = Field(..., description="Chat ID")
title: str = Field(..., description="Chat title")
user_id: str = Field(..., description="User ID")
model_id: str = Field(..., description="Model ID")
created_at: str = Field(..., description="Creation timestamp")
updated_at: str = Field(..., description="Update timestamp")
messages: List[Message] = Field(default=[], description="Chat messages")
# In-memory storage
chats = {}
# API endpoints
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/chats", response_model=Chat)
async def create_chat(request: ChatRequest):
"""Create a new chat."""
chat_id = str(uuid.uuid4())
chat = {
"id": chat_id,
"title": request.title or f"Chat {len(chats) + 1}",
"user_id": request.user_id,
"model_id": request.model_id or "gpt-3.5-turbo",
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
"messages": []
}
chats[chat_id] = chat
return chat
@app.get("/chats/{chat_id}", response_model=Chat)
async def get_chat(chat_id: str):
"""Get a chat by ID."""
if chat_id not in chats:
raise HTTPException(status_code=404, detail="Chat not found")
return chats[chat_id]
@app.post("/chats/{chat_id}/messages", response_model=Message)
async def send_message(chat_id: str, request: MessageRequest):
"""Send a message to a chat."""
if chat_id not in chats:
raise HTTPException(status_code=404, detail="Chat not found")
# Add user message
user_message = {
"id": str(uuid.uuid4()),
"content": request.message,
"user_id": request.user_id,
"is_user_message": True,
"timestamp": datetime.utcnow().isoformat()
}
chats[chat_id]["messages"].append(user_message)
# Generate bot response
params_text = ""
if request.temperature is not None:
params_text += f" (temperature={request.temperature})"
if request.max_tokens is not None:
params_text += f" (max_tokens={request.max_tokens})"
if request.system_prompt is not None:
params_text += f" (using custom system prompt)"
bot_message = {
"id": str(uuid.uuid4()),
"content": f"This is a test response to: '{request.message}'{params_text}",
"user_id": None,
"is_user_message": False,
"timestamp": datetime.utcnow().isoformat()
}
chats[chat_id]["messages"].append(bot_message)
chats[chat_id]["updated_at"] = datetime.utcnow().isoformat()
return bot_message
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5251)
-69
View File
@@ -1,69 +0,0 @@
"""
Test script for sending a message with advanced parameters.
"""
import requests
import json
import uuid
# Create a new chat
def create_chat():
response = requests.post(
"http://localhost:5251/chats",
json={
"user_id": "test_user",
"title": "Test Chat",
"model_id": "gpt-3.5-turbo",
"is_team_chat": False
}
)
if response.status_code == 200:
return response.json()["id"]
else:
print(f"Error creating chat: {response.status_code}")
print(response.text)
return None
# Send a message with advanced parameters
def send_message(chat_id):
response = requests.post(
f"http://localhost:5251/chats/{chat_id}/messages",
json={
"message": "Tell me about artificial intelligence",
"user_id": "test_user",
"use_rag": False,
"temperature": 0.7,
"max_tokens": 500,
"top_p": 0.9,
"frequency_penalty": 0.5,
"presence_penalty": 0.5,
"system_prompt": "You are a helpful AI assistant that provides concise responses."
}
)
if response.status_code == 200:
return response.json()
else:
print(f"Error sending message: {response.status_code}")
print(response.text)
return None
# Main function
def main():
print("Creating a new chat...")
chat_id = create_chat()
if chat_id:
print(f"Chat created with ID: {chat_id}")
print("\nSending a message with advanced parameters...")
response = send_message(chat_id)
if response:
print("\nResponse received:")
print(f"Message ID: {response['id']}")
print(f"Content: {response['content']}")
if __name__ == "__main__":
main()
Executable
+309
View File
@@ -0,0 +1,309 @@
#!/usr/bin/env python
"""
Test script for RAG functionality in the AI service.
This script tests the document-based question answering capabilities
by making requests to the API endpoints.
"""
import os
import sys
import json
import argparse
import requests
from typing import Dict, Any, Optional
from pprint import pprint
# Default configuration
DEFAULT_API_URL = "http://localhost:5252" # Local development server
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
DEFAULT_MODEL = "llama3.1"
class RAGTester:
"""Test the RAG functionality of the AI service."""
def __init__(self, api_url: str, verbose: bool = False):
"""
Initialize the RAG tester.
Args:
api_url: URL of the AI service API.
verbose: Whether to print verbose output.
"""
self.api_url = api_url
self.verbose = verbose
self.session = requests.Session()
# Print configuration
print(f"Testing RAG functionality against API at: {self.api_url}")
def _log(self, message: str):
"""Log a message if verbose mode is enabled."""
if self.verbose:
print(f"[DEBUG] {message}")
def check_server_health(self) -> bool:
"""
Check if the server is healthy.
Returns:
True if the server is healthy, False otherwise.
"""
try:
response = self.session.get(f"{self.api_url}/health", timeout=10)
response.raise_for_status()
result = response.json()
if result.get("status") == "healthy":
print("✅ Server is healthy")
return True
else:
print("❌ Server health check failed")
print(f"Response: {result}")
return False
except Exception as e:
print(f"❌ Error checking server health: {str(e)}")
return False
def check_config(self) -> Dict[str, Any]:
"""
Check the server configuration.
Returns:
Server configuration.
"""
try:
response = self.session.get(f"{self.api_url}/config", timeout=10)
response.raise_for_status()
config = response.json()
print("✅ Server configuration:")
print(f" - API Host: {config.get('api_host')}")
print(f" - API Port: {config.get('api_port')}")
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
print(f" - Default Model: {config.get('default_model')}")
print(f" - API Timeout: {config.get('api_timeout')} seconds")
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
return config
except Exception as e:
print(f"❌ Error checking server configuration: {str(e)}")
return {}
def test_ollama_connection(self) -> bool:
"""
Test the connection to Ollama.
Returns:
True if the connection is successful, False otherwise.
"""
try:
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully connected to Ollama API")
print(f" - Ollama URL: {result.get('ollama_url')}")
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
return True
else:
print("❌ Failed to connect to Ollama API")
print(f" - Error: {result.get('message')}")
print(f" - Ollama URL: {result.get('ollama_url')}")
return False
except Exception as e:
print(f"❌ Error testing Ollama connection: {str(e)}")
return False
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
"""
Test basic chat completion without RAG.
Args:
model_id: Optional model ID to use.
Returns:
True if the test is successful, False otherwise.
"""
try:
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully tested chat completion")
print(f" - Model: {result.get('model')}")
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
return True
else:
print("❌ Failed to test chat completion")
print(f" - Error: {result.get('message')}")
return False
except Exception as e:
print(f"❌ Error testing chat completion: {str(e)}")
return False
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
"""
Test RAG completion with a query.
Args:
query: Query to test with RAG.
Returns:
True if the test is successful, False otherwise.
"""
try:
response = self.session.post(
f"{self.api_url}/test-rag",
params={"query": query},
timeout=120 # Longer timeout for RAG
)
response.raise_for_status()
result = response.json()
if result.get("status") == "success":
print("✅ Successfully tested RAG completion")
print(f" - Model: {result.get('model')}")
print(f" - Query: {result.get('query')}")
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
# Print full response in verbose mode
if self.verbose:
print("\nFull response:")
print(result.get('response'))
return True
else:
print("❌ Failed to test RAG completion")
print(f" - Error: {result.get('message')}")
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
return False
except Exception as e:
print(f"❌ Error testing RAG completion: {str(e)}")
return False
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
"""
Create a chat and test RAG with a message.
Args:
query: Query to test with RAG.
user_id: User ID for the chat.
Returns:
True if the test is successful, False otherwise.
"""
try:
# Create a chat
create_response = self.session.post(
f"{self.api_url}/chats",
json={
"user_id": user_id,
"title": "RAG Test Chat",
"model_id": DEFAULT_MODEL,
"is_team_chat": False
},
timeout=30
)
create_response.raise_for_status()
chat = create_response.json()
chat_id = chat.get("id")
print(f"✅ Created chat with ID: {chat_id}")
# Send a message with RAG enabled
message_response = self.session.post(
f"{self.api_url}/chats/{chat_id}/messages",
json={
"message": query,
"user_id": user_id,
"use_rag": True,
"temperature": 0.7,
"max_tokens": 1000
},
timeout=120 # Longer timeout for RAG
)
message_response.raise_for_status()
message = message_response.json()
print("✅ Successfully sent message with RAG")
print(f" - Message ID: {message.get('id')}")
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
# Print full response in verbose mode
if self.verbose:
print("\nFull response:")
print(message.get('content'))
return True
except Exception as e:
print(f"❌ Error testing chat with RAG: {str(e)}")
return False
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
"""
Run all tests.
Args:
query: Query to test with RAG.
"""
print("\n=== Running RAG Functionality Tests ===\n")
# Check server health
if not self.check_server_health():
print("❌ Server health check failed. Aborting tests.")
return
# Check configuration
config = self.check_config()
if not config:
print("❌ Failed to get server configuration. Continuing with tests...")
# Test Ollama connection
if not self.test_ollama_connection():
print("⚠️ Ollama connection test failed. Some tests may fail.")
# Test basic chat completion
if not self.test_chat_completion():
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
# Test RAG completion
print("\n--- Testing RAG Completion ---\n")
self.test_rag_completion(query)
# Test chat with RAG
print("\n--- Testing Chat with RAG ---\n")
self.create_chat_and_test_rag(query)
print("\n=== Tests Completed ===\n")
def main():
"""Main function to run the tests."""
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
args = parser.parse_args()
# Use remote URL if specified
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
# Create and run the tester
tester = RAGTester(api_url=api_url, verbose=args.verbose)
tester.run_all_tests(query=args.query)
if __name__ == "__main__":
main()
+2
View File
@@ -0,0 +1,2 @@
requests>=2.28.0
pydantic>=2.0.0