Added Rag Featured
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
# Document-Based Question Answering with OpenWebUI
|
||||
|
||||
This document explains how to use the document-based question answering (RAG) functionality in the AI service.
|
||||
|
||||
## Overview
|
||||
|
||||
The AI service now supports Retrieval Augmented Generation (RAG) by leveraging OpenWebUI's built-in knowledge database. This allows users to:
|
||||
|
||||
1. Upload documents to OpenWebUI
|
||||
2. Ask questions about those documents
|
||||
3. Receive responses that incorporate information from the documents
|
||||
|
||||
## How It Works
|
||||
|
||||
When RAG is enabled:
|
||||
|
||||
1. The AI service forwards the request to OpenWebUI with `use_knowledge=True`
|
||||
2. OpenWebUI searches its knowledge database for relevant information
|
||||
3. The retrieved information is used to augment the model's response
|
||||
4. The response is returned to the user
|
||||
|
||||
## Using RAG in API Requests
|
||||
|
||||
To enable RAG in your API requests, set the `use_rag` parameter to `true`:
|
||||
|
||||
```json
|
||||
POST /chats/{chat_id}/messages
|
||||
{
|
||||
"message": "What information do you have about project X?",
|
||||
"user_id": "user123",
|
||||
"use_rag": true
|
||||
}
|
||||
```
|
||||
|
||||
## Testing RAG Functionality
|
||||
|
||||
You can test the RAG functionality using the `/test-rag` endpoint:
|
||||
|
||||
```
|
||||
POST /test-rag?query=What information do you have about project X?
|
||||
```
|
||||
|
||||
This will return a response that includes information from documents in OpenWebUI's knowledge database.
|
||||
|
||||
## Uploading Documents to OpenWebUI
|
||||
|
||||
To use RAG effectively, you need to upload documents to OpenWebUI:
|
||||
|
||||
1. Log in to OpenWebUI at your configured URL (default: http://104.225.217.215:8080)
|
||||
2. Navigate to the Knowledge section
|
||||
3. Upload your documents (PDF, TXT, DOCX, etc.)
|
||||
4. OpenWebUI will automatically process and index the documents
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If RAG is not working as expected:
|
||||
|
||||
1. Ensure OpenWebUI is running and accessible
|
||||
2. Check that documents are properly uploaded and indexed in OpenWebUI
|
||||
3. Verify that the `use_rag` parameter is set to `true` in your requests
|
||||
4. Check the logs for any errors related to OpenWebUI API calls
|
||||
|
||||
If there are connection issues with OpenWebUI, the AI service will automatically fall back to using the direct Ollama API without RAG.
|
||||
|
||||
## Configuration
|
||||
|
||||
The following configuration settings affect RAG functionality:
|
||||
|
||||
- `OPENWEBUI_URL`: URL of your OpenWebUI instance
|
||||
- `OPENWEBUI_API_KEY`: API key for OpenWebUI (if required)
|
||||
- `API_TIMEOUT`: Timeout for API requests (in seconds)
|
||||
|
||||
These can be set in your environment variables or in the `.env` file.
|
||||
@@ -1,6 +1,11 @@
|
||||
"""
|
||||
FastAPI application for the AI service.
|
||||
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
|
||||
|
||||
The service supports document-based question answering using OpenWebUI's knowledge database:
|
||||
- Set use_rag=True in API requests to enable Retrieval Augmented Generation
|
||||
- When enabled, the service will use OpenWebUI's knowledge database to find relevant information
|
||||
- Documents uploaded to OpenWebUI will be used to augment the model's responses
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
@@ -203,6 +208,43 @@ async def test_chat_completion():
|
||||
"ollama_url": config.OLLAMA_API_URL
|
||||
}
|
||||
|
||||
@app.post("/test-rag")
|
||||
async def test_rag_completion(query: str = "What information do you have in your knowledge database?"):
|
||||
"""
|
||||
Test the RAG (Retrieval Augmented Generation) functionality with a query.
|
||||
|
||||
This endpoint tests the integration with OpenWebUI's knowledge database.
|
||||
|
||||
Args:
|
||||
query: The question to ask about documents in the knowledge database.
|
||||
|
||||
Returns:
|
||||
Model response using RAG.
|
||||
"""
|
||||
try:
|
||||
# Use the model service directly with RAG enabled
|
||||
response = model_service.generate_response(
|
||||
model_id=config.DEFAULT_MODEL,
|
||||
prompt=query,
|
||||
context=[],
|
||||
use_rag=True # Enable RAG
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model": config.DEFAULT_MODEL,
|
||||
"query": query,
|
||||
"use_rag": True,
|
||||
"response": response,
|
||||
"openwebui_url": config.OPENWEBUI_URL
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to get RAG completion: {str(e)}",
|
||||
"openwebui_url": config.OPENWEBUI_URL
|
||||
}
|
||||
|
||||
@app.post("/test-ollama-direct")
|
||||
async def test_ollama_direct():
|
||||
"""
|
||||
|
||||
@@ -198,5 +198,32 @@
|
||||
}
|
||||
],
|
||||
"team_members": []
|
||||
},
|
||||
"aefe133e-e85e-422f-9ca9-6441f24db74f": {
|
||||
"id": "aefe133e-e85e-422f-9ca9-6441f24db74f",
|
||||
"title": "RAG Test Chat",
|
||||
"user_id": "test_user",
|
||||
"model_id": "llama3.1",
|
||||
"is_team_chat": false,
|
||||
"created_at": "2025-05-16T14:41:08.281043",
|
||||
"updated_at": "2025-05-16T14:41:09.142922",
|
||||
"messages": [
|
||||
{
|
||||
"id": "35384001-0c6d-4ac6-a247-32fb53bb664c",
|
||||
"content": "What information do you have in your knowledge database?",
|
||||
"user_id": "test_user",
|
||||
"is_user_message": true,
|
||||
"timestamp": "2025-05-16T14:41:08.341940"
|
||||
},
|
||||
{
|
||||
"id": "4b82555f-76e2-43c0-a1aa-c91257536133",
|
||||
"content": "Connection error to Ollama API: HTTPConnectionPool(host='104.225.217.215', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x104bcc750>: Failed to establish a new connection: [Errno 61] Connection refused')). Please check if Ollama is running at http://104.225.217.215:11434.",
|
||||
"user_id": null,
|
||||
"is_user_message": false,
|
||||
"timestamp": "2025-05-16T14:41:09.142892"
|
||||
}
|
||||
],
|
||||
"team_members": [],
|
||||
"openwebui_channel_id": null
|
||||
}
|
||||
}
|
||||
@@ -1,261 +0,0 @@
|
||||
"""
|
||||
Service for document processing and chunking.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import requests
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from ai_service.config import config
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document processing and chunking."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the document service."""
|
||||
self.chunk_size = config.CHUNK_SIZE
|
||||
self.chunk_overlap = config.CHUNK_OVERLAP
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
|
||||
# OpenWebUI configuration
|
||||
self.openwebui_url = config.OPENWEBUI_URL
|
||||
self.openwebui_api_key = config.OPENWEBUI_API_KEY
|
||||
|
||||
# Ensure data directory exists
|
||||
os.makedirs('ai_service/data', exist_ok=True)
|
||||
|
||||
# For now, we'll store document metadata in a simple JSON file
|
||||
self.metadata_file = 'ai_service/data/document_metadata.json'
|
||||
self._load_metadata()
|
||||
|
||||
def _load_metadata(self):
|
||||
"""Load document metadata from file."""
|
||||
if os.path.exists(self.metadata_file):
|
||||
try:
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
self.documents = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading document metadata: {str(e)}")
|
||||
self.documents = {}
|
||||
else:
|
||||
self.documents = {}
|
||||
|
||||
def _save_metadata(self):
|
||||
"""Save document metadata to file."""
|
||||
try:
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(self.documents, f, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving document metadata: {str(e)}")
|
||||
|
||||
def process_document(self, content: str, title: str,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Process a document for embedding.
|
||||
|
||||
Args:
|
||||
content: Document content.
|
||||
title: Document title.
|
||||
description: Optional document description.
|
||||
metadata: Optional additional metadata.
|
||||
|
||||
Returns:
|
||||
Document ID.
|
||||
"""
|
||||
# Generate a unique ID for the document
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Upload the document to OpenWebUI for RAG processing
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Prepare the document data
|
||||
document_data = {
|
||||
"filename": f"{title}.txt",
|
||||
"content": base64.b64encode(content.encode('utf-8')).decode('utf-8'),
|
||||
"description": description or title
|
||||
}
|
||||
|
||||
# Upload to OpenWebUI
|
||||
response = requests.post(
|
||||
f"{self.openwebui_url}/api/knowledge/upload",
|
||||
headers=headers,
|
||||
json=document_data,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Get the OpenWebUI document ID
|
||||
openwebui_doc_id = result.get('id', '')
|
||||
|
||||
# Store document metadata
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'title': title,
|
||||
'description': description or '',
|
||||
'openwebui_id': openwebui_doc_id,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return doc_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error uploading document to OpenWebUI: {str(e)}")
|
||||
|
||||
# Fall back to local processing if OpenWebUI upload fails
|
||||
print("Falling back to local document processing")
|
||||
|
||||
# Split the document into chunks for local reference
|
||||
chunks = self.text_splitter.split_text(content)
|
||||
|
||||
# Store document metadata
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'title': title,
|
||||
'description': description or '',
|
||||
'chunk_count': len(chunks),
|
||||
'openwebui_upload_failed': True,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return doc_id
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get document metadata.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Document metadata if found, None otherwise.
|
||||
"""
|
||||
return self.documents.get(doc_id)
|
||||
|
||||
def get_all_documents(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all document metadata.
|
||||
|
||||
Returns:
|
||||
List of document metadata.
|
||||
"""
|
||||
# Get documents from local storage
|
||||
local_documents = list(self.documents.values())
|
||||
|
||||
# Try to get documents from OpenWebUI as well
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Get documents from OpenWebUI
|
||||
response = requests.get(
|
||||
f"{self.openwebui_url}/api/knowledge",
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
openwebui_docs = response.json()
|
||||
|
||||
# Update local documents with OpenWebUI information
|
||||
for doc in local_documents:
|
||||
if 'openwebui_id' in doc:
|
||||
for openwebui_doc in openwebui_docs:
|
||||
if openwebui_doc.get('id') == doc['openwebui_id']:
|
||||
doc['openwebui_status'] = 'active'
|
||||
doc['openwebui_info'] = openwebui_doc
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting documents from OpenWebUI: {str(e)}")
|
||||
|
||||
return local_documents
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""
|
||||
Delete a document and its chunks.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
if doc_id not in self.documents:
|
||||
return False
|
||||
|
||||
# Check if document was uploaded to OpenWebUI
|
||||
doc = self.documents[doc_id]
|
||||
openwebui_id = doc.get('openwebui_id')
|
||||
|
||||
if openwebui_id:
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Delete from OpenWebUI
|
||||
response = requests.delete(
|
||||
f"{self.openwebui_url}/api/knowledge/{openwebui_id}",
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Warning: Failed to delete document from OpenWebUI: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting document from OpenWebUI: {str(e)}")
|
||||
|
||||
# Delete document metadata
|
||||
del self.documents[doc_id]
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return True
|
||||
|
||||
def search_documents(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for documents similar to a query.
|
||||
|
||||
Args:
|
||||
query: Search query.
|
||||
top_k: Number of results to return.
|
||||
|
||||
Returns:
|
||||
List of similar document chunks with their metadata.
|
||||
"""
|
||||
# Note: We don't need to implement this method anymore since
|
||||
# RAG is handled directly by OpenWebUI when use_rag=True in the model service
|
||||
|
||||
# Return empty results - this is just a placeholder
|
||||
# The actual RAG functionality is in the model_service.generate_response method
|
||||
return []
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
document_service = DocumentService()
|
||||
@@ -135,9 +135,6 @@ class ModelService:
|
||||
model_id = self.default_model
|
||||
print(f" - Model not found, using default: {model_id}")
|
||||
|
||||
# Ensure we're using a valid model
|
||||
# (model_id is already validated above)
|
||||
|
||||
# Prepare the messages for the API call
|
||||
messages = []
|
||||
|
||||
@@ -178,6 +175,10 @@ class ModelService:
|
||||
openwebui_request['max_tokens'] = params['max_tokens']
|
||||
if 'top_p' in params:
|
||||
openwebui_request['top_p'] = params['top_p']
|
||||
if 'top_k' in params:
|
||||
openwebui_request['top_k'] = params['top_k']
|
||||
if 'repeat_penalty' in params:
|
||||
openwebui_request['repeat_penalty'] = params['repeat_penalty']
|
||||
|
||||
# Make the API call to OpenWebUI
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@@ -201,10 +202,15 @@ class ModelService:
|
||||
result = response.json()
|
||||
|
||||
# Extract the response content
|
||||
if 'message' in result:
|
||||
if 'choices' in result and len(result['choices']) > 0 and 'message' in result['choices'][0]:
|
||||
# OpenAI-compatible format
|
||||
return result['choices'][0]['message']['content']
|
||||
elif 'message' in result and 'content' in result['message']:
|
||||
# OpenWebUI format
|
||||
return result['message']['content']
|
||||
else:
|
||||
return "Error: Unexpected response format from OpenWebUI"
|
||||
print(f"WARNING: Unexpected response format from OpenWebUI: {json.dumps(result, indent=2)}")
|
||||
return "Error: Unexpected response format from OpenWebUI. Falling back to direct model call."
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout."
|
||||
@@ -216,6 +222,11 @@ class ModelService:
|
||||
print(f"ERROR: {error_msg}")
|
||||
print("Falling back to direct Ollama call without RAG")
|
||||
# Continue to the Ollama API call below
|
||||
except requests.exceptions.HTTPError as e:
|
||||
error_msg = f"HTTP error from OpenWebUI API: {str(e)}."
|
||||
print(f"ERROR: {error_msg}")
|
||||
print("Falling back to direct Ollama call without RAG")
|
||||
# Continue to the Ollama API call below
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling OpenWebUI API: {str(e)}"
|
||||
print(f"ERROR: {error_msg}")
|
||||
@@ -247,6 +258,8 @@ class ModelService:
|
||||
request_json['top_k'] = params['top_k']
|
||||
if 'max_tokens' in params:
|
||||
request_json['max_tokens'] = params['max_tokens']
|
||||
if 'repeat_penalty' in params:
|
||||
request_json['repeat_penalty'] = params['repeat_penalty']
|
||||
|
||||
# Make the API call to Ollama
|
||||
try:
|
||||
@@ -272,6 +285,7 @@ class ModelService:
|
||||
if 'message' in result and 'content' in result['message']:
|
||||
return result['message']['content']
|
||||
else:
|
||||
print(f"WARNING: Unexpected response format from Ollama: {json.dumps(result, indent=2)}")
|
||||
return "Error: Unexpected response format from Ollama"
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
|
||||
@@ -71,10 +71,16 @@ async def chat_completions(request: Request):
|
||||
temperature = body.get("temperature")
|
||||
max_tokens = body.get("max_tokens")
|
||||
top_p = body.get("top_p")
|
||||
top_k = body.get("top_k")
|
||||
frequency_penalty = body.get("frequency_penalty")
|
||||
presence_penalty = body.get("presence_penalty")
|
||||
repeat_penalty = body.get("repeat_penalty")
|
||||
stop = body.get("stop")
|
||||
|
||||
# Check if RAG should be used
|
||||
use_knowledge = body.get("use_knowledge", False)
|
||||
use_rag = use_knowledge # Map OpenWebUI's use_knowledge to our use_rag parameter
|
||||
|
||||
# Create a unique chat ID
|
||||
chat_id = str(uuid.uuid4())
|
||||
|
||||
@@ -99,11 +105,14 @@ async def chat_completions(request: Request):
|
||||
chat_id=chat_id,
|
||||
message=user_message,
|
||||
user_id=user_id,
|
||||
use_rag=use_rag,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
|
||||
)
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
Main application package for the chatbot application.
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from app.config.config import Config
|
||||
|
||||
def create_app(config_class=Config):
|
||||
"""
|
||||
Create and configure the Flask application.
|
||||
|
||||
Args:
|
||||
config_class: Configuration class to use.
|
||||
|
||||
Returns:
|
||||
Flask application instance.
|
||||
"""
|
||||
# Initialize Flask app
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config.from_object(config_class)
|
||||
|
||||
# Register Flask routes
|
||||
from app.api import routes as flask_routes
|
||||
flask_app.register_blueprint(flask_routes.bp)
|
||||
|
||||
# For now, we'll use only Flask routes and disable FastAPI integration
|
||||
# until we resolve the integration issues
|
||||
|
||||
# Initialize database
|
||||
from app.database import db
|
||||
db.init_app(flask_app)
|
||||
|
||||
return flask_app
|
||||
-110
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
FastAPI routes for the application.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.services.chatbot_service import chatbot_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request model for sending a message."""
|
||||
message: str
|
||||
user_id: str = "default_user"
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response model for a message."""
|
||||
content: str
|
||||
is_user: bool
|
||||
timestamp: str
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Response model for a chat."""
|
||||
chat_id: int
|
||||
messages: List[MessageResponse]
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Returns:
|
||||
JSON response with health status.
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def create_chat(user_id: str = "default_user"):
|
||||
"""
|
||||
Create a new chat.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user creating the chat.
|
||||
|
||||
Returns:
|
||||
Created chat.
|
||||
"""
|
||||
chat_id = chatbot_service.create_chat(user_id)
|
||||
|
||||
return {
|
||||
"chat_id": chat_id,
|
||||
"messages": []
|
||||
}
|
||||
|
||||
@router.post("/chat/{chat_id}/message", response_model=MessageResponse)
|
||||
async def send_message(chat_id: int, request: MessageRequest):
|
||||
"""
|
||||
Send a message to the chatbot.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
request: Message request.
|
||||
|
||||
Returns:
|
||||
Bot response.
|
||||
"""
|
||||
try:
|
||||
response = chatbot_service.get_response(chat_id, request.message)
|
||||
|
||||
# Get the last message (bot response)
|
||||
messages = chatbot_service.get_chat_messages(chat_id)
|
||||
last_message = messages[-1]
|
||||
|
||||
return last_message
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@router.get("/chat/{chat_id}", response_model=ChatResponse)
|
||||
async def get_chat(chat_id: int):
|
||||
"""
|
||||
Get a chat by ID.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
Chat with messages.
|
||||
"""
|
||||
try:
|
||||
messages = chatbot_service.get_chat_messages(chat_id)
|
||||
|
||||
return {
|
||||
"chat_id": chat_id,
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
def init_app(app):
|
||||
"""
|
||||
Initialize FastAPI application with routes.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
"""
|
||||
app.include_router(router, prefix="/api")
|
||||
@@ -1,100 +0,0 @@
|
||||
"""
|
||||
Flask routes for the application.
|
||||
"""
|
||||
|
||||
from flask import Blueprint, jsonify, request, abort
|
||||
|
||||
from app.services.chatbot_service import chatbot_service
|
||||
|
||||
bp = Blueprint('main', __name__)
|
||||
|
||||
@bp.route('/')
|
||||
def index():
|
||||
"""
|
||||
Root endpoint.
|
||||
|
||||
Returns:
|
||||
JSON response with application information.
|
||||
"""
|
||||
return jsonify({
|
||||
'name': 'Chatbot Application',
|
||||
'version': '1.0.0',
|
||||
'status': 'running'
|
||||
})
|
||||
|
||||
@bp.route('/api/health')
|
||||
def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Returns:
|
||||
JSON response with health status.
|
||||
"""
|
||||
return jsonify({
|
||||
'status': 'healthy'
|
||||
})
|
||||
|
||||
@bp.route('/api/chat', methods=['POST'])
|
||||
def create_chat():
|
||||
"""
|
||||
Create a new chat.
|
||||
|
||||
Returns:
|
||||
JSON response with chat ID.
|
||||
"""
|
||||
user_id = request.json.get('user_id', 'default_user')
|
||||
chat_id = chatbot_service.create_chat(user_id)
|
||||
|
||||
return jsonify({
|
||||
'chat_id': chat_id,
|
||||
'messages': []
|
||||
})
|
||||
|
||||
@bp.route('/api/chat/<int:chat_id>/message', methods=['POST'])
|
||||
def send_message(chat_id):
|
||||
"""
|
||||
Send a message to the chatbot.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
JSON response with bot response.
|
||||
"""
|
||||
if not request.json or 'message' not in request.json:
|
||||
abort(400, description="Message is required")
|
||||
|
||||
try:
|
||||
message = request.json['message']
|
||||
response = chatbot_service.get_response(chat_id, message)
|
||||
|
||||
# Get the last message (bot response)
|
||||
messages = chatbot_service.get_chat_messages(chat_id)
|
||||
last_message = messages[-1]
|
||||
|
||||
return jsonify(last_message)
|
||||
|
||||
except ValueError as e:
|
||||
abort(404, description=str(e))
|
||||
|
||||
@bp.route('/api/chat/<int:chat_id>', methods=['GET'])
|
||||
def get_chat(chat_id):
|
||||
"""
|
||||
Get a chat by ID.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
JSON response with chat messages.
|
||||
"""
|
||||
try:
|
||||
messages = chatbot_service.get_chat_messages(chat_id)
|
||||
|
||||
return jsonify({
|
||||
'chat_id': chat_id,
|
||||
'messages': messages
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
abort(404, description=str(e))
|
||||
@@ -1,79 +0,0 @@
|
||||
"""
|
||||
Configuration settings for the application.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
class Config:
|
||||
"""Base configuration."""
|
||||
|
||||
# Flask configuration
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-key-for-development-only')
|
||||
DEBUG = False
|
||||
TESTING = False
|
||||
|
||||
# Database configuration
|
||||
SQLALCHEMY_DATABASE_URI = os.environ.get(
|
||||
'DATABASE_URL',
|
||||
'sqlite:///chatbot.db'
|
||||
)
|
||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
||||
INITIALIZE_DATABASE = os.environ.get('INITIALIZE_DATABASE', 'False').lower() == 'true'
|
||||
|
||||
# Pinecone configuration
|
||||
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', '')
|
||||
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT', '')
|
||||
PINECONE_INDEX_NAME = os.environ.get('PINECONE_INDEX_NAME', 'chatbot-index')
|
||||
|
||||
# Model configuration
|
||||
DEFAULT_MODEL = os.environ.get('DEFAULT_MODEL', 'gpt-3.5-turbo')
|
||||
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
|
||||
|
||||
|
||||
class DevelopmentConfig(Config):
|
||||
"""Development configuration."""
|
||||
|
||||
DEBUG = True
|
||||
|
||||
|
||||
class TestingConfig(Config):
|
||||
"""Testing configuration."""
|
||||
|
||||
TESTING = True
|
||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
||||
|
||||
|
||||
class ProductionConfig(Config):
|
||||
"""Production configuration."""
|
||||
|
||||
# Ensure all required environment variables are set in production
|
||||
@classmethod
|
||||
def init_app(cls, app):
|
||||
"""Initialize production application."""
|
||||
# Check for required environment variables
|
||||
required_vars = [
|
||||
'SECRET_KEY',
|
||||
'DATABASE_URL',
|
||||
'PINECONE_API_KEY',
|
||||
'PINECONE_ENVIRONMENT',
|
||||
'OPENAI_API_KEY'
|
||||
]
|
||||
|
||||
missing_vars = [var for var in required_vars if not os.environ.get(var)]
|
||||
if missing_vars:
|
||||
raise RuntimeError(
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
|
||||
# Configuration dictionary
|
||||
config = {
|
||||
'development': DevelopmentConfig,
|
||||
'testing': TestingConfig,
|
||||
'production': ProductionConfig,
|
||||
'default': DevelopmentConfig
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Database module for the application.
|
||||
"""
|
||||
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
# Define naming convention for constraints
|
||||
convention = {
|
||||
"ix": 'ix_%(column_0_label)s',
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s"
|
||||
}
|
||||
|
||||
# Create SQLAlchemy instance with naming convention
|
||||
db = SQLAlchemy(metadata=MetaData(naming_convention=convention))
|
||||
|
||||
def init_app(app):
|
||||
"""
|
||||
Initialize the database with the Flask application.
|
||||
|
||||
Args:
|
||||
app: Flask application instance.
|
||||
"""
|
||||
db.init_app(app)
|
||||
|
||||
# Only initialize database if configured to do so
|
||||
if app.config.get('INITIALIZE_DATABASE', False):
|
||||
# Import models to ensure they are registered with SQLAlchemy
|
||||
from app.models import user, chat, document
|
||||
|
||||
# Create tables if they don't exist
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
Chat models for the application.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from app.database.db import db
|
||||
|
||||
class Chat(db.Model):
|
||||
"""Chat model representing a chat session."""
|
||||
|
||||
__tablename__ = 'chats'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
title = db.Column(db.String(100), nullable=True)
|
||||
is_team_chat = db.Column(db.Boolean, default=False)
|
||||
model_name = db.Column(db.String(50), nullable=False)
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Foreign keys
|
||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
messages = db.relationship('Message', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
|
||||
team_members = db.relationship('TeamChatMember', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Chat {self.id}: {self.title or "Untitled"}>'
|
||||
|
||||
|
||||
class Message(db.Model):
|
||||
"""Message model representing a single message in a chat."""
|
||||
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
is_user_message = db.Column(db.Boolean, default=True)
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
|
||||
# Foreign keys
|
||||
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Message {self.id}: {self.content[:20]}...>'
|
||||
|
||||
|
||||
class TeamChatMember(db.Model):
|
||||
"""Model representing a member of a team chat."""
|
||||
|
||||
__tablename__ = 'team_chat_members'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
joined_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
|
||||
# Foreign keys
|
||||
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
|
||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
||||
|
||||
# Ensure a user can only be added to a team chat once
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint('chat_id', 'user_id', name='uq_team_chat_member'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<TeamChatMember chat_id={self.chat_id}, user_id={self.user_id}>'
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
Document models for the application.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.database.db import db
|
||||
|
||||
class Document(db.Model):
|
||||
"""Document model representing a document in the library."""
|
||||
|
||||
__tablename__ = 'documents'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
title = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
file_path = db.Column(db.String(255), nullable=True)
|
||||
content_type = db.Column(db.String(50), nullable=False)
|
||||
status = db.Column(db.String(20), default='pending') # pending, processing, completed, error
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Foreign keys
|
||||
uploaded_by = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
chunks = db.relationship('DocumentChunk', backref='document', lazy='dynamic', cascade='all, delete-orphan')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Document {self.id}: {self.title}>'
|
||||
|
||||
|
||||
class DocumentChunk(db.Model):
|
||||
"""Model representing a chunk of a document for embedding."""
|
||||
|
||||
__tablename__ = 'document_chunks'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
chunk_index = db.Column(db.Integer, nullable=False)
|
||||
embedding_id = db.Column(db.String(100), nullable=True) # ID in Pinecone
|
||||
meta_data = db.Column(db.Text, nullable=True) # JSON string of metadata
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
|
||||
# Foreign keys
|
||||
document_id = db.Column(db.Integer, db.ForeignKey('documents.id'), nullable=False)
|
||||
|
||||
def set_metadata(self, metadata_dict):
|
||||
"""Set metadata as JSON string."""
|
||||
self.meta_data = json.dumps(metadata_dict)
|
||||
|
||||
def get_metadata(self):
|
||||
"""Get metadata as dictionary."""
|
||||
if self.meta_data:
|
||||
return json.loads(self.meta_data)
|
||||
return {}
|
||||
|
||||
def __repr__(self):
|
||||
return f'<DocumentChunk {self.id}: doc_id={self.document_id}, index={self.chunk_index}>'
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
User model for the application.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from app.database.db import db
|
||||
|
||||
class User(db.Model):
|
||||
"""User model representing application users."""
|
||||
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
username = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
|
||||
password_hash = db.Column(db.String(128), nullable=False)
|
||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
chats = db.relationship('Chat', backref='user', lazy='dynamic')
|
||||
|
||||
def __repr__(self):
|
||||
return f'<User {self.username}>'
|
||||
@@ -1,227 +0,0 @@
|
||||
"""
|
||||
Service for chat functionality.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from app.database.db import db
|
||||
from app.models.chat import Chat, Message, TeamChatMember
|
||||
from app.models.user import User
|
||||
|
||||
class ChatService:
|
||||
"""Service for chat functionality."""
|
||||
|
||||
def create_chat(self, user_id: int, title: Optional[str] = None,
|
||||
is_team_chat: bool = False, model_name: Optional[str] = None) -> Chat:
|
||||
"""
|
||||
Create a new chat.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user creating the chat.
|
||||
title: Optional title for the chat.
|
||||
is_team_chat: Whether this is a team chat.
|
||||
model_name: Name of the model to use for this chat.
|
||||
|
||||
Returns:
|
||||
Created chat.
|
||||
"""
|
||||
from app.config.config import Config
|
||||
|
||||
chat = Chat(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
is_team_chat=is_team_chat,
|
||||
model_name=model_name or Config().DEFAULT_MODEL
|
||||
)
|
||||
|
||||
db.session.add(chat)
|
||||
db.session.commit()
|
||||
|
||||
# If it's a team chat, add the creator as a member
|
||||
if is_team_chat:
|
||||
self.add_team_member(chat.id, user_id)
|
||||
|
||||
return chat
|
||||
|
||||
def get_chat(self, chat_id: int) -> Optional[Chat]:
|
||||
"""
|
||||
Get a chat by ID.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
Chat if found, None otherwise.
|
||||
"""
|
||||
return Chat.query.get(chat_id)
|
||||
|
||||
def get_user_chats(self, user_id: int) -> List[Chat]:
|
||||
"""
|
||||
Get all chats for a user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user.
|
||||
|
||||
Returns:
|
||||
List of chats.
|
||||
"""
|
||||
# Get private chats
|
||||
private_chats = Chat.query.filter_by(
|
||||
user_id=user_id,
|
||||
is_team_chat=False
|
||||
).order_by(Chat.updated_at.desc()).all()
|
||||
|
||||
# Get team chats where user is a member
|
||||
team_chat_ids = db.session.query(TeamChatMember.chat_id).filter_by(user_id=user_id).all()
|
||||
team_chat_ids = [chat_id for (chat_id,) in team_chat_ids]
|
||||
|
||||
team_chats = Chat.query.filter(
|
||||
Chat.id.in_(team_chat_ids)
|
||||
).order_by(Chat.updated_at.desc()).all()
|
||||
|
||||
# Combine and sort by updated_at
|
||||
all_chats = private_chats + team_chats
|
||||
all_chats.sort(key=lambda x: x.updated_at, reverse=True)
|
||||
|
||||
return all_chats
|
||||
|
||||
def add_message(self, chat_id: int, content: str,
|
||||
is_user_message: bool = True, user_id: Optional[int] = None) -> Message:
|
||||
"""
|
||||
Add a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
content: Message content.
|
||||
is_user_message: Whether this is a user message (vs. bot message).
|
||||
user_id: ID of the user sending the message (required for user messages).
|
||||
|
||||
Returns:
|
||||
Created message.
|
||||
"""
|
||||
message = Message(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id if is_user_message else None
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
|
||||
# Update chat's updated_at timestamp
|
||||
chat = Chat.query.get(chat_id)
|
||||
if chat:
|
||||
chat.updated_at = message.created_at
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return message
|
||||
|
||||
def get_chat_messages(self, chat_id: int) -> List[Message]:
|
||||
"""
|
||||
Get all messages for a chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
List of messages.
|
||||
"""
|
||||
return Message.query.filter_by(chat_id=chat_id).order_by(Message.created_at).all()
|
||||
|
||||
def add_team_member(self, chat_id: int, user_id: int) -> Optional[TeamChatMember]:
|
||||
"""
|
||||
Add a user to a team chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the team chat.
|
||||
user_id: ID of the user to add.
|
||||
|
||||
Returns:
|
||||
Created team chat member if successful, None otherwise.
|
||||
"""
|
||||
chat = Chat.query.get(chat_id)
|
||||
if not chat or not chat.is_team_chat:
|
||||
return None
|
||||
|
||||
# Check if user is already a member
|
||||
existing_member = TeamChatMember.query.filter_by(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id
|
||||
).first()
|
||||
|
||||
if existing_member:
|
||||
return existing_member
|
||||
|
||||
member = TeamChatMember(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
db.session.add(member)
|
||||
db.session.commit()
|
||||
|
||||
return member
|
||||
|
||||
def get_team_members(self, chat_id: int) -> List[User]:
|
||||
"""
|
||||
Get all members of a team chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the team chat.
|
||||
|
||||
Returns:
|
||||
List of users.
|
||||
"""
|
||||
member_ids = db.session.query(TeamChatMember.user_id).filter_by(chat_id=chat_id).all()
|
||||
member_ids = [user_id for (user_id,) in member_ids]
|
||||
|
||||
return User.query.filter(User.id.in_(member_ids)).all()
|
||||
|
||||
def remove_team_member(self, chat_id: int, user_id: int) -> bool:
|
||||
"""
|
||||
Remove a user from a team chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the team chat.
|
||||
user_id: ID of the user to remove.
|
||||
|
||||
Returns:
|
||||
True if removal was successful, False otherwise.
|
||||
"""
|
||||
member = TeamChatMember.query.filter_by(
|
||||
chat_id=chat_id,
|
||||
user_id=user_id
|
||||
).first()
|
||||
|
||||
if not member:
|
||||
return False
|
||||
|
||||
db.session.delete(member)
|
||||
db.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def delete_chat(self, chat_id: int) -> bool:
|
||||
"""
|
||||
Delete a chat and all its messages.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
chat = Chat.query.get(chat_id)
|
||||
if not chat:
|
||||
return False
|
||||
|
||||
try:
|
||||
db.session.delete(chat)
|
||||
db.session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
print(f"Error deleting chat {chat_id}: {str(e)}")
|
||||
db.session.rollback()
|
||||
return False
|
||||
@@ -1,105 +0,0 @@
|
||||
"""
|
||||
Service for chatbot functionality without database dependency.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class ChatbotService:
|
||||
"""Service for chatbot functionality."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the chatbot service."""
|
||||
# In-memory storage for chat history
|
||||
self.chat_history = {}
|
||||
self.current_chat_id = 0
|
||||
|
||||
def create_chat(self, user_id: str) -> int:
|
||||
"""
|
||||
Create a new chat session.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user creating the chat.
|
||||
|
||||
Returns:
|
||||
ID of the created chat.
|
||||
"""
|
||||
self.current_chat_id += 1
|
||||
chat_id = self.current_chat_id
|
||||
|
||||
self.chat_history[chat_id] = {
|
||||
'user_id': user_id,
|
||||
'messages': []
|
||||
}
|
||||
|
||||
return chat_id
|
||||
|
||||
def add_message(self, chat_id: int, content: str, is_user: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Add a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
content: Message content.
|
||||
is_user: Whether this is a user message (vs. bot message).
|
||||
|
||||
Returns:
|
||||
Added message.
|
||||
"""
|
||||
if chat_id not in self.chat_history:
|
||||
raise ValueError(f"Chat with ID {chat_id} not found")
|
||||
|
||||
message = {
|
||||
'content': content,
|
||||
'is_user': is_user,
|
||||
'timestamp': self._get_timestamp()
|
||||
}
|
||||
|
||||
self.chat_history[chat_id]['messages'].append(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_chat_messages(self, chat_id: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all messages for a chat.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
|
||||
Returns:
|
||||
List of messages.
|
||||
"""
|
||||
if chat_id not in self.chat_history:
|
||||
raise ValueError(f"Chat with ID {chat_id} not found")
|
||||
|
||||
return self.chat_history[chat_id]['messages']
|
||||
|
||||
def get_response(self, chat_id: int, message: str) -> str:
|
||||
"""
|
||||
Get a response from the chatbot.
|
||||
|
||||
Args:
|
||||
chat_id: ID of the chat.
|
||||
message: User message.
|
||||
|
||||
Returns:
|
||||
Bot response.
|
||||
"""
|
||||
# Add user message to chat history
|
||||
self.add_message(chat_id, message, is_user=True)
|
||||
|
||||
# Simple echo response for now
|
||||
response = f"You said: {message}"
|
||||
|
||||
# Add bot response to chat history
|
||||
self.add_message(chat_id, response, is_user=False)
|
||||
|
||||
return response
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""Get current timestamp."""
|
||||
from datetime import datetime
|
||||
return datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
chatbot_service = ChatbotService()
|
||||
@@ -1,165 +0,0 @@
|
||||
"""
|
||||
Service for document processing and embedding.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
import pinecone
|
||||
from app.database.db import db
|
||||
from app.models.document import Document, DocumentChunk
|
||||
from app.config.config import Config
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document processing and embedding."""
|
||||
|
||||
def __init__(self, config: Config = None):
|
||||
"""
|
||||
Initialize the document service.
|
||||
|
||||
Args:
|
||||
config: Configuration object.
|
||||
"""
|
||||
self.config = config or Config()
|
||||
self._initialize_pinecone()
|
||||
|
||||
def _initialize_pinecone(self):
|
||||
"""Initialize Pinecone client."""
|
||||
pinecone.init(
|
||||
api_key=self.config.PINECONE_API_KEY,
|
||||
environment=self.config.PINECONE_ENVIRONMENT
|
||||
)
|
||||
|
||||
# Check if index exists, create if it doesn't
|
||||
if self.config.PINECONE_INDEX_NAME not in pinecone.list_indexes():
|
||||
pinecone.create_index(
|
||||
name=self.config.PINECONE_INDEX_NAME,
|
||||
dimension=768, # Default dimension for sentence-transformers
|
||||
metric="cosine"
|
||||
)
|
||||
|
||||
self.index = pinecone.Index(self.config.PINECONE_INDEX_NAME)
|
||||
|
||||
def create_document(self, title: str, file_path: str, content_type: str,
|
||||
description: Optional[str], user_id: int) -> Document:
|
||||
"""
|
||||
Create a new document record.
|
||||
|
||||
Args:
|
||||
title: Document title.
|
||||
file_path: Path to the document file.
|
||||
content_type: MIME type of the document.
|
||||
description: Optional description of the document.
|
||||
user_id: ID of the user who uploaded the document.
|
||||
|
||||
Returns:
|
||||
Created document.
|
||||
"""
|
||||
document = Document(
|
||||
title=title,
|
||||
file_path=file_path,
|
||||
content_type=content_type,
|
||||
description=description,
|
||||
uploaded_by=user_id,
|
||||
status='pending'
|
||||
)
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
return document
|
||||
|
||||
def process_document(self, document_id: int) -> bool:
|
||||
"""
|
||||
Process a document for embedding.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to process.
|
||||
|
||||
Returns:
|
||||
True if processing was successful, False otherwise.
|
||||
"""
|
||||
document = Document.query.get(document_id)
|
||||
if not document:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Update status to processing
|
||||
document.status = 'processing'
|
||||
db.session.commit()
|
||||
|
||||
# TODO: Implement document parsing and chunking
|
||||
# This will be implemented in the next step
|
||||
|
||||
# Update status to completed
|
||||
document.status = 'completed'
|
||||
db.session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Update status to error
|
||||
document.status = 'error'
|
||||
db.session.commit()
|
||||
# Log the error
|
||||
print(f"Error processing document {document_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_document(self, document_id: int) -> Optional[Document]:
|
||||
"""
|
||||
Get a document by ID.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document.
|
||||
|
||||
Returns:
|
||||
Document if found, None otherwise.
|
||||
"""
|
||||
return Document.query.get(document_id)
|
||||
|
||||
def get_all_documents(self, user_id: Optional[int] = None) -> List[Document]:
|
||||
"""
|
||||
Get all documents, optionally filtered by user.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by.
|
||||
|
||||
Returns:
|
||||
List of documents.
|
||||
"""
|
||||
query = Document.query
|
||||
if user_id:
|
||||
query = query.filter_by(uploaded_by=user_id)
|
||||
return query.order_by(Document.created_at.desc()).all()
|
||||
|
||||
def delete_document(self, document_id: int) -> bool:
|
||||
"""
|
||||
Delete a document and its chunks.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
document = Document.query.get(document_id)
|
||||
if not document:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Delete document chunks from Pinecone
|
||||
chunks = DocumentChunk.query.filter_by(document_id=document_id).all()
|
||||
embedding_ids = [chunk.embedding_id for chunk in chunks if chunk.embedding_id]
|
||||
|
||||
if embedding_ids:
|
||||
self.index.delete(ids=embedding_ids)
|
||||
|
||||
# Delete document from database
|
||||
db.session.delete(document)
|
||||
db.session.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
print(f"Error deleting document {document_id}: {str(e)}")
|
||||
db.session.rollback()
|
||||
return False
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
Service for model management and interaction.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from app.config.config import Config
|
||||
|
||||
class ModelService:
|
||||
"""Service for model management and interaction."""
|
||||
|
||||
# Available models
|
||||
AVAILABLE_MODELS = {
|
||||
'gpt-3.5-turbo': {
|
||||
'name': 'GPT-3.5 Turbo',
|
||||
'description': 'OpenAI GPT-3.5 Turbo model',
|
||||
'provider': 'openai',
|
||||
'max_tokens': 4096
|
||||
},
|
||||
'gpt-4': {
|
||||
'name': 'GPT-4',
|
||||
'description': 'OpenAI GPT-4 model',
|
||||
'provider': 'openai',
|
||||
'max_tokens': 8192
|
||||
},
|
||||
# Add more models as needed
|
||||
}
|
||||
|
||||
def __init__(self, config: Config = None):
|
||||
"""
|
||||
Initialize the model service.
|
||||
|
||||
Args:
|
||||
config: Configuration object.
|
||||
"""
|
||||
self.config = config or Config()
|
||||
self.default_model = self.config.DEFAULT_MODEL
|
||||
|
||||
def get_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get a list of available models.
|
||||
|
||||
Returns:
|
||||
List of model information dictionaries.
|
||||
"""
|
||||
models = []
|
||||
for model_id, model_info in self.AVAILABLE_MODELS.items():
|
||||
model_data = {
|
||||
'id': model_id,
|
||||
'is_default': model_id == self.default_model,
|
||||
**model_info
|
||||
}
|
||||
models.append(model_data)
|
||||
|
||||
return models
|
||||
|
||||
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: ID of the model.
|
||||
|
||||
Returns:
|
||||
Model information dictionary if found, None otherwise.
|
||||
"""
|
||||
if model_id not in self.AVAILABLE_MODELS:
|
||||
return None
|
||||
|
||||
return {
|
||||
'id': model_id,
|
||||
'is_default': model_id == self.default_model,
|
||||
**self.AVAILABLE_MODELS[model_id]
|
||||
}
|
||||
|
||||
def generate_response(self, model_id: str, prompt: str,
|
||||
context: Optional[List[Dict[str, str]]] = None) -> str:
|
||||
"""
|
||||
Generate a response from the model.
|
||||
|
||||
Args:
|
||||
model_id: ID of the model to use.
|
||||
prompt: User prompt.
|
||||
context: Optional conversation context.
|
||||
|
||||
Returns:
|
||||
Generated response.
|
||||
"""
|
||||
# TODO: Implement actual model integration
|
||||
# This is a placeholder that will be implemented in the next steps
|
||||
|
||||
if model_id not in self.AVAILABLE_MODELS:
|
||||
model_id = self.default_model
|
||||
|
||||
# Placeholder response
|
||||
return f"This is a placeholder response from {self.AVAILABLE_MODELS[model_id]['name']}. The actual model integration will be implemented in the next steps."
|
||||
Executable
+22
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
# Script to run the RAG test in a Python virtual environment
|
||||
|
||||
# Make the script executable
|
||||
chmod +x test_rag.py
|
||||
|
||||
# Check if virtual environment is activated
|
||||
if [[ -z "$VIRTUAL_ENV" ]]; then
|
||||
echo "Virtual environment is not activated."
|
||||
echo "Please activate your virtual environment first with:"
|
||||
echo "source /path/to/your/venv/bin/activate"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running RAG tests against remote server..."
|
||||
python test_rag.py --remote --verbose
|
||||
|
||||
# If you want to test against a different server, uncomment and modify this line:
|
||||
# python test_rag.py --api-url "http://your-server-url:port" --verbose
|
||||
|
||||
# If you want to test with a specific query, uncomment and modify this line:
|
||||
# python test_rag.py --remote --query "What information do you have about project X?" --verbose
|
||||
-144
@@ -1,144 +0,0 @@
|
||||
"""
|
||||
Simple API for testing deployment.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Simple AI Service API",
|
||||
description="Simple API for testing deployment",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Define API models
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request model for sending a message."""
|
||||
message: str = Field(..., description="Message content")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
|
||||
# Model parameters
|
||||
temperature: Optional[float] = Field(None, description="Controls randomness")
|
||||
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
||||
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
|
||||
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
|
||||
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt")
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Model for a message."""
|
||||
id: str = Field(..., description="Message ID")
|
||||
content: str = Field(..., description="Message content")
|
||||
user_id: Optional[str] = Field(None, description="User ID")
|
||||
is_user_message: bool = Field(..., description="Whether this is a user message")
|
||||
timestamp: str = Field(..., description="Message timestamp")
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Request model for creating a chat."""
|
||||
user_id: str = Field(..., description="User ID")
|
||||
title: Optional[str] = Field(None, description="Chat title")
|
||||
model_id: Optional[str] = Field(None, description="Model ID")
|
||||
|
||||
class Chat(BaseModel):
|
||||
"""Model for a chat."""
|
||||
id: str = Field(..., description="Chat ID")
|
||||
title: str = Field(..., description="Chat title")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
model_id: str = Field(..., description="Model ID")
|
||||
created_at: str = Field(..., description="Creation timestamp")
|
||||
updated_at: str = Field(..., description="Update timestamp")
|
||||
messages: List[Message] = Field(default=[], description="Chat messages")
|
||||
|
||||
# In-memory storage
|
||||
chats = {}
|
||||
|
||||
# API endpoints
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.post("/chats", response_model=Chat)
|
||||
async def create_chat(request: ChatRequest):
|
||||
"""Create a new chat."""
|
||||
chat_id = str(uuid.uuid4())
|
||||
|
||||
chat = {
|
||||
"id": chat_id,
|
||||
"title": request.title or f"Chat {len(chats) + 1}",
|
||||
"user_id": request.user_id,
|
||||
"model_id": request.model_id or "gpt-3.5-turbo",
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
"messages": []
|
||||
}
|
||||
|
||||
chats[chat_id] = chat
|
||||
return chat
|
||||
|
||||
@app.get("/chats/{chat_id}", response_model=Chat)
|
||||
async def get_chat(chat_id: str):
|
||||
"""Get a chat by ID."""
|
||||
if chat_id not in chats:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
return chats[chat_id]
|
||||
|
||||
@app.post("/chats/{chat_id}/messages", response_model=Message)
|
||||
async def send_message(chat_id: str, request: MessageRequest):
|
||||
"""Send a message to a chat."""
|
||||
if chat_id not in chats:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
# Add user message
|
||||
user_message = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"content": request.message,
|
||||
"user_id": request.user_id,
|
||||
"is_user_message": True,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
chats[chat_id]["messages"].append(user_message)
|
||||
|
||||
# Generate bot response
|
||||
params_text = ""
|
||||
if request.temperature is not None:
|
||||
params_text += f" (temperature={request.temperature})"
|
||||
if request.max_tokens is not None:
|
||||
params_text += f" (max_tokens={request.max_tokens})"
|
||||
if request.system_prompt is not None:
|
||||
params_text += f" (using custom system prompt)"
|
||||
|
||||
bot_message = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"content": f"This is a test response to: '{request.message}'{params_text}",
|
||||
"user_id": None,
|
||||
"is_user_message": False,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
chats[chat_id]["messages"].append(bot_message)
|
||||
chats[chat_id]["updated_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
return bot_message
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5251)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
Test script for sending a message with advanced parameters.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import uuid
|
||||
|
||||
# Create a new chat
|
||||
def create_chat():
|
||||
response = requests.post(
|
||||
"http://localhost:5251/chats",
|
||||
json={
|
||||
"user_id": "test_user",
|
||||
"title": "Test Chat",
|
||||
"model_id": "gpt-3.5-turbo",
|
||||
"is_team_chat": False
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()["id"]
|
||||
else:
|
||||
print(f"Error creating chat: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
|
||||
# Send a message with advanced parameters
|
||||
def send_message(chat_id):
|
||||
response = requests.post(
|
||||
f"http://localhost:5251/chats/{chat_id}/messages",
|
||||
json={
|
||||
"message": "Tell me about artificial intelligence",
|
||||
"user_id": "test_user",
|
||||
"use_rag": False,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 500,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5,
|
||||
"presence_penalty": 0.5,
|
||||
"system_prompt": "You are a helpful AI assistant that provides concise responses."
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
print(f"Error sending message: {response.status_code}")
|
||||
print(response.text)
|
||||
return None
|
||||
|
||||
# Main function
|
||||
def main():
|
||||
print("Creating a new chat...")
|
||||
chat_id = create_chat()
|
||||
|
||||
if chat_id:
|
||||
print(f"Chat created with ID: {chat_id}")
|
||||
|
||||
print("\nSending a message with advanced parameters...")
|
||||
response = send_message(chat_id)
|
||||
|
||||
if response:
|
||||
print("\nResponse received:")
|
||||
print(f"Message ID: {response['id']}")
|
||||
print(f"Content: {response['content']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+309
@@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for RAG functionality in the AI service.
|
||||
This script tests the document-based question answering capabilities
|
||||
by making requests to the API endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import requests
|
||||
from typing import Dict, Any, Optional
|
||||
from pprint import pprint
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_API_URL = "http://localhost:5252" # Local development server
|
||||
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
|
||||
DEFAULT_MODEL = "llama3.1"
|
||||
|
||||
class RAGTester:
|
||||
"""Test the RAG functionality of the AI service."""
|
||||
|
||||
def __init__(self, api_url: str, verbose: bool = False):
|
||||
"""
|
||||
Initialize the RAG tester.
|
||||
|
||||
Args:
|
||||
api_url: URL of the AI service API.
|
||||
verbose: Whether to print verbose output.
|
||||
"""
|
||||
self.api_url = api_url
|
||||
self.verbose = verbose
|
||||
self.session = requests.Session()
|
||||
|
||||
# Print configuration
|
||||
print(f"Testing RAG functionality against API at: {self.api_url}")
|
||||
|
||||
def _log(self, message: str):
|
||||
"""Log a message if verbose mode is enabled."""
|
||||
if self.verbose:
|
||||
print(f"[DEBUG] {message}")
|
||||
|
||||
def check_server_health(self) -> bool:
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
Returns:
|
||||
True if the server is healthy, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/health", timeout=10)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "healthy":
|
||||
print("✅ Server is healthy")
|
||||
return True
|
||||
else:
|
||||
print("❌ Server health check failed")
|
||||
print(f"Response: {result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking server health: {str(e)}")
|
||||
return False
|
||||
|
||||
def check_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Check the server configuration.
|
||||
|
||||
Returns:
|
||||
Server configuration.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/config", timeout=10)
|
||||
response.raise_for_status()
|
||||
config = response.json()
|
||||
|
||||
print("✅ Server configuration:")
|
||||
print(f" - API Host: {config.get('api_host')}")
|
||||
print(f" - API Port: {config.get('api_port')}")
|
||||
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
|
||||
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
|
||||
print(f" - Default Model: {config.get('default_model')}")
|
||||
print(f" - API Timeout: {config.get('api_timeout')} seconds")
|
||||
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking server configuration: {str(e)}")
|
||||
return {}
|
||||
|
||||
def test_ollama_connection(self) -> bool:
|
||||
"""
|
||||
Test the connection to Ollama.
|
||||
|
||||
Returns:
|
||||
True if the connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully connected to Ollama API")
|
||||
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to connect to Ollama API")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing Ollama connection: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Test basic chat completion without RAG.
|
||||
|
||||
Args:
|
||||
model_id: Optional model ID to use.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully tested chat completion")
|
||||
print(f" - Model: {result.get('model')}")
|
||||
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to test chat completion")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing chat completion: {str(e)}")
|
||||
return False
|
||||
|
||||
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
|
||||
"""
|
||||
Test RAG completion with a query.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
response = self.session.post(
|
||||
f"{self.api_url}/test-rag",
|
||||
params={"query": query},
|
||||
timeout=120 # Longer timeout for RAG
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("status") == "success":
|
||||
print("✅ Successfully tested RAG completion")
|
||||
print(f" - Model: {result.get('model')}")
|
||||
print(f" - Query: {result.get('query')}")
|
||||
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
|
||||
|
||||
# Print full response in verbose mode
|
||||
if self.verbose:
|
||||
print("\nFull response:")
|
||||
print(result.get('response'))
|
||||
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to test RAG completion")
|
||||
print(f" - Error: {result.get('message')}")
|
||||
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing RAG completion: {str(e)}")
|
||||
return False
|
||||
|
||||
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
|
||||
"""
|
||||
Create a chat and test RAG with a message.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
user_id: User ID for the chat.
|
||||
|
||||
Returns:
|
||||
True if the test is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Create a chat
|
||||
create_response = self.session.post(
|
||||
f"{self.api_url}/chats",
|
||||
json={
|
||||
"user_id": user_id,
|
||||
"title": "RAG Test Chat",
|
||||
"model_id": DEFAULT_MODEL,
|
||||
"is_team_chat": False
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
chat = create_response.json()
|
||||
chat_id = chat.get("id")
|
||||
|
||||
print(f"✅ Created chat with ID: {chat_id}")
|
||||
|
||||
# Send a message with RAG enabled
|
||||
message_response = self.session.post(
|
||||
f"{self.api_url}/chats/{chat_id}/messages",
|
||||
json={
|
||||
"message": query,
|
||||
"user_id": user_id,
|
||||
"use_rag": True,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
timeout=120 # Longer timeout for RAG
|
||||
)
|
||||
message_response.raise_for_status()
|
||||
message = message_response.json()
|
||||
|
||||
print("✅ Successfully sent message with RAG")
|
||||
print(f" - Message ID: {message.get('id')}")
|
||||
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
|
||||
|
||||
# Print full response in verbose mode
|
||||
if self.verbose:
|
||||
print("\nFull response:")
|
||||
print(message.get('content'))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing chat with RAG: {str(e)}")
|
||||
return False
|
||||
|
||||
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
|
||||
"""
|
||||
Run all tests.
|
||||
|
||||
Args:
|
||||
query: Query to test with RAG.
|
||||
"""
|
||||
print("\n=== Running RAG Functionality Tests ===\n")
|
||||
|
||||
# Check server health
|
||||
if not self.check_server_health():
|
||||
print("❌ Server health check failed. Aborting tests.")
|
||||
return
|
||||
|
||||
# Check configuration
|
||||
config = self.check_config()
|
||||
if not config:
|
||||
print("❌ Failed to get server configuration. Continuing with tests...")
|
||||
|
||||
# Test Ollama connection
|
||||
if not self.test_ollama_connection():
|
||||
print("⚠️ Ollama connection test failed. Some tests may fail.")
|
||||
|
||||
# Test basic chat completion
|
||||
if not self.test_chat_completion():
|
||||
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
|
||||
|
||||
# Test RAG completion
|
||||
print("\n--- Testing RAG Completion ---\n")
|
||||
self.test_rag_completion(query)
|
||||
|
||||
# Test chat with RAG
|
||||
print("\n--- Testing Chat with RAG ---\n")
|
||||
self.create_chat_and_test_rag(query)
|
||||
|
||||
print("\n=== Tests Completed ===\n")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the tests."""
|
||||
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
|
||||
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
|
||||
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
|
||||
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use remote URL if specified
|
||||
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
|
||||
|
||||
# Create and run the tester
|
||||
tester = RAGTester(api_url=api_url, verbose=args.verbose)
|
||||
tester.run_all_tests(query=args.query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,2 @@
|
||||
requests>=2.28.0
|
||||
pydantic>=2.0.0
|
||||
Reference in New Issue
Block a user