Added Rag Featured
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
"""
|
||||
FastAPI application for the AI service.
|
||||
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
|
||||
|
||||
The service supports document-based question answering using OpenWebUI's knowledge database:
|
||||
- Set use_rag=True in API requests to enable Retrieval Augmented Generation
|
||||
- When enabled, the service will use OpenWebUI's knowledge database to find relevant information
|
||||
- Documents uploaded to OpenWebUI will be used to augment the model's responses
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
@@ -203,6 +208,43 @@ async def test_chat_completion():
|
||||
"ollama_url": config.OLLAMA_API_URL
|
||||
}
|
||||
|
||||
@app.post("/test-rag")
|
||||
async def test_rag_completion(query: str = "What information do you have in your knowledge database?"):
|
||||
"""
|
||||
Test the RAG (Retrieval Augmented Generation) functionality with a query.
|
||||
|
||||
This endpoint tests the integration with OpenWebUI's knowledge database.
|
||||
|
||||
Args:
|
||||
query: The question to ask about documents in the knowledge database.
|
||||
|
||||
Returns:
|
||||
Model response using RAG.
|
||||
"""
|
||||
try:
|
||||
# Use the model service directly with RAG enabled
|
||||
response = model_service.generate_response(
|
||||
model_id=config.DEFAULT_MODEL,
|
||||
prompt=query,
|
||||
context=[],
|
||||
use_rag=True # Enable RAG
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model": config.DEFAULT_MODEL,
|
||||
"query": query,
|
||||
"use_rag": True,
|
||||
"response": response,
|
||||
"openwebui_url": config.OPENWEBUI_URL
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to get RAG completion: {str(e)}",
|
||||
"openwebui_url": config.OPENWEBUI_URL
|
||||
}
|
||||
|
||||
@app.post("/test-ollama-direct")
|
||||
async def test_ollama_direct():
|
||||
"""
|
||||
|
||||
@@ -198,5 +198,32 @@
|
||||
}
|
||||
],
|
||||
"team_members": []
|
||||
},
|
||||
"aefe133e-e85e-422f-9ca9-6441f24db74f": {
|
||||
"id": "aefe133e-e85e-422f-9ca9-6441f24db74f",
|
||||
"title": "RAG Test Chat",
|
||||
"user_id": "test_user",
|
||||
"model_id": "llama3.1",
|
||||
"is_team_chat": false,
|
||||
"created_at": "2025-05-16T14:41:08.281043",
|
||||
"updated_at": "2025-05-16T14:41:09.142922",
|
||||
"messages": [
|
||||
{
|
||||
"id": "35384001-0c6d-4ac6-a247-32fb53bb664c",
|
||||
"content": "What information do you have in your knowledge database?",
|
||||
"user_id": "test_user",
|
||||
"is_user_message": true,
|
||||
"timestamp": "2025-05-16T14:41:08.341940"
|
||||
},
|
||||
{
|
||||
"id": "4b82555f-76e2-43c0-a1aa-c91257536133",
|
||||
"content": "Connection error to Ollama API: HTTPConnectionPool(host='104.225.217.215', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x104bcc750>: Failed to establish a new connection: [Errno 61] Connection refused')). Please check if Ollama is running at http://104.225.217.215:11434.",
|
||||
"user_id": null,
|
||||
"is_user_message": false,
|
||||
"timestamp": "2025-05-16T14:41:09.142892"
|
||||
}
|
||||
],
|
||||
"team_members": [],
|
||||
"openwebui_channel_id": null
|
||||
}
|
||||
}
|
||||
@@ -1,261 +0,0 @@
|
||||
"""
|
||||
Service for document processing and chunking.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
import requests
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from ai_service.config import config
|
||||
|
||||
class DocumentService:
|
||||
"""Service for document processing and chunking."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the document service."""
|
||||
self.chunk_size = config.CHUNK_SIZE
|
||||
self.chunk_overlap = config.CHUNK_OVERLAP
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len
|
||||
)
|
||||
|
||||
# OpenWebUI configuration
|
||||
self.openwebui_url = config.OPENWEBUI_URL
|
||||
self.openwebui_api_key = config.OPENWEBUI_API_KEY
|
||||
|
||||
# Ensure data directory exists
|
||||
os.makedirs('ai_service/data', exist_ok=True)
|
||||
|
||||
# For now, we'll store document metadata in a simple JSON file
|
||||
self.metadata_file = 'ai_service/data/document_metadata.json'
|
||||
self._load_metadata()
|
||||
|
||||
def _load_metadata(self):
|
||||
"""Load document metadata from file."""
|
||||
if os.path.exists(self.metadata_file):
|
||||
try:
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
self.documents = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Error loading document metadata: {str(e)}")
|
||||
self.documents = {}
|
||||
else:
|
||||
self.documents = {}
|
||||
|
||||
def _save_metadata(self):
|
||||
"""Save document metadata to file."""
|
||||
try:
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(self.documents, f, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving document metadata: {str(e)}")
|
||||
|
||||
def process_document(self, content: str, title: str,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Process a document for embedding.
|
||||
|
||||
Args:
|
||||
content: Document content.
|
||||
title: Document title.
|
||||
description: Optional document description.
|
||||
metadata: Optional additional metadata.
|
||||
|
||||
Returns:
|
||||
Document ID.
|
||||
"""
|
||||
# Generate a unique ID for the document
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
# Upload the document to OpenWebUI for RAG processing
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Prepare the document data
|
||||
document_data = {
|
||||
"filename": f"{title}.txt",
|
||||
"content": base64.b64encode(content.encode('utf-8')).decode('utf-8'),
|
||||
"description": description or title
|
||||
}
|
||||
|
||||
# Upload to OpenWebUI
|
||||
response = requests.post(
|
||||
f"{self.openwebui_url}/api/knowledge/upload",
|
||||
headers=headers,
|
||||
json=document_data,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Get the OpenWebUI document ID
|
||||
openwebui_doc_id = result.get('id', '')
|
||||
|
||||
# Store document metadata
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'title': title,
|
||||
'description': description or '',
|
||||
'openwebui_id': openwebui_doc_id,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return doc_id
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error uploading document to OpenWebUI: {str(e)}")
|
||||
|
||||
# Fall back to local processing if OpenWebUI upload fails
|
||||
print("Falling back to local document processing")
|
||||
|
||||
# Split the document into chunks for local reference
|
||||
chunks = self.text_splitter.split_text(content)
|
||||
|
||||
# Store document metadata
|
||||
self.documents[doc_id] = {
|
||||
'id': doc_id,
|
||||
'title': title,
|
||||
'description': description or '',
|
||||
'chunk_count': len(chunks),
|
||||
'openwebui_upload_failed': True,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return doc_id
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get document metadata.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Document metadata if found, None otherwise.
|
||||
"""
|
||||
return self.documents.get(doc_id)
|
||||
|
||||
def get_all_documents(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all document metadata.
|
||||
|
||||
Returns:
|
||||
List of document metadata.
|
||||
"""
|
||||
# Get documents from local storage
|
||||
local_documents = list(self.documents.values())
|
||||
|
||||
# Try to get documents from OpenWebUI as well
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Get documents from OpenWebUI
|
||||
response = requests.get(
|
||||
f"{self.openwebui_url}/api/knowledge",
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
openwebui_docs = response.json()
|
||||
|
||||
# Update local documents with OpenWebUI information
|
||||
for doc in local_documents:
|
||||
if 'openwebui_id' in doc:
|
||||
for openwebui_doc in openwebui_docs:
|
||||
if openwebui_doc.get('id') == doc['openwebui_id']:
|
||||
doc['openwebui_status'] = 'active'
|
||||
doc['openwebui_info'] = openwebui_doc
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting documents from OpenWebUI: {str(e)}")
|
||||
|
||||
return local_documents
|
||||
|
||||
def delete_document(self, doc_id: str) -> bool:
|
||||
"""
|
||||
Delete a document and its chunks.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
if doc_id not in self.documents:
|
||||
return False
|
||||
|
||||
# Check if document was uploaded to OpenWebUI
|
||||
doc = self.documents[doc_id]
|
||||
openwebui_id = doc.get('openwebui_id')
|
||||
|
||||
if openwebui_id:
|
||||
try:
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.openwebui_api_key:
|
||||
headers["Authorization"] = f"Bearer {self.openwebui_api_key}"
|
||||
|
||||
# Delete from OpenWebUI
|
||||
response = requests.delete(
|
||||
f"{self.openwebui_url}/api/knowledge/{openwebui_id}",
|
||||
headers=headers,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Warning: Failed to delete document from OpenWebUI: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting document from OpenWebUI: {str(e)}")
|
||||
|
||||
# Delete document metadata
|
||||
del self.documents[doc_id]
|
||||
|
||||
# Save metadata to file
|
||||
self._save_metadata()
|
||||
|
||||
return True
|
||||
|
||||
def search_documents(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for documents similar to a query.
|
||||
|
||||
Args:
|
||||
query: Search query.
|
||||
top_k: Number of results to return.
|
||||
|
||||
Returns:
|
||||
List of similar document chunks with their metadata.
|
||||
"""
|
||||
# Note: We don't need to implement this method anymore since
|
||||
# RAG is handled directly by OpenWebUI when use_rag=True in the model service
|
||||
|
||||
# Return empty results - this is just a placeholder
|
||||
# The actual RAG functionality is in the model_service.generate_response method
|
||||
return []
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
document_service = DocumentService()
|
||||
@@ -135,9 +135,6 @@ class ModelService:
|
||||
model_id = self.default_model
|
||||
print(f" - Model not found, using default: {model_id}")
|
||||
|
||||
# Ensure we're using a valid model
|
||||
# (model_id is already validated above)
|
||||
|
||||
# Prepare the messages for the API call
|
||||
messages = []
|
||||
|
||||
@@ -178,6 +175,10 @@ class ModelService:
|
||||
openwebui_request['max_tokens'] = params['max_tokens']
|
||||
if 'top_p' in params:
|
||||
openwebui_request['top_p'] = params['top_p']
|
||||
if 'top_k' in params:
|
||||
openwebui_request['top_k'] = params['top_k']
|
||||
if 'repeat_penalty' in params:
|
||||
openwebui_request['repeat_penalty'] = params['repeat_penalty']
|
||||
|
||||
# Make the API call to OpenWebUI
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@@ -201,10 +202,15 @@ class ModelService:
|
||||
result = response.json()
|
||||
|
||||
# Extract the response content
|
||||
if 'message' in result:
|
||||
if 'choices' in result and len(result['choices']) > 0 and 'message' in result['choices'][0]:
|
||||
# OpenAI-compatible format
|
||||
return result['choices'][0]['message']['content']
|
||||
elif 'message' in result and 'content' in result['message']:
|
||||
# OpenWebUI format
|
||||
return result['message']['content']
|
||||
else:
|
||||
return "Error: Unexpected response format from OpenWebUI"
|
||||
print(f"WARNING: Unexpected response format from OpenWebUI: {json.dumps(result, indent=2)}")
|
||||
return "Error: Unexpected response format from OpenWebUI. Falling back to direct model call."
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
error_msg = f"Timeout error connecting to OpenWebUI API: {str(e)}. The request exceeded the {self.api_timeout} second timeout."
|
||||
@@ -216,6 +222,11 @@ class ModelService:
|
||||
print(f"ERROR: {error_msg}")
|
||||
print("Falling back to direct Ollama call without RAG")
|
||||
# Continue to the Ollama API call below
|
||||
except requests.exceptions.HTTPError as e:
|
||||
error_msg = f"HTTP error from OpenWebUI API: {str(e)}."
|
||||
print(f"ERROR: {error_msg}")
|
||||
print("Falling back to direct Ollama call without RAG")
|
||||
# Continue to the Ollama API call below
|
||||
except Exception as e:
|
||||
error_msg = f"Error calling OpenWebUI API: {str(e)}"
|
||||
print(f"ERROR: {error_msg}")
|
||||
@@ -247,6 +258,8 @@ class ModelService:
|
||||
request_json['top_k'] = params['top_k']
|
||||
if 'max_tokens' in params:
|
||||
request_json['max_tokens'] = params['max_tokens']
|
||||
if 'repeat_penalty' in params:
|
||||
request_json['repeat_penalty'] = params['repeat_penalty']
|
||||
|
||||
# Make the API call to Ollama
|
||||
try:
|
||||
@@ -272,6 +285,7 @@ class ModelService:
|
||||
if 'message' in result and 'content' in result['message']:
|
||||
return result['message']['content']
|
||||
else:
|
||||
print(f"WARNING: Unexpected response format from Ollama: {json.dumps(result, indent=2)}")
|
||||
return "Error: Unexpected response format from Ollama"
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
|
||||
@@ -71,10 +71,16 @@ async def chat_completions(request: Request):
|
||||
temperature = body.get("temperature")
|
||||
max_tokens = body.get("max_tokens")
|
||||
top_p = body.get("top_p")
|
||||
top_k = body.get("top_k")
|
||||
frequency_penalty = body.get("frequency_penalty")
|
||||
presence_penalty = body.get("presence_penalty")
|
||||
repeat_penalty = body.get("repeat_penalty")
|
||||
stop = body.get("stop")
|
||||
|
||||
# Check if RAG should be used
|
||||
use_knowledge = body.get("use_knowledge", False)
|
||||
use_rag = use_knowledge # Map OpenWebUI's use_knowledge to our use_rag parameter
|
||||
|
||||
# Create a unique chat ID
|
||||
chat_id = str(uuid.uuid4())
|
||||
|
||||
@@ -99,11 +105,14 @@ async def chat_completions(request: Request):
|
||||
chat_id=chat_id,
|
||||
message=user_message,
|
||||
user_id=user_id,
|
||||
use_rag=use_rag,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user