Added Rag Featured
This commit is contained in:
@@ -0,0 +1,73 @@
|
|||||||
|
# Document-Based Question Answering with OpenWebUI
|
||||||
|
|
||||||
|
This document explains how to use the document-based question answering (RAG) functionality in the AI service.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The AI service now supports Retrieval Augmented Generation (RAG) by leveraging OpenWebUI's built-in knowledge database. This allows users to:
|
||||||
|
|
||||||
|
1. Upload documents to OpenWebUI
|
||||||
|
2. Ask questions about those documents
|
||||||
|
3. Receive responses that incorporate information from the documents
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
When RAG is enabled:
|
||||||
|
|
||||||
|
1. The AI service forwards the request to OpenWebUI with `use_knowledge=True`
|
||||||
|
2. OpenWebUI searches its knowledge database for relevant information
|
||||||
|
3. The retrieved information is used to augment the model's response
|
||||||
|
4. The response is returned to the user
|
||||||
|
|
||||||
|
## Using RAG in API Requests
|
||||||
|
|
||||||
|
To enable RAG in your API requests, set the `use_rag` parameter to `true`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
POST /chats/{chat_id}/messages
|
||||||
|
{
|
||||||
|
"message": "What information do you have about project X?",
|
||||||
|
"user_id": "user123",
|
||||||
|
"use_rag": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing RAG Functionality
|
||||||
|
|
||||||
|
You can test the RAG functionality using the `/test-rag` endpoint:
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /test-rag?query=What information do you have about project X?
|
||||||
|
```
|
||||||
|
|
||||||
|
This will return a response that includes information from documents in OpenWebUI's knowledge database.
|
||||||
|
|
||||||
|
## Uploading Documents to OpenWebUI
|
||||||
|
|
||||||
|
To use RAG effectively, you need to upload documents to OpenWebUI:
|
||||||
|
|
||||||
|
1. Log in to OpenWebUI at your configured URL (default: http://104.225.217.215:8080)
|
||||||
|
2. Navigate to the Knowledge section
|
||||||
|
3. Upload your documents (PDF, TXT, DOCX, etc.)
|
||||||
|
4. OpenWebUI will automatically process and index the documents
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If RAG is not working as expected:
|
||||||
|
|
||||||
|
1. Ensure OpenWebUI is running and accessible
|
||||||
|
2. Check that documents are properly uploaded and indexed in OpenWebUI
|
||||||
|
3. Verify that the `use_rag` parameter is set to `true` in your requests
|
||||||
|
4. Check the logs for any errors related to OpenWebUI API calls
|
||||||
|
|
||||||
|
If there are connection issues with OpenWebUI, the AI service will automatically fall back to using the direct Ollama API without RAG.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The following configuration settings affect RAG functionality:
|
||||||
|
|
||||||
|
- `OPENWEBUI_URL`: URL of your OpenWebUI instance
|
||||||
|
- `OPENWEBUI_API_KEY`: API key for OpenWebUI (if required)
|
||||||
|
- `API_TIMEOUT`: Timeout for API requests (in seconds)
|
||||||
|
|
||||||
|
These can be set in your environment variables or in the `.env` file.
|
||||||
@@ -1,6 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
FastAPI application for the AI service.
|
FastAPI application for the AI service.
|
||||||
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
|
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
|
||||||
|
|
||||||
|
The service supports document-based question answering using OpenWebUI's knowledge database:
|
||||||
|
- Set use_rag=True in API requests to enable Retrieval Augmented Generation
|
||||||
|
- When enabled, the service will use OpenWebUI's knowledge database to find relevant information
|
||||||
|
- Documents uploaded to OpenWebUI will be used to augment the model's responses
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
@@ -203,6 +208,43 @@ async def test_chat_completion():
|
|||||||
"ollama_url": config.OLLAMA_API_URL
|
"ollama_url": config.OLLAMA_API_URL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@app.post("/test-rag")
|
||||||
|
async def test_rag_completion(query: str = "What information do you have in your knowledge database?"):
|
||||||
|
"""
|
||||||
|
Test the RAG (Retrieval Augmented Generation) functionality with a query.
|
||||||
|
|
||||||
|
This endpoint tests the integration with OpenWebUI's knowledge database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The question to ask about documents in the knowledge database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model response using RAG.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use the model service directly with RAG enabled
|
||||||
|
response = model_service.generate_response(
|
||||||
|
model_id=config.DEFAULT_MODEL,
|
||||||
|
prompt=query,
|
||||||
|
context=[],
|
||||||
|
use_rag=True # Enable RAG
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"model": config.DEFAULT_MODEL,
|
||||||
|
"query": query,
|
||||||
|
"use_rag": True,
|
||||||
|
"response": response,
|
||||||
|
"openwebui_url": config.OPENWEBUI_URL
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Failed to get RAG completion: {str(e)}",
|
||||||
|
"openwebui_url": config.OPENWEBUI_URL
|
||||||
|
}
|
||||||
|
|
||||||
@app.post("/test-ollama-direct")
|
@app.post("/test-ollama-direct")
|
||||||
async def test_ollama_direct():
|
async def test_ollama_direct():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -198,5 +198,32 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"team_members": []
|
"team_members": []
|
||||||
|
},
|
||||||
|
"aefe133e-e85e-422f-9ca9-6441f24db74f": {
|
||||||
|
"id": "aefe133e-e85e-422f-9ca9-6441f24db74f",
|
||||||
|
"title": "RAG Test Chat",
|
||||||
|
"user_id": "test_user",
|
||||||
|
"model_id": "llama3.1",
|
||||||
|
"is_team_chat": false,
|
||||||
|
"created_at": "2025-05-16T14:41:08.281043",
|
||||||
|
"updated_at": "2025-05-16T14:41:09.142922",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"id": "35384001-0c6d-4ac6-a247-32fb53bb664c",
|
||||||
|
"content": "What information do you have in your knowledge database?",
|
||||||
|
"user_id": "test_user",
|
||||||
|
"is_user_message": true,
|
||||||
|
"timestamp": "2025-05-16T14:41:08.341940"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4b82555f-76e2-43c0-a1aa-c91257536133",
|
||||||
|
"content": "Connection error to Ollama API: HTTPConnectionPool(host='104.225.217.215', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x104bcc750>: Failed to establish a new connection: [Errno 61] Connection refused')). Please check if Ollama is running at http://104.225.217.215:11434.",
|
||||||
|
"user_id": null,
|
||||||
|
"is_user_message": false,
|
||||||
|
"timestamp": "2025-05-16T14:41:09.142892"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"team_members": [],
|
||||||
|
"openwebui_channel_id": null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
"""
|
|
||||||
Service for document processing and chunking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
import requests
|
|
||||||
import base64
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
|
|
||||||
from ai_service.config import config
|
|
||||||
|
|
||||||
class DocumentService:
|
|
||||||
"""Service for document processing and chunking."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the document service."""
|
|
||||||
self.chunk_size = config.CHUNK_SIZE
|
|
||||||
self.chunk_overlap = config.CHUNK_OVERLAP
|
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=self.chunk_size,
|
|
||||||
chunk_overlap=self.chunk_overlap,
|
|
||||||
length_function=len
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenWebUI configuration
|
|
||||||
self.openwebui_url = config.OPENWEBUI_URL
|
|
||||||
self.openwebui_api_key = config.OPENWEBUI_API_KEY
|
|
||||||
|
|
||||||
# Ensure data directory exists
|
|
||||||
os.makedirs('ai_service/data', exist_ok=True)
|
|
||||||
|
|
||||||
# For now, we'll store document metadata in a simple JSON file
|
|
||||||
self.metadata_file = 'ai_service/data/document_metadata.json'
|
|
||||||
self._load_metadata()
|
|
||||||
|
|
||||||
def _load_metadata(self):
|
|
||||||
"""Load document metadata from file."""
|
|
||||||
if os.path.exists(self.metadata_file):
|
|
||||||
try:
|
|
||||||
with open(self.metadata_file, 'r') as f:
|
|
||||||
self.documents = json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading document metadata: {str(e)}")
|
|
||||||
self.documents = {}
|
|
||||||
else:
|
|
||||||
self.documents = {}
|
|
||||||
|
|
||||||
def _save_metadata(self):
|
|
||||||
"""Save document metadata to file."""
|
|
||||||
try:
|
|
||||||
with open(self.metadata_file, 'w') as f:
|
|
||||||
json.dump(self.documents, f, indent=2)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving document metadata: {str(e)}")
|
|
||||||
|
|
||||||
def process_document(self, content: str, title: str,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None) -> str:
|
|
||||||
"""
|
|
||||||
Process a document for embedding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Document content.
|
|
||||||
title: Document title.
|
|
||||||
description: Optional document description.
|
|
||||||
metadata: Optional additional metadata.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Document ID.
|
|
||||||
"""
|
|
||||||
# Generate a unique ID for the document
|
|
||||||
doc_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Upload the document to OpenWebUI for RAG processing
|
|
||||||
try:
|
|
||||||
# Prepare headers
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
if self.openwebui_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
|
||||||
|
|
||||||
# Prepare the document data
|
|
||||||
document_data = {
|
|
||||||
"filename": f"{title}.txt",
|
|
||||||
"content": base64.b64encode(content.encode('utf-8')).decode('utf-8'),
|
|
||||||
"description": description or title
|
|
||||||
}
|
|
||||||
|
|
||||||
# Upload to OpenWebUI
|
|
||||||
response = requests.post(
|
|
||||||
f"{self.openwebui_url}/api/knowledge/upload",
|
|
||||||
headers=headers,
|
|
||||||
json=document_data,
|
|
||||||
timeout=60
|
|
||||||
)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
# Get the OpenWebUI document ID
|
|
||||||
openwebui_doc_id = result.get('id', '')
|
|
||||||
|
|
||||||
# Store document metadata
|
|
||||||
self.documents[doc_id] = {
|
|
||||||
'id': doc_id,
|
|
||||||
'title': title,
|
|
||||||
'description': description or '',
|
|
||||||
'openwebui_id': openwebui_doc_id,
|
|
||||||
'metadata': metadata or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save metadata to file
|
|
||||||
self._save_metadata()
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error uploading document to OpenWebUI: {str(e)}")
|
|
||||||
|
|
||||||
# Fall back to local processing if OpenWebUI upload fails
|
|
||||||
print("Falling back to local document processing")
|
|
||||||
|
|
||||||
# Split the document into chunks for local reference
|
|
||||||
chunks = self.text_splitter.split_text(content)
|
|
||||||
|
|
||||||
# Store document metadata
|
|
||||||
self.documents[doc_id] = {
|
|
||||||
'id': doc_id,
|
|
||||||
'title': title,
|
|
||||||
'description': description or '',
|
|
||||||
'chunk_count': len(chunks),
|
|
||||||
'openwebui_upload_failed': True,
|
|
||||||
'metadata': metadata or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save metadata to file
|
|
||||||
self._save_metadata()
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get document metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: Document ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Document metadata if found, None otherwise.
|
|
||||||
"""
|
|
||||||
return self.documents.get(doc_id)
|
|
||||||
|
|
||||||
def get_all_documents(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get all document metadata.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of document metadata.
|
|
||||||
"""
|
|
||||||
# Get documents from local storage
|
|
||||||
local_documents = list(self.documents.values())
|
|
||||||
|
|
||||||
# Try to get documents from OpenWebUI as well
|
|
||||||
try:
|
|
||||||
# Prepare headers
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
if self.openwebui_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
|
||||||
|
|
||||||
# Get documents from OpenWebUI
|
|
||||||
response = requests.get(
|
|
||||||
f"{self.openwebui_url}/api/knowledge",
|
|
||||||
headers=headers,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
openwebui_docs = response.json()
|
|
||||||
|
|
||||||
# Update local documents with OpenWebUI information
|
|
||||||
for doc in local_documents:
|
|
||||||
if 'openwebui_id' in doc:
|
|
||||||
for openwebui_doc in openwebui_docs:
|
|
||||||
if openwebui_doc.get('id') == doc['openwebui_id']:
|
|
||||||
doc['openwebui_status'] = 'active'
|
|
||||||
doc['openwebui_info'] = openwebui_doc
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error getting documents from OpenWebUI: {str(e)}")
|
|
||||||
|
|
||||||
return local_documents
|
|
||||||
|
|
||||||
def delete_document(self, doc_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a document and its chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: Document ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was successful, False otherwise.
|
|
||||||
"""
|
|
||||||
if doc_id not in self.documents:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if document was uploaded to OpenWebUI
|
|
||||||
doc = self.documents[doc_id]
|
|
||||||
openwebui_id = doc.get('openwebui_id')
|
|
||||||
|
|
||||||
if openwebui_id:
|
|
||||||
try:
|
|
||||||
# Prepare headers
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
if self.openwebui_api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
|
||||||
|
|
||||||
# Delete from OpenWebUI
|
|
||||||
response = requests.delete(
|
|
||||||
f"{self.openwebui_url}/api/knowledge/{openwebui_id}",
|
|
||||||
headers=headers,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"Warning: Failed to delete document from OpenWebUI: {response.text}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deleting document from OpenWebUI: {str(e)}")
|
|
||||||
|
|
||||||
# Delete document metadata
|
|
||||||
del self.documents[doc_id]
|
|
||||||
|
|
||||||
# Save metadata to file
|
|
||||||
self._save_metadata()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def search_documents(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Search for documents similar to a query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query.
|
|
||||||
top_k: Number of results to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of similar document chunks with their metadata.
|
|
||||||
"""
|
|
||||||
# Note: We don't need to implement this method anymore since
|
|
||||||
# RAG is handled directly by OpenWebUI when use_rag=True in the model service
|
|
||||||
|
|
||||||
# Return empty results - this is just a placeholder
|
|
||||||
# The actual RAG functionality is in the model_service.generate_response method
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance
|
|
||||||
document_service = DocumentService()
|
|
||||||
@@ -135,9 +135,6 @@ class ModelService:
|
|||||||
model_id = self.default_model
|
model_id = self.default_model
|
||||||
print(f" - Model not found, using default: {model_id}")
|
print(f" - Model not found, using default: {model_id}")
|
||||||
|
|
||||||
# Ensure we're using a valid model
|
|
||||||
# (model_id is already validated above)
|
|
||||||
|
|
||||||
# Prepare the messages for the API call
|
# Prepare the messages for the API call
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
@@ -178,6 +175,10 @@ class ModelService:
|
|||||||
openwebui_request['max_tokens'] = params['max_tokens']
|
openwebui_request['max_tokens'] = params['max_tokens']
|
||||||
if 'top_p' in params:
|
if 'top_p' in params:
|
||||||
openwebui_request['top_p'] = params['top_p']
|
openwebui_request['top_p'] = params['top_p']
|
||||||
|
if 'top_k' in params:
|
||||||
|
openwebui_request['top_k'] = params['top_k']
|
||||||
|
if 'repeat_penalty' in params:
|
||||||
|
openwebui_request['repeat_penalty'] = params['repeat_penalty']
|
||||||
|
|
||||||
# Make the API call to OpenWebUI
|
# Make the API call to OpenWebUI
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
@@ -201,10 +202,15 @@ class ModelService:
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
# Extract the response content
|
# Extract the response content
|
||||||
if 'message' in result:
|
if 'choices' in result and len(result['choices']) > 0 and 'message' in result['choices'][0]:
|
||||||
|
# OpenAI-compatible format
|
||||||
|
return result['choices'][0]['message']['content']
|
||||||
|
elif 'message' in result and 'content' in result['message']:
|
||||||
|
# OpenWebUI format
|
||||||
return result['message']['content']
|
return result['message']['content']
|
||||||
else:
|
else:
|
||||||
return "Error: Unexpected response format from OpenWebUI"
|
print(f"WARNING: Unexpected response format from OpenWebUI: {json.dumps(result, indent=2)}")
|
||||||
|
return "Error: Unexpected response format from OpenWebUI. Falling back to direct model call."
|
||||||
|
|
||||||
except requests.exceptions.Timeout as e:
|
except requests.exceptions.Timeout as e:
|
||||||
error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout."
|
error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout."
|
||||||
@@ -216,6 +222,11 @@ class ModelService:
|
|||||||
print(f"ERROR: {error_msg}")
|
print(f"ERROR: {error_msg}")
|
||||||
print("Falling back to direct Ollama call without RAG")
|
print("Falling back to direct Ollama call without RAG")
|
||||||
# Continue to the Ollama API call below
|
# Continue to the Ollama API call below
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
error_msg = f"HTTP error from OpenWebUI API: {str(e)}."
|
||||||
|
print(f"ERROR: {error_msg}")
|
||||||
|
print("Falling back to direct Ollama call without RAG")
|
||||||
|
# Continue to the Ollama API call below
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error calling OpenWebUI API: {str(e)}"
|
error_msg = f"Error calling OpenWebUI API: {str(e)}"
|
||||||
print(f"ERROR: {error_msg}")
|
print(f"ERROR: {error_msg}")
|
||||||
@@ -247,6 +258,8 @@ class ModelService:
|
|||||||
request_json['top_k'] = params['top_k']
|
request_json['top_k'] = params['top_k']
|
||||||
if 'max_tokens' in params:
|
if 'max_tokens' in params:
|
||||||
request_json['max_tokens'] = params['max_tokens']
|
request_json['max_tokens'] = params['max_tokens']
|
||||||
|
if 'repeat_penalty' in params:
|
||||||
|
request_json['repeat_penalty'] = params['repeat_penalty']
|
||||||
|
|
||||||
# Make the API call to Ollama
|
# Make the API call to Ollama
|
||||||
try:
|
try:
|
||||||
@@ -272,6 +285,7 @@ class ModelService:
|
|||||||
if 'message' in result and 'content' in result['message']:
|
if 'message' in result and 'content' in result['message']:
|
||||||
return result['message']['content']
|
return result['message']['content']
|
||||||
else:
|
else:
|
||||||
|
print(f"WARNING: Unexpected response format from Ollama: {json.dumps(result, indent=2)}")
|
||||||
return "Error: Unexpected response format from Ollama"
|
return "Error: Unexpected response format from Ollama"
|
||||||
|
|
||||||
except requests.exceptions.Timeout as e:
|
except requests.exceptions.Timeout as e:
|
||||||
|
|||||||
@@ -71,10 +71,16 @@ async def chat_completions(request: Request):
|
|||||||
temperature = body.get("temperature")
|
temperature = body.get("temperature")
|
||||||
max_tokens = body.get("max_tokens")
|
max_tokens = body.get("max_tokens")
|
||||||
top_p = body.get("top_p")
|
top_p = body.get("top_p")
|
||||||
|
top_k = body.get("top_k")
|
||||||
frequency_penalty = body.get("frequency_penalty")
|
frequency_penalty = body.get("frequency_penalty")
|
||||||
presence_penalty = body.get("presence_penalty")
|
presence_penalty = body.get("presence_penalty")
|
||||||
|
repeat_penalty = body.get("repeat_penalty")
|
||||||
stop = body.get("stop")
|
stop = body.get("stop")
|
||||||
|
|
||||||
|
# Check if RAG should be used
|
||||||
|
use_knowledge = body.get("use_knowledge", False)
|
||||||
|
use_rag = use_knowledge # Map OpenWebUI's use_knowledge to our use_rag parameter
|
||||||
|
|
||||||
# Create a unique chat ID
|
# Create a unique chat ID
|
||||||
chat_id = str(uuid.uuid4())
|
chat_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -99,11 +105,14 @@ async def chat_completions(request: Request):
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
message=user_message,
|
message=user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
use_rag=use_rag,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
|
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
"""
|
|
||||||
Main application package for the chatbot application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from flask import Flask
|
|
||||||
|
|
||||||
from app.config.config import Config
|
|
||||||
|
|
||||||
def create_app(config_class=Config):
|
|
||||||
"""
|
|
||||||
Create and configure the Flask application.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_class: Configuration class to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Flask application instance.
|
|
||||||
"""
|
|
||||||
# Initialize Flask app
|
|
||||||
flask_app = Flask(__name__)
|
|
||||||
flask_app.config.from_object(config_class)
|
|
||||||
|
|
||||||
# Register Flask routes
|
|
||||||
from app.api import routes as flask_routes
|
|
||||||
flask_app.register_blueprint(flask_routes.bp)
|
|
||||||
|
|
||||||
# For now, we'll use only Flask routes and disable FastAPI integration
|
|
||||||
# until we resolve the integration issues
|
|
||||||
|
|
||||||
# Initialize database
|
|
||||||
from app.database import db
|
|
||||||
db.init_app(flask_app)
|
|
||||||
|
|
||||||
return flask_app
|
|
||||||
-110
@@ -1,110 +0,0 @@
|
|||||||
"""
|
|
||||||
FastAPI routes for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from app.services.chatbot_service import chatbot_service
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
class MessageRequest(BaseModel):
|
|
||||||
"""Request model for sending a message."""
|
|
||||||
message: str
|
|
||||||
user_id: str = "default_user"
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
|
||||||
"""Response model for a message."""
|
|
||||||
content: str
|
|
||||||
is_user: bool
|
|
||||||
timestamp: str
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
|
||||||
"""Response model for a chat."""
|
|
||||||
chat_id: int
|
|
||||||
messages: List[MessageResponse]
|
|
||||||
|
|
||||||
@router.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""
|
|
||||||
Health check endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with health status.
|
|
||||||
"""
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
@router.post("/chat", response_model=ChatResponse)
|
|
||||||
async def create_chat(user_id: str = "default_user"):
|
|
||||||
"""
|
|
||||||
Create a new chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user creating the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created chat.
|
|
||||||
"""
|
|
||||||
chat_id = chatbot_service.create_chat(user_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
@router.post("/chat/{chat_id}/message", response_model=MessageResponse)
|
|
||||||
async def send_message(chat_id: int, request: MessageRequest):
|
|
||||||
"""
|
|
||||||
Send a message to the chatbot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
request: Message request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Bot response.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = chatbot_service.get_response(chat_id, request.message)
|
|
||||||
|
|
||||||
# Get the last message (bot response)
|
|
||||||
messages = chatbot_service.get_chat_messages(chat_id)
|
|
||||||
last_message = messages[-1]
|
|
||||||
|
|
||||||
return last_message
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
@router.get("/chat/{chat_id}", response_model=ChatResponse)
|
|
||||||
async def get_chat(chat_id: int):
|
|
||||||
"""
|
|
||||||
Get a chat by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Chat with messages.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
messages = chatbot_service.get_chat_messages(chat_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"messages": messages
|
|
||||||
}
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
|
||||||
|
|
||||||
def init_app(app):
|
|
||||||
"""
|
|
||||||
Initialize FastAPI application with routes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app: FastAPI application instance.
|
|
||||||
"""
|
|
||||||
app.include_router(router, prefix="/api")
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
"""
|
|
||||||
Flask routes for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from flask import Blueprint, jsonify, request, abort
|
|
||||||
|
|
||||||
from app.services.chatbot_service import chatbot_service
|
|
||||||
|
|
||||||
bp = Blueprint('main', __name__)
|
|
||||||
|
|
||||||
@bp.route('/')
|
|
||||||
def index():
|
|
||||||
"""
|
|
||||||
Root endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with application information.
|
|
||||||
"""
|
|
||||||
return jsonify({
|
|
||||||
'name': 'Chatbot Application',
|
|
||||||
'version': '1.0.0',
|
|
||||||
'status': 'running'
|
|
||||||
})
|
|
||||||
|
|
||||||
@bp.route('/api/health')
|
|
||||||
def health_check():
|
|
||||||
"""
|
|
||||||
Health check endpoint.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with health status.
|
|
||||||
"""
|
|
||||||
return jsonify({
|
|
||||||
'status': 'healthy'
|
|
||||||
})
|
|
||||||
|
|
||||||
@bp.route('/api/chat', methods=['POST'])
|
|
||||||
def create_chat():
|
|
||||||
"""
|
|
||||||
Create a new chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with chat ID.
|
|
||||||
"""
|
|
||||||
user_id = request.json.get('user_id', 'default_user')
|
|
||||||
chat_id = chatbot_service.create_chat(user_id)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'chat_id': chat_id,
|
|
||||||
'messages': []
|
|
||||||
})
|
|
||||||
|
|
||||||
@bp.route('/api/chat/<int:chat_id>/message', methods=['POST'])
|
|
||||||
def send_message(chat_id):
|
|
||||||
"""
|
|
||||||
Send a message to the chatbot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with bot response.
|
|
||||||
"""
|
|
||||||
if not request.json or 'message' not in request.json:
|
|
||||||
abort(400, description="Message is required")
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = request.json['message']
|
|
||||||
response = chatbot_service.get_response(chat_id, message)
|
|
||||||
|
|
||||||
# Get the last message (bot response)
|
|
||||||
messages = chatbot_service.get_chat_messages(chat_id)
|
|
||||||
last_message = messages[-1]
|
|
||||||
|
|
||||||
return jsonify(last_message)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
abort(404, description=str(e))
|
|
||||||
|
|
||||||
@bp.route('/api/chat/<int:chat_id>', methods=['GET'])
|
|
||||||
def get_chat(chat_id):
|
|
||||||
"""
|
|
||||||
Get a chat by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON response with chat messages.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
messages = chatbot_service.get_chat_messages(chat_id)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'chat_id': chat_id,
|
|
||||||
'messages': messages
|
|
||||||
})
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
abort(404, description=str(e))
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""
|
|
||||||
Configuration settings for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Base configuration."""
|
|
||||||
|
|
||||||
# Flask configuration
|
|
||||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'dev-key-for-development-only')
|
|
||||||
DEBUG = False
|
|
||||||
TESTING = False
|
|
||||||
|
|
||||||
# Database configuration
|
|
||||||
SQLALCHEMY_DATABASE_URI = os.environ.get(
|
|
||||||
'DATABASE_URL',
|
|
||||||
'sqlite:///chatbot.db'
|
|
||||||
)
|
|
||||||
SQLALCHEMY_TRACK_MODIFICATIONS = False
|
|
||||||
INITIALIZE_DATABASE = os.environ.get('INITIALIZE_DATABASE', 'False').lower() == 'true'
|
|
||||||
|
|
||||||
# Pinecone configuration
|
|
||||||
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', '')
|
|
||||||
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT', '')
|
|
||||||
PINECONE_INDEX_NAME = os.environ.get('PINECONE_INDEX_NAME', 'chatbot-index')
|
|
||||||
|
|
||||||
# Model configuration
|
|
||||||
DEFAULT_MODEL = os.environ.get('DEFAULT_MODEL', 'gpt-3.5-turbo')
|
|
||||||
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
|
|
||||||
|
|
||||||
|
|
||||||
class DevelopmentConfig(Config):
|
|
||||||
"""Development configuration."""
|
|
||||||
|
|
||||||
DEBUG = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestingConfig(Config):
|
|
||||||
"""Testing configuration."""
|
|
||||||
|
|
||||||
TESTING = True
|
|
||||||
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
|
|
||||||
|
|
||||||
|
|
||||||
class ProductionConfig(Config):
|
|
||||||
"""Production configuration."""
|
|
||||||
|
|
||||||
# Ensure all required environment variables are set in production
|
|
||||||
@classmethod
|
|
||||||
def init_app(cls, app):
|
|
||||||
"""Initialize production application."""
|
|
||||||
# Check for required environment variables
|
|
||||||
required_vars = [
|
|
||||||
'SECRET_KEY',
|
|
||||||
'DATABASE_URL',
|
|
||||||
'PINECONE_API_KEY',
|
|
||||||
'PINECONE_ENVIRONMENT',
|
|
||||||
'OPENAI_API_KEY'
|
|
||||||
]
|
|
||||||
|
|
||||||
missing_vars = [var for var in required_vars if not os.environ.get(var)]
|
|
||||||
if missing_vars:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Configuration dictionary
|
|
||||||
config = {
|
|
||||||
'development': DevelopmentConfig,
|
|
||||||
'testing': TestingConfig,
|
|
||||||
'production': ProductionConfig,
|
|
||||||
'default': DevelopmentConfig
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
"""
|
|
||||||
Database module for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
from sqlalchemy import MetaData
|
|
||||||
|
|
||||||
# Define naming convention for constraints
|
|
||||||
convention = {
|
|
||||||
"ix": 'ix_%(column_0_label)s',
|
|
||||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
|
||||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
|
||||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
||||||
"pk": "pk_%(table_name)s"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create SQLAlchemy instance with naming convention
|
|
||||||
db = SQLAlchemy(metadata=MetaData(naming_convention=convention))
|
|
||||||
|
|
||||||
def init_app(app):
|
|
||||||
"""
|
|
||||||
Initialize the database with the Flask application.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app: Flask application instance.
|
|
||||||
"""
|
|
||||||
db.init_app(app)
|
|
||||||
|
|
||||||
# Only initialize database if configured to do so
|
|
||||||
if app.config.get('INITIALIZE_DATABASE', False):
|
|
||||||
# Import models to ensure they are registered with SQLAlchemy
|
|
||||||
from app.models import user, chat, document
|
|
||||||
|
|
||||||
# Create tables if they don't exist
|
|
||||||
with app.app_context():
|
|
||||||
db.create_all()
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
"""
|
|
||||||
Chat models for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from app.database.db import db
|
|
||||||
|
|
||||||
class Chat(db.Model):
|
|
||||||
"""Chat model representing a chat session."""
|
|
||||||
|
|
||||||
__tablename__ = 'chats'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
title = db.Column(db.String(100), nullable=True)
|
|
||||||
is_team_chat = db.Column(db.Boolean, default=False)
|
|
||||||
model_name = db.Column(db.String(50), nullable=False)
|
|
||||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
||||||
|
|
||||||
# Foreign keys
|
|
||||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
messages = db.relationship('Message', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
|
|
||||||
team_members = db.relationship('TeamChatMember', backref='chat', lazy='dynamic', cascade='all, delete-orphan')
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<Chat {self.id}: {self.title or "Untitled"}>'
|
|
||||||
|
|
||||||
|
|
||||||
class Message(db.Model):
|
|
||||||
"""Message model representing a single message in a chat."""
|
|
||||||
|
|
||||||
__tablename__ = 'messages'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
content = db.Column(db.Text, nullable=False)
|
|
||||||
is_user_message = db.Column(db.Boolean, default=True)
|
|
||||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Foreign keys
|
|
||||||
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
|
|
||||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=True)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<Message {self.id}: {self.content[:20]}...>'
|
|
||||||
|
|
||||||
|
|
||||||
class TeamChatMember(db.Model):
|
|
||||||
"""Model representing a member of a team chat."""
|
|
||||||
|
|
||||||
__tablename__ = 'team_chat_members'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
joined_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Foreign keys
|
|
||||||
chat_id = db.Column(db.Integer, db.ForeignKey('chats.id'), nullable=False)
|
|
||||||
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
|
||||||
|
|
||||||
# Ensure a user can only be added to a team chat once
|
|
||||||
__table_args__ = (
|
|
||||||
db.UniqueConstraint('chat_id', 'user_id', name='uq_team_chat_member'),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<TeamChatMember chat_id={self.chat_id}, user_id={self.user_id}>'
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
"""
|
|
||||||
Document models for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
import json
|
|
||||||
from app.database.db import db
|
|
||||||
|
|
||||||
class Document(db.Model):
|
|
||||||
"""Document model representing a document in the library."""
|
|
||||||
|
|
||||||
__tablename__ = 'documents'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
title = db.Column(db.String(255), nullable=False)
|
|
||||||
description = db.Column(db.Text, nullable=True)
|
|
||||||
file_path = db.Column(db.String(255), nullable=True)
|
|
||||||
content_type = db.Column(db.String(50), nullable=False)
|
|
||||||
status = db.Column(db.String(20), default='pending') # pending, processing, completed, error
|
|
||||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
||||||
|
|
||||||
# Foreign keys
|
|
||||||
uploaded_by = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
chunks = db.relationship('DocumentChunk', backref='document', lazy='dynamic', cascade='all, delete-orphan')
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<Document {self.id}: {self.title}>'
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(db.Model):
|
|
||||||
"""Model representing a chunk of a document for embedding."""
|
|
||||||
|
|
||||||
__tablename__ = 'document_chunks'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
content = db.Column(db.Text, nullable=False)
|
|
||||||
chunk_index = db.Column(db.Integer, nullable=False)
|
|
||||||
embedding_id = db.Column(db.String(100), nullable=True) # ID in Pinecone
|
|
||||||
meta_data = db.Column(db.Text, nullable=True) # JSON string of metadata
|
|
||||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Foreign keys
|
|
||||||
document_id = db.Column(db.Integer, db.ForeignKey('documents.id'), nullable=False)
|
|
||||||
|
|
||||||
def set_metadata(self, metadata_dict):
|
|
||||||
"""Set metadata as JSON string."""
|
|
||||||
self.meta_data = json.dumps(metadata_dict)
|
|
||||||
|
|
||||||
def get_metadata(self):
|
|
||||||
"""Get metadata as dictionary."""
|
|
||||||
if self.meta_data:
|
|
||||||
return json.loads(self.meta_data)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<DocumentChunk {self.id}: doc_id={self.document_id}, index={self.chunk_index}>'
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
User model for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from app.database.db import db
|
|
||||||
|
|
||||||
class User(db.Model):
|
|
||||||
"""User model representing application users."""
|
|
||||||
|
|
||||||
__tablename__ = 'users'
|
|
||||||
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
username = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
|
||||||
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
|
|
||||||
password_hash = db.Column(db.String(128), nullable=False)
|
|
||||||
created_at = db.Column(db.DateTime, default=datetime.utcnow)
|
|
||||||
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
chats = db.relationship('Chat', backref='user', lazy='dynamic')
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'<User {self.username}>'
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
"""
|
|
||||||
Service for chat functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from app.database.db import db
|
|
||||||
from app.models.chat import Chat, Message, TeamChatMember
|
|
||||||
from app.models.user import User
|
|
||||||
|
|
||||||
class ChatService:
|
|
||||||
"""Service for chat functionality."""
|
|
||||||
|
|
||||||
def create_chat(self, user_id: int, title: Optional[str] = None,
|
|
||||||
is_team_chat: bool = False, model_name: Optional[str] = None) -> Chat:
|
|
||||||
"""
|
|
||||||
Create a new chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user creating the chat.
|
|
||||||
title: Optional title for the chat.
|
|
||||||
is_team_chat: Whether this is a team chat.
|
|
||||||
model_name: Name of the model to use for this chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created chat.
|
|
||||||
"""
|
|
||||||
from app.config.config import Config
|
|
||||||
|
|
||||||
chat = Chat(
|
|
||||||
user_id=user_id,
|
|
||||||
title=title,
|
|
||||||
is_team_chat=is_team_chat,
|
|
||||||
model_name=model_name or Config().DEFAULT_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(chat)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# If it's a team chat, add the creator as a member
|
|
||||||
if is_team_chat:
|
|
||||||
self.add_team_member(chat.id, user_id)
|
|
||||||
|
|
||||||
return chat
|
|
||||||
|
|
||||||
def get_chat(self, chat_id: int) -> Optional[Chat]:
|
|
||||||
"""
|
|
||||||
Get a chat by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Chat if found, None otherwise.
|
|
||||||
"""
|
|
||||||
return Chat.query.get(chat_id)
|
|
||||||
|
|
||||||
def get_user_chats(self, user_id: int) -> List[Chat]:
|
|
||||||
"""
|
|
||||||
Get all chats for a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of chats.
|
|
||||||
"""
|
|
||||||
# Get private chats
|
|
||||||
private_chats = Chat.query.filter_by(
|
|
||||||
user_id=user_id,
|
|
||||||
is_team_chat=False
|
|
||||||
).order_by(Chat.updated_at.desc()).all()
|
|
||||||
|
|
||||||
# Get team chats where user is a member
|
|
||||||
team_chat_ids = db.session.query(TeamChatMember.chat_id).filter_by(user_id=user_id).all()
|
|
||||||
team_chat_ids = [chat_id for (chat_id,) in team_chat_ids]
|
|
||||||
|
|
||||||
team_chats = Chat.query.filter(
|
|
||||||
Chat.id.in_(team_chat_ids)
|
|
||||||
).order_by(Chat.updated_at.desc()).all()
|
|
||||||
|
|
||||||
# Combine and sort by updated_at
|
|
||||||
all_chats = private_chats + team_chats
|
|
||||||
all_chats.sort(key=lambda x: x.updated_at, reverse=True)
|
|
||||||
|
|
||||||
return all_chats
|
|
||||||
|
|
||||||
def add_message(self, chat_id: int, content: str,
|
|
||||||
is_user_message: bool = True, user_id: Optional[int] = None) -> Message:
|
|
||||||
"""
|
|
||||||
Add a message to a chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
content: Message content.
|
|
||||||
is_user_message: Whether this is a user message (vs. bot message).
|
|
||||||
user_id: ID of the user sending the message (required for user messages).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created message.
|
|
||||||
"""
|
|
||||||
message = Message(
|
|
||||||
chat_id=chat_id,
|
|
||||||
content=content,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
user_id=user_id if is_user_message else None
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(message)
|
|
||||||
|
|
||||||
# Update chat's updated_at timestamp
|
|
||||||
chat = Chat.query.get(chat_id)
|
|
||||||
if chat:
|
|
||||||
chat.updated_at = message.created_at
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def get_chat_messages(self, chat_id: int) -> List[Message]:
|
|
||||||
"""
|
|
||||||
Get all messages for a chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages.
|
|
||||||
"""
|
|
||||||
return Message.query.filter_by(chat_id=chat_id).order_by(Message.created_at).all()
|
|
||||||
|
|
||||||
def add_team_member(self, chat_id: int, user_id: int) -> Optional[TeamChatMember]:
|
|
||||||
"""
|
|
||||||
Add a user to a team chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the team chat.
|
|
||||||
user_id: ID of the user to add.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created team chat member if successful, None otherwise.
|
|
||||||
"""
|
|
||||||
chat = Chat.query.get(chat_id)
|
|
||||||
if not chat or not chat.is_team_chat:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if user is already a member
|
|
||||||
existing_member = TeamChatMember.query.filter_by(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=user_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing_member:
|
|
||||||
return existing_member
|
|
||||||
|
|
||||||
member = TeamChatMember(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(member)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return member
|
|
||||||
|
|
||||||
def get_team_members(self, chat_id: int) -> List[User]:
|
|
||||||
"""
|
|
||||||
Get all members of a team chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the team chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of users.
|
|
||||||
"""
|
|
||||||
member_ids = db.session.query(TeamChatMember.user_id).filter_by(chat_id=chat_id).all()
|
|
||||||
member_ids = [user_id for (user_id,) in member_ids]
|
|
||||||
|
|
||||||
return User.query.filter(User.id.in_(member_ids)).all()
|
|
||||||
|
|
||||||
def remove_team_member(self, chat_id: int, user_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Remove a user from a team chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the team chat.
|
|
||||||
user_id: ID of the user to remove.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if removal was successful, False otherwise.
|
|
||||||
"""
|
|
||||||
member = TeamChatMember.query.filter_by(
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=user_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not member:
|
|
||||||
return False
|
|
||||||
|
|
||||||
db.session.delete(member)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def delete_chat(self, chat_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a chat and all its messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat to delete.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was successful, False otherwise.
|
|
||||||
"""
|
|
||||||
chat = Chat.query.get(chat_id)
|
|
||||||
if not chat:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
db.session.delete(chat)
|
|
||||||
db.session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log the error
|
|
||||||
print(f"Error deleting chat {chat_id}: {str(e)}")
|
|
||||||
db.session.rollback()
|
|
||||||
return False
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
"""
|
|
||||||
Service for chatbot functionality without database dependency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
class ChatbotService:
|
|
||||||
"""Service for chatbot functionality."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the chatbot service."""
|
|
||||||
# In-memory storage for chat history
|
|
||||||
self.chat_history = {}
|
|
||||||
self.current_chat_id = 0
|
|
||||||
|
|
||||||
def create_chat(self, user_id: str) -> int:
|
|
||||||
"""
|
|
||||||
Create a new chat session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: ID of the user creating the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ID of the created chat.
|
|
||||||
"""
|
|
||||||
self.current_chat_id += 1
|
|
||||||
chat_id = self.current_chat_id
|
|
||||||
|
|
||||||
self.chat_history[chat_id] = {
|
|
||||||
'user_id': user_id,
|
|
||||||
'messages': []
|
|
||||||
}
|
|
||||||
|
|
||||||
return chat_id
|
|
||||||
|
|
||||||
def add_message(self, chat_id: int, content: str, is_user: bool = True) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Add a message to a chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
content: Message content.
|
|
||||||
is_user: Whether this is a user message (vs. bot message).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Added message.
|
|
||||||
"""
|
|
||||||
if chat_id not in self.chat_history:
|
|
||||||
raise ValueError(f"Chat with ID {chat_id} not found")
|
|
||||||
|
|
||||||
message = {
|
|
||||||
'content': content,
|
|
||||||
'is_user': is_user,
|
|
||||||
'timestamp': self._get_timestamp()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.chat_history[chat_id]['messages'].append(message)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def get_chat_messages(self, chat_id: int) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get all messages for a chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages.
|
|
||||||
"""
|
|
||||||
if chat_id not in self.chat_history:
|
|
||||||
raise ValueError(f"Chat with ID {chat_id} not found")
|
|
||||||
|
|
||||||
return self.chat_history[chat_id]['messages']
|
|
||||||
|
|
||||||
def get_response(self, chat_id: int, message: str) -> str:
|
|
||||||
"""
|
|
||||||
Get a response from the chatbot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: ID of the chat.
|
|
||||||
message: User message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Bot response.
|
|
||||||
"""
|
|
||||||
# Add user message to chat history
|
|
||||||
self.add_message(chat_id, message, is_user=True)
|
|
||||||
|
|
||||||
# Simple echo response for now
|
|
||||||
response = f"You said: {message}"
|
|
||||||
|
|
||||||
# Add bot response to chat history
|
|
||||||
self.add_message(chat_id, response, is_user=False)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _get_timestamp(self) -> str:
|
|
||||||
"""Get current timestamp."""
|
|
||||||
from datetime import datetime
|
|
||||||
return datetime.utcnow().isoformat()
|
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance
|
|
||||||
chatbot_service = ChatbotService()
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""
|
|
||||||
Service for document processing and embedding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
import pinecone
|
|
||||||
from app.database.db import db
|
|
||||||
from app.models.document import Document, DocumentChunk
|
|
||||||
from app.config.config import Config
|
|
||||||
|
|
||||||
class DocumentService:
|
|
||||||
"""Service for document processing and embedding."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config = None):
|
|
||||||
"""
|
|
||||||
Initialize the document service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object.
|
|
||||||
"""
|
|
||||||
self.config = config or Config()
|
|
||||||
self._initialize_pinecone()
|
|
||||||
|
|
||||||
def _initialize_pinecone(self):
|
|
||||||
"""Initialize Pinecone client."""
|
|
||||||
pinecone.init(
|
|
||||||
api_key=self.config.PINECONE_API_KEY,
|
|
||||||
environment=self.config.PINECONE_ENVIRONMENT
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if index exists, create if it doesn't
|
|
||||||
if self.config.PINECONE_INDEX_NAME not in pinecone.list_indexes():
|
|
||||||
pinecone.create_index(
|
|
||||||
name=self.config.PINECONE_INDEX_NAME,
|
|
||||||
dimension=768, # Default dimension for sentence-transformers
|
|
||||||
metric="cosine"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.index = pinecone.Index(self.config.PINECONE_INDEX_NAME)
|
|
||||||
|
|
||||||
def create_document(self, title: str, file_path: str, content_type: str,
|
|
||||||
description: Optional[str], user_id: int) -> Document:
|
|
||||||
"""
|
|
||||||
Create a new document record.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
title: Document title.
|
|
||||||
file_path: Path to the document file.
|
|
||||||
content_type: MIME type of the document.
|
|
||||||
description: Optional description of the document.
|
|
||||||
user_id: ID of the user who uploaded the document.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created document.
|
|
||||||
"""
|
|
||||||
document = Document(
|
|
||||||
title=title,
|
|
||||||
file_path=file_path,
|
|
||||||
content_type=content_type,
|
|
||||||
description=description,
|
|
||||||
uploaded_by=user_id,
|
|
||||||
status='pending'
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(document)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return document
|
|
||||||
|
|
||||||
def process_document(self, document_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Process a document for embedding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document_id: ID of the document to process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if processing was successful, False otherwise.
|
|
||||||
"""
|
|
||||||
document = Document.query.get(document_id)
|
|
||||||
if not document:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Update status to processing
|
|
||||||
document.status = 'processing'
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# TODO: Implement document parsing and chunking
|
|
||||||
# This will be implemented in the next step
|
|
||||||
|
|
||||||
# Update status to completed
|
|
||||||
document.status = 'completed'
|
|
||||||
db.session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Update status to error
|
|
||||||
document.status = 'error'
|
|
||||||
db.session.commit()
|
|
||||||
# Log the error
|
|
||||||
print(f"Error processing document {document_id}: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_document(self, document_id: int) -> Optional[Document]:
|
|
||||||
"""
|
|
||||||
Get a document by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document_id: ID of the document.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Document if found, None otherwise.
|
|
||||||
"""
|
|
||||||
return Document.query.get(document_id)
|
|
||||||
|
|
||||||
def get_all_documents(self, user_id: Optional[int] = None) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Get all documents, optionally filtered by user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Optional user ID to filter by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of documents.
|
|
||||||
"""
|
|
||||||
query = Document.query
|
|
||||||
if user_id:
|
|
||||||
query = query.filter_by(uploaded_by=user_id)
|
|
||||||
return query.order_by(Document.created_at.desc()).all()
|
|
||||||
|
|
||||||
def delete_document(self, document_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a document and its chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document_id: ID of the document to delete.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deletion was successful, False otherwise.
|
|
||||||
"""
|
|
||||||
document = Document.query.get(document_id)
|
|
||||||
if not document:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Delete document chunks from Pinecone
|
|
||||||
chunks = DocumentChunk.query.filter_by(document_id=document_id).all()
|
|
||||||
embedding_ids = [chunk.embedding_id for chunk in chunks if chunk.embedding_id]
|
|
||||||
|
|
||||||
if embedding_ids:
|
|
||||||
self.index.delete(ids=embedding_ids)
|
|
||||||
|
|
||||||
# Delete document from database
|
|
||||||
db.session.delete(document)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Log the error
|
|
||||||
print(f"Error deleting document {document_id}: {str(e)}")
|
|
||||||
db.session.rollback()
|
|
||||||
return False
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
"""
|
|
||||||
Service for model management and interaction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from app.config.config import Config
|
|
||||||
|
|
||||||
class ModelService:
|
|
||||||
"""Service for model management and interaction."""
|
|
||||||
|
|
||||||
# Available models
|
|
||||||
AVAILABLE_MODELS = {
|
|
||||||
'gpt-3.5-turbo': {
|
|
||||||
'name': 'GPT-3.5 Turbo',
|
|
||||||
'description': 'OpenAI GPT-3.5 Turbo model',
|
|
||||||
'provider': 'openai',
|
|
||||||
'max_tokens': 4096
|
|
||||||
},
|
|
||||||
'gpt-4': {
|
|
||||||
'name': 'GPT-4',
|
|
||||||
'description': 'OpenAI GPT-4 model',
|
|
||||||
'provider': 'openai',
|
|
||||||
'max_tokens': 8192
|
|
||||||
},
|
|
||||||
# Add more models as needed
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, config: Config = None):
|
|
||||||
"""
|
|
||||||
Initialize the model service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration object.
|
|
||||||
"""
|
|
||||||
self.config = config or Config()
|
|
||||||
self.default_model = self.config.DEFAULT_MODEL
|
|
||||||
|
|
||||||
def get_available_models(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get a list of available models.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model information dictionaries.
|
|
||||||
"""
|
|
||||||
models = []
|
|
||||||
for model_id, model_info in self.AVAILABLE_MODELS.items():
|
|
||||||
model_data = {
|
|
||||||
'id': model_id,
|
|
||||||
'is_default': model_id == self.default_model,
|
|
||||||
**model_info
|
|
||||||
}
|
|
||||||
models.append(model_data)
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get information about a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: ID of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model information dictionary if found, None otherwise.
|
|
||||||
"""
|
|
||||||
if model_id not in self.AVAILABLE_MODELS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return {
|
|
||||||
'id': model_id,
|
|
||||||
'is_default': model_id == self.default_model,
|
|
||||||
**self.AVAILABLE_MODELS[model_id]
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate_response(self, model_id: str, prompt: str,
|
|
||||||
context: Optional[List[Dict[str, str]]] = None) -> str:
|
|
||||||
"""
|
|
||||||
Generate a response from the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: ID of the model to use.
|
|
||||||
prompt: User prompt.
|
|
||||||
context: Optional conversation context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated response.
|
|
||||||
"""
|
|
||||||
# TODO: Implement actual model integration
|
|
||||||
# This is a placeholder that will be implemented in the next steps
|
|
||||||
|
|
||||||
if model_id not in self.AVAILABLE_MODELS:
|
|
||||||
model_id = self.default_model
|
|
||||||
|
|
||||||
# Placeholder response
|
|
||||||
return f"This is a placeholder response from {self.AVAILABLE_MODELS[model_id]['name']}. The actual model integration will be implemented in the next steps."
|
|
||||||
Executable
+22
@@ -0,0 +1,22 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Script to run the RAG test in a Python virtual environment
|
||||||
|
|
||||||
|
# Make the script executable
|
||||||
|
chmod +x test_rag.py
|
||||||
|
|
||||||
|
# Check if virtual environment is activated
|
||||||
|
if [[ -z "$VIRTUAL_ENV" ]]; then
|
||||||
|
echo "Virtual environment is not activated."
|
||||||
|
echo "Please activate your virtual environment first with:"
|
||||||
|
echo "source /path/to/your/venv/bin/activate"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running RAG tests against remote server..."
|
||||||
|
python test_rag.py --remote --verbose
|
||||||
|
|
||||||
|
# If you want to test against a different server, uncomment and modify this line:
|
||||||
|
# python test_rag.py --api-url "http://your-server-url:port" --verbose
|
||||||
|
|
||||||
|
# If you want to test with a specific query, uncomment and modify this line:
|
||||||
|
# python test_rag.py --remote --query "What information do you have about project X?" --verbose
|
||||||
-144
@@ -1,144 +0,0 @@
|
|||||||
"""
|
|
||||||
Simple API for testing deployment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
# Create FastAPI app
|
|
||||||
app = FastAPI(
|
|
||||||
title="Simple AI Service API",
|
|
||||||
description="Simple API for testing deployment",
|
|
||||||
version="1.0.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add CORS middleware
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define API models
|
|
||||||
class MessageRequest(BaseModel):
|
|
||||||
"""Request model for sending a message."""
|
|
||||||
message: str = Field(..., description="Message content")
|
|
||||||
user_id: str = Field(..., description="User ID")
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
temperature: Optional[float] = Field(None, description="Controls randomness")
|
|
||||||
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
|
||||||
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
|
|
||||||
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
|
|
||||||
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
|
|
||||||
system_prompt: Optional[str] = Field(None, description="System prompt")
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
"""Model for a message."""
|
|
||||||
id: str = Field(..., description="Message ID")
|
|
||||||
content: str = Field(..., description="Message content")
|
|
||||||
user_id: Optional[str] = Field(None, description="User ID")
|
|
||||||
is_user_message: bool = Field(..., description="Whether this is a user message")
|
|
||||||
timestamp: str = Field(..., description="Message timestamp")
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
"""Request model for creating a chat."""
|
|
||||||
user_id: str = Field(..., description="User ID")
|
|
||||||
title: Optional[str] = Field(None, description="Chat title")
|
|
||||||
model_id: Optional[str] = Field(None, description="Model ID")
|
|
||||||
|
|
||||||
class Chat(BaseModel):
|
|
||||||
"""Model for a chat."""
|
|
||||||
id: str = Field(..., description="Chat ID")
|
|
||||||
title: str = Field(..., description="Chat title")
|
|
||||||
user_id: str = Field(..., description="User ID")
|
|
||||||
model_id: str = Field(..., description="Model ID")
|
|
||||||
created_at: str = Field(..., description="Creation timestamp")
|
|
||||||
updated_at: str = Field(..., description="Update timestamp")
|
|
||||||
messages: List[Message] = Field(default=[], description="Chat messages")
|
|
||||||
|
|
||||||
# In-memory storage
|
|
||||||
chats = {}
|
|
||||||
|
|
||||||
# API endpoints
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint."""
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
@app.post("/chats", response_model=Chat)
|
|
||||||
async def create_chat(request: ChatRequest):
|
|
||||||
"""Create a new chat."""
|
|
||||||
chat_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
chat = {
|
|
||||||
"id": chat_id,
|
|
||||||
"title": request.title or f"Chat {len(chats) + 1}",
|
|
||||||
"user_id": request.user_id,
|
|
||||||
"model_id": request.model_id or "gpt-3.5-turbo",
|
|
||||||
"created_at": datetime.utcnow().isoformat(),
|
|
||||||
"updated_at": datetime.utcnow().isoformat(),
|
|
||||||
"messages": []
|
|
||||||
}
|
|
||||||
|
|
||||||
chats[chat_id] = chat
|
|
||||||
return chat
|
|
||||||
|
|
||||||
@app.get("/chats/{chat_id}", response_model=Chat)
|
|
||||||
async def get_chat(chat_id: str):
|
|
||||||
"""Get a chat by ID."""
|
|
||||||
if chat_id not in chats:
|
|
||||||
raise HTTPException(status_code=404, detail="Chat not found")
|
|
||||||
|
|
||||||
return chats[chat_id]
|
|
||||||
|
|
||||||
@app.post("/chats/{chat_id}/messages", response_model=Message)
|
|
||||||
async def send_message(chat_id: str, request: MessageRequest):
|
|
||||||
"""Send a message to a chat."""
|
|
||||||
if chat_id not in chats:
|
|
||||||
raise HTTPException(status_code=404, detail="Chat not found")
|
|
||||||
|
|
||||||
# Add user message
|
|
||||||
user_message = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"content": request.message,
|
|
||||||
"user_id": request.user_id,
|
|
||||||
"is_user_message": True,
|
|
||||||
"timestamp": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
chats[chat_id]["messages"].append(user_message)
|
|
||||||
|
|
||||||
# Generate bot response
|
|
||||||
params_text = ""
|
|
||||||
if request.temperature is not None:
|
|
||||||
params_text += f" (temperature={request.temperature})"
|
|
||||||
if request.max_tokens is not None:
|
|
||||||
params_text += f" (max_tokens={request.max_tokens})"
|
|
||||||
if request.system_prompt is not None:
|
|
||||||
params_text += f" (using custom system prompt)"
|
|
||||||
|
|
||||||
bot_message = {
|
|
||||||
"id": str(uuid.uuid4()),
|
|
||||||
"content": f"This is a test response to: '{request.message}'{params_text}",
|
|
||||||
"user_id": None,
|
|
||||||
"is_user_message": False,
|
|
||||||
"timestamp": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
chats[chat_id]["messages"].append(bot_message)
|
|
||||||
chats[chat_id]["updated_at"] = datetime.utcnow().isoformat()
|
|
||||||
|
|
||||||
return bot_message
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=5251)
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for sending a message with advanced parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# Create a new chat
|
|
||||||
def create_chat():
|
|
||||||
response = requests.post(
|
|
||||||
"http://localhost:5251/chats",
|
|
||||||
json={
|
|
||||||
"user_id": "test_user",
|
|
||||||
"title": "Test Chat",
|
|
||||||
"model_id": "gpt-3.5-turbo",
|
|
||||||
"is_team_chat": False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()["id"]
|
|
||||||
else:
|
|
||||||
print(f"Error creating chat: {response.status_code}")
|
|
||||||
print(response.text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Send a message with advanced parameters
|
|
||||||
def send_message(chat_id):
|
|
||||||
response = requests.post(
|
|
||||||
f"http://localhost:5251/chats/{chat_id}/messages",
|
|
||||||
json={
|
|
||||||
"message": "Tell me about artificial intelligence",
|
|
||||||
"user_id": "test_user",
|
|
||||||
"use_rag": False,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"max_tokens": 500,
|
|
||||||
"top_p": 0.9,
|
|
||||||
"frequency_penalty": 0.5,
|
|
||||||
"presence_penalty": 0.5,
|
|
||||||
"system_prompt": "You are a helpful AI assistant that provides concise responses."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
return response.json()
|
|
||||||
else:
|
|
||||||
print(f"Error sending message: {response.status_code}")
|
|
||||||
print(response.text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Main function
|
|
||||||
def main():
|
|
||||||
print("Creating a new chat...")
|
|
||||||
chat_id = create_chat()
|
|
||||||
|
|
||||||
if chat_id:
|
|
||||||
print(f"Chat created with ID: {chat_id}")
|
|
||||||
|
|
||||||
print("\nSending a message with advanced parameters...")
|
|
||||||
response = send_message(chat_id)
|
|
||||||
|
|
||||||
if response:
|
|
||||||
print("\nResponse received:")
|
|
||||||
print(f"Message ID: {response['id']}")
|
|
||||||
print(f"Content: {response['content']}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Executable
+309
@@ -0,0 +1,309 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Test script for RAG functionality in the AI service.
|
||||||
|
This script tests the document-based question answering capabilities
|
||||||
|
by making requests to the API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
DEFAULT_API_URL = "http://localhost:5252" # Local development server
|
||||||
|
DEFAULT_REMOTE_API_URL = "http://157.157.221.29:5251" # Remote server
|
||||||
|
DEFAULT_MODEL = "llama3.1"
|
||||||
|
|
||||||
|
class RAGTester:
|
||||||
|
"""Test the RAG functionality of the AI service."""
|
||||||
|
|
||||||
|
def __init__(self, api_url: str, verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize the RAG tester.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_url: URL of the AI service API.
|
||||||
|
verbose: Whether to print verbose output.
|
||||||
|
"""
|
||||||
|
self.api_url = api_url
|
||||||
|
self.verbose = verbose
|
||||||
|
self.session = requests.Session()
|
||||||
|
|
||||||
|
# Print configuration
|
||||||
|
print(f"Testing RAG functionality against API at: {self.api_url}")
|
||||||
|
|
||||||
|
def _log(self, message: str):
|
||||||
|
"""Log a message if verbose mode is enabled."""
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[DEBUG] {message}")
|
||||||
|
|
||||||
|
def check_server_health(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the server is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the server is healthy, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.session.get(f"{self.api_url}/health", timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("status") == "healthy":
|
||||||
|
print("✅ Server is healthy")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Server health check failed")
|
||||||
|
print(f"Response: {result}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error checking server health: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Check the server configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server configuration.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.session.get(f"{self.api_url}/config", timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
config = response.json()
|
||||||
|
|
||||||
|
print("✅ Server configuration:")
|
||||||
|
print(f" - API Host: {config.get('api_host')}")
|
||||||
|
print(f" - API Port: {config.get('api_port')}")
|
||||||
|
print(f" - OpenWebUI URL: {config.get('openwebui_url')}")
|
||||||
|
print(f" - Ollama API URL: {config.get('ollama_api_url')}")
|
||||||
|
print(f" - Default Model: {config.get('default_model')}")
|
||||||
|
print(f" - API Timeout: {config.get('api_timeout')} seconds")
|
||||||
|
print(f" - Available Models: {', '.join(config.get('available_models', []))}")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error checking server configuration: {str(e)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def test_ollama_connection(self) -> bool:
|
||||||
|
"""
|
||||||
|
Test the connection to Ollama.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the connection is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.session.get(f"{self.api_url}/test-ollama", timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("status") == "success":
|
||||||
|
print("✅ Successfully connected to Ollama API")
|
||||||
|
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||||
|
print(f" - Available models: {len(result.get('models', {}).get('models', []))}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Failed to connect to Ollama API")
|
||||||
|
print(f" - Error: {result.get('message')}")
|
||||||
|
print(f" - Ollama URL: {result.get('ollama_url')}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error testing Ollama connection: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_chat_completion(self, model_id: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Test basic chat completion without RAG.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Optional model ID to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the test is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.session.post(f"{self.api_url}/test-chat", timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("status") == "success":
|
||||||
|
print("✅ Successfully tested chat completion")
|
||||||
|
print(f" - Model: {result.get('model')}")
|
||||||
|
print(f" - Response: {result.get('response')[:100]}...") # First 100 chars
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Failed to test chat completion")
|
||||||
|
print(f" - Error: {result.get('message')}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error testing chat completion: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_rag_completion(self, query: str = "What information do you have in your knowledge database?") -> bool:
|
||||||
|
"""
|
||||||
|
Test RAG completion with a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query to test with RAG.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the test is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = self.session.post(
|
||||||
|
f"{self.api_url}/test-rag",
|
||||||
|
params={"query": query},
|
||||||
|
timeout=120 # Longer timeout for RAG
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("status") == "success":
|
||||||
|
print("✅ Successfully tested RAG completion")
|
||||||
|
print(f" - Model: {result.get('model')}")
|
||||||
|
print(f" - Query: {result.get('query')}")
|
||||||
|
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||||
|
print(f" - Response: {result.get('response')[:150]}...") # First 150 chars
|
||||||
|
|
||||||
|
# Print full response in verbose mode
|
||||||
|
if self.verbose:
|
||||||
|
print("\nFull response:")
|
||||||
|
print(result.get('response'))
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Failed to test RAG completion")
|
||||||
|
print(f" - Error: {result.get('message')}")
|
||||||
|
print(f" - OpenWebUI URL: {result.get('openwebui_url')}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error testing RAG completion: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_chat_and_test_rag(self, query: str, user_id: str = "test_user") -> bool:
|
||||||
|
"""
|
||||||
|
Create a chat and test RAG with a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query to test with RAG.
|
||||||
|
user_id: User ID for the chat.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the test is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a chat
|
||||||
|
create_response = self.session.post(
|
||||||
|
f"{self.api_url}/chats",
|
||||||
|
json={
|
||||||
|
"user_id": user_id,
|
||||||
|
"title": "RAG Test Chat",
|
||||||
|
"model_id": DEFAULT_MODEL,
|
||||||
|
"is_team_chat": False
|
||||||
|
},
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
create_response.raise_for_status()
|
||||||
|
chat = create_response.json()
|
||||||
|
chat_id = chat.get("id")
|
||||||
|
|
||||||
|
print(f"✅ Created chat with ID: {chat_id}")
|
||||||
|
|
||||||
|
# Send a message with RAG enabled
|
||||||
|
message_response = self.session.post(
|
||||||
|
f"{self.api_url}/chats/{chat_id}/messages",
|
||||||
|
json={
|
||||||
|
"message": query,
|
||||||
|
"user_id": user_id,
|
||||||
|
"use_rag": True,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000
|
||||||
|
},
|
||||||
|
timeout=120 # Longer timeout for RAG
|
||||||
|
)
|
||||||
|
message_response.raise_for_status()
|
||||||
|
message = message_response.json()
|
||||||
|
|
||||||
|
print("✅ Successfully sent message with RAG")
|
||||||
|
print(f" - Message ID: {message.get('id')}")
|
||||||
|
print(f" - Response: {message.get('content')[:150]}...") # First 150 chars
|
||||||
|
|
||||||
|
# Print full response in verbose mode
|
||||||
|
if self.verbose:
|
||||||
|
print("\nFull response:")
|
||||||
|
print(message.get('content'))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error testing chat with RAG: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_all_tests(self, query: str = "What information do you have in your knowledge database?"):
|
||||||
|
"""
|
||||||
|
Run all tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query to test with RAG.
|
||||||
|
"""
|
||||||
|
print("\n=== Running RAG Functionality Tests ===\n")
|
||||||
|
|
||||||
|
# Check server health
|
||||||
|
if not self.check_server_health():
|
||||||
|
print("❌ Server health check failed. Aborting tests.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check configuration
|
||||||
|
config = self.check_config()
|
||||||
|
if not config:
|
||||||
|
print("❌ Failed to get server configuration. Continuing with tests...")
|
||||||
|
|
||||||
|
# Test Ollama connection
|
||||||
|
if not self.test_ollama_connection():
|
||||||
|
print("⚠️ Ollama connection test failed. Some tests may fail.")
|
||||||
|
|
||||||
|
# Test basic chat completion
|
||||||
|
if not self.test_chat_completion():
|
||||||
|
print("⚠️ Basic chat completion test failed. RAG tests may also fail.")
|
||||||
|
|
||||||
|
# Test RAG completion
|
||||||
|
print("\n--- Testing RAG Completion ---\n")
|
||||||
|
self.test_rag_completion(query)
|
||||||
|
|
||||||
|
# Test chat with RAG
|
||||||
|
print("\n--- Testing Chat with RAG ---\n")
|
||||||
|
self.create_chat_and_test_rag(query)
|
||||||
|
|
||||||
|
print("\n=== Tests Completed ===\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the tests."""
|
||||||
|
parser = argparse.ArgumentParser(description="Test RAG functionality in the AI service")
|
||||||
|
parser.add_argument("--api-url", default=DEFAULT_API_URL, help=f"URL of the AI service API (default: {DEFAULT_API_URL})")
|
||||||
|
parser.add_argument("--remote", action="store_true", help=f"Use the remote API URL ({DEFAULT_REMOTE_API_URL})")
|
||||||
|
parser.add_argument("--query", default="What information do you have in your knowledge database?", help="Query to test with RAG")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose output")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Use remote URL if specified
|
||||||
|
api_url = DEFAULT_REMOTE_API_URL if args.remote else args.api_url
|
||||||
|
|
||||||
|
# Create and run the tester
|
||||||
|
tester = RAGTester(api_url=api_url, verbose=args.verbose)
|
||||||
|
tester.run_all_tests(query=args.query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
requests>=2.28.0
|
||||||
|
pydantic>=2.0.0
|
||||||
Reference in New Issue
Block a user