Files
ds_zagres_ai/ai_service/api.py
T

647 lines
21 KiB
Python
Raw Normal View History

2025-05-09 15:41:16 +01:00
"""
FastAPI application for the AI service.
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
2025-05-16 15:24:01 +01:00
The service supports document-based question answering using OpenWebUI's knowledge database:
- Set use_rag=True in API requests to enable Retrieval Augmented Generation
- When enabled, the service will use OpenWebUI's knowledge database to find relevant information
- Documents uploaded to OpenWebUI will be used to augment the model's responses
2025-05-09 15:41:16 +01:00
"""
from fastapi import FastAPI, HTTPException
2025-05-09 15:41:16 +01:00
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional
2025-05-12 16:30:35 +01:00
import uuid
from datetime import datetime, timezone
2025-05-09 15:41:16 +01:00
from ai_service.models.model_service import model_service
from ai_service.models.chat_service import chat_service
from ai_service.openwebui_api import router as openwebui_router
2025-05-16 13:23:35 +01:00
from ai_service.openwebui_channels import openwebui_channels
2025-05-12 16:10:45 +01:00
from ai_service.config import config
2025-05-09 15:41:16 +01:00
# Create FastAPI app
app = FastAPI(
title="AI Service API",
description="Backend API for OpenWebUI",
2025-05-09 15:41:16 +01:00
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
# Include OpenWebUI-compatible API routes
app.include_router(openwebui_router, prefix="/api")
# Include Ollama proxy routes
app.include_router(openwebui_router, prefix="/ollama")
2025-05-16 13:23:35 +01:00
# Register webhook for channel messages on startup
@app.on_event("startup")
async def startup_event():
"""
Register webhook for channel messages on startup.
"""
# Get the public URL of this service
service_url = f"http://{config.API_HOST}:{config.API_PORT}"
if config.PUBLIC_URL:
service_url = config.PUBLIC_URL
# Register webhook
webhook_url = f"{service_url}/webhooks/channel-message"
print(f"Registering webhook for channel messages: {webhook_url}")
success = openwebui_channels.register_webhook(webhook_url)
if success:
print("Successfully registered webhook for channel messages")
else:
print("Failed to register webhook for channel messages")
# Define API models for health check
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str = Field(..., description="Health status")
2025-05-09 15:41:16 +01:00
class ModelInfo(BaseModel):
"""Model for model information."""
id: str = Field(..., description="Model ID")
name: str = Field(..., description="Model name")
description: str = Field(..., description="Model description")
provider: str = Field(..., description="Model provider")
max_tokens: int = Field(..., description="Maximum tokens")
is_default: bool = Field(..., description="Whether this is the default model")
class ChatRequest(BaseModel):
"""Request model for creating a chat."""
user_id: str = Field(..., description="User ID")
title: Optional[str] = Field(None, description="Chat title")
model_id: Optional[str] = Field(None, description="Model ID")
is_team_chat: bool = Field(False, description="Whether this is a team chat")
class MessageRequest(BaseModel):
"""Request model for sending a message."""
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
use_rag: bool = Field(False, description="Whether to use RAG")
# Model parameters
temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions")
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating")
system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior")
# Additional advanced parameters
min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection")
top_k: Optional[int] = Field(None, description="Only sample from the top k tokens")
repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens")
function_calling: Optional[bool] = Field(None, description="Whether to enable function calling")
class Message(BaseModel):
"""Model for a message."""
id: str = Field(..., description="Message ID")
content: str = Field(..., description="Message content")
user_id: Optional[str] = Field(None, description="User ID")
is_user_message: bool = Field(..., description="Whether this is a user message")
timestamp: str = Field(..., description="Message timestamp")
class Chat(BaseModel):
"""Model for a chat."""
id: str = Field(..., description="Chat ID")
title: str = Field(..., description="Chat title")
user_id: str = Field(..., description="User ID")
model_id: str = Field(..., description="Model ID")
is_team_chat: bool = Field(..., description="Whether this is a team chat")
created_at: str = Field(..., description="Creation timestamp")
updated_at: str = Field(..., description="Update timestamp")
messages: List[Message] = Field(..., description="Chat messages")
team_members: List[str] = Field(..., description="Team members")
# Define API endpoints
@app.get("/health", response_model=HealthResponse)
2025-05-09 15:41:16 +01:00
async def health_check():
"""
Health check endpoint.
Returns:
Health status.
"""
return {"status": "healthy"}
2025-05-12 16:10:45 +01:00
@app.get("/test-ollama")
async def test_ollama_connection():
"""
Test the connection to the Ollama API.
Returns:
Connection status and available models from Ollama.
"""
import requests
try:
# Try to connect to Ollama API
2025-05-12 16:30:35 +01:00
response = requests.get(f"{config.OLLAMA_API_URL}/api/tags", timeout=config.API_TIMEOUT)
2025-05-12 16:10:45 +01:00
response.raise_for_status()
# Return the models from Ollama
return {
"status": "success",
"message": "Successfully connected to Ollama API",
"ollama_url": config.OLLAMA_API_URL,
"models": response.json()
}
2025-05-12 16:30:35 +01:00
except requests.exceptions.Timeout as e:
return {
"status": "error",
"message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.",
"ollama_url": config.OLLAMA_API_URL
}
except requests.exceptions.ConnectionError as e:
return {
"status": "error",
"message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.",
"ollama_url": config.OLLAMA_API_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to Ollama API: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
@app.post("/test-chat")
async def test_chat_completion():
"""
Test the chat completion with a simple prompt.
Returns:
Model response.
"""
try:
# Use the model service directly
response = model_service.generate_response(
model_id=config.DEFAULT_MODEL,
prompt="Hello, how are you?",
context=[],
use_rag=False
)
return {
"status": "success",
"model": config.DEFAULT_MODEL,
"response": response,
"ollama_url": config.OLLAMA_API_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get chat completion: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
2025-05-16 15:24:01 +01:00
@app.post("/test-rag")
async def test_rag_completion(query: str = "What information do you have in your knowledge database?"):
"""
Test the RAG (Retrieval Augmented Generation) functionality with a query.
This endpoint tests the integration with OpenWebUI's knowledge database.
Args:
query: The question to ask about documents in the knowledge database.
Returns:
Model response using RAG.
"""
try:
# Use the model service directly with RAG enabled
response = model_service.generate_response(
model_id=config.DEFAULT_MODEL,
prompt=query,
context=[],
use_rag=True # Enable RAG
)
return {
"status": "success",
"model": config.DEFAULT_MODEL,
"query": query,
"use_rag": True,
"response": response,
"openwebui_url": config.OPENWEBUI_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get RAG completion: {str(e)}",
"openwebui_url": config.OPENWEBUI_URL
}
2025-05-12 16:30:35 +01:00
@app.post("/test-ollama-direct")
async def test_ollama_direct():
"""
Test the Ollama API directly with a simple chat request.
Returns:
Raw Ollama API response.
"""
import requests
try:
# Prepare a simple chat request
request_json = {
"model": config.DEFAULT_MODEL,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"stream": False
}
# Make the API call to Ollama
print(f"Sending direct request to Ollama API at: {config.OLLAMA_API_URL}/api/chat")
response = requests.post(
f"{config.OLLAMA_API_URL}/api/chat",
headers={"Content-Type": "application/json"},
json=request_json,
timeout=config.API_TIMEOUT
)
response.raise_for_status()
result = response.json()
return {
"status": "success",
"ollama_url": config.OLLAMA_API_URL,
"request": request_json,
"response": result
}
except requests.exceptions.Timeout as e:
return {
"status": "error",
"message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.",
"ollama_url": config.OLLAMA_API_URL
}
except requests.exceptions.ConnectionError as e:
return {
"status": "error",
"message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.",
"ollama_url": config.OLLAMA_API_URL
}
2025-05-12 16:10:45 +01:00
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to Ollama API: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
@app.get("/config")
async def get_config():
"""
Get the current configuration.
Returns:
Current configuration settings.
"""
return {
"api_host": config.API_HOST,
"api_port": config.API_PORT,
"openwebui_url": config.OPENWEBUI_URL,
"ollama_api_url": config.OLLAMA_API_URL,
"default_model": config.DEFAULT_MODEL,
2025-05-12 16:30:35 +01:00
"api_timeout": config.API_TIMEOUT,
2025-05-12 16:10:45 +01:00
"available_models": list(model_service.AVAILABLE_MODELS.keys())
}
2025-05-09 15:41:16 +01:00
# Model endpoints
@app.get("/models", response_model=List[ModelInfo])
async def get_available_models():
"""
Get available models.
Returns:
List of model information.
"""
2025-05-12 16:10:45 +01:00
models = model_service.get_available_models()
# Debug log
print(f"API models: {models}")
return models
2025-05-09 15:41:16 +01:00
@app.get("/models/{model_id}", response_model=ModelInfo)
async def get_model_info(model_id: str):
"""
Get information about a model.
Args:
model_id: Model ID.
Returns:
Model information.
"""
model_info = model_service.get_model_info(model_id)
if not model_info:
raise HTTPException(status_code=404, detail="Model not found")
return model_info
# Chat endpoints
@app.post("/chats", response_model=Chat)
async def create_chat(request: ChatRequest):
"""
Create a new chat.
Args:
request: Chat creation request.
Returns:
Created chat.
"""
chat_id = chat_service.create_chat(
user_id=request.user_id,
title=request.title,
model_id=request.model_id,
is_team_chat=request.is_team_chat
)
return chat_service.get_chat(chat_id)
@app.get("/chats/user/{user_id}", response_model=List[Chat])
async def get_user_chats(user_id: str):
"""
Get all chats for a user.
Args:
user_id: User ID.
Returns:
List of chats.
"""
return chat_service.get_user_chats(user_id)
@app.get("/chats/{chat_id}", response_model=Chat)
async def get_chat(chat_id: str):
"""
Get a chat by ID.
Args:
chat_id: Chat ID.
Returns:
Chat information.
"""
chat = chat_service.get_chat(chat_id)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
return chat
@app.post("/chats/{chat_id}/messages", response_model=Message)
async def send_message(chat_id: str, request: MessageRequest):
"""
Send a message to a chat.
Args:
chat_id: Chat ID.
request: Message request with optional model parameters.
Returns:
Bot response message.
"""
try:
2025-05-12 16:30:35 +01:00
print(f"Processing message for chat {chat_id} from user {request.user_id}")
print(f"Message: {request.message[:50]}...") # Print first 50 chars of message
print(f"Using RAG: {request.use_rag}")
print(f"Model parameters: temperature={request.temperature}, max_tokens={request.max_tokens}")
2025-05-09 15:41:16 +01:00
# Extract model parameters from the request
response = chat_service.get_chat_response(
chat_id=chat_id,
message=request.message,
user_id=request.user_id,
use_rag=request.use_rag,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
stop_sequences=request.stop_sequences,
system_prompt=request.system_prompt,
min_p=request.min_p,
top_k=request.top_k,
repeat_penalty=request.repeat_penalty,
function_calling=request.function_calling
)
2025-05-12 16:30:35 +01:00
print(f"Response received. Length: {len(response.get('content', ''))}")
2025-05-09 15:41:16 +01:00
return response
except ValueError as e:
2025-05-12 16:30:35 +01:00
error_msg = f"Chat not found: {str(e)}"
print(f"ERROR: {error_msg}")
raise HTTPException(status_code=404, detail=error_msg)
except Exception as e:
error_msg = f"Error processing message: {str(e)}"
print(f"ERROR: {error_msg}")
# Return an error message instead of raising an exception
# This ensures the client gets a proper response
return {
"id": str(uuid.uuid4()),
"content": f"Error processing message: {str(e)}",
"user_id": None,
"is_user_message": False,
"timestamp": datetime.now(timezone.utc).isoformat()
}
2025-05-09 15:41:16 +01:00
@app.post("/chats/{chat_id}/members/{user_id}")
async def add_team_member(chat_id: str, user_id: str):
"""
Add a user to a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Addition status.
"""
success = chat_service.add_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to add team member")
return {"status": "success", "message": "Team member added"}
@app.delete("/chats/{chat_id}/members/{user_id}")
async def remove_team_member(chat_id: str, user_id: str):
"""
Remove a user from a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Removal status.
"""
success = chat_service.remove_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to remove team member")
return {"status": "success", "message": "Team member removed"}
@app.delete("/chats/{chat_id}")
async def delete_chat(chat_id: str):
"""
Delete a chat.
Args:
chat_id: Chat ID.
Returns:
Deletion status.
"""
success = chat_service.delete_chat(chat_id)
if not success:
raise HTTPException(status_code=404, detail="Chat not found")
return {"status": "success", "message": "Chat deleted"}
2025-05-16 13:23:35 +01:00
# OpenWebUI Channels endpoints
@app.get("/channels")
async def get_openwebui_channels():
"""
Get all OpenWebUI channels.
Returns:
List of channels.
"""
channels = openwebui_channels.get_channels()
return channels
@app.get("/channels/{channel_id}")
async def get_openwebui_channel(channel_id: str):
"""
Get an OpenWebUI channel by ID.
Args:
channel_id: Channel ID.
Returns:
Channel information.
"""
channel = openwebui_channels.get_channel(channel_id)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
@app.post("/channels")
async def create_openwebui_channel(name: str, description: str = "", is_private: bool = False):
"""
Create a new OpenWebUI channel.
Args:
name: Channel name.
description: Channel description.
is_private: Whether the channel is private.
Returns:
Created channel.
"""
channel = openwebui_channels.create_channel(name, description, is_private)
if not channel:
raise HTTPException(status_code=400, detail="Failed to create channel")
return channel
# Webhook endpoint for OpenWebUI channel messages
class ChannelMessageWebhook(BaseModel):
"""Model for channel message webhook."""
channel_id: str = Field(..., description="Channel ID")
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
timestamp: Optional[str] = Field(None, description="Message timestamp")
@app.post("/webhooks/channel-message")
async def channel_message_webhook(request: ChannelMessageWebhook):
"""
Webhook endpoint for receiving messages from OpenWebUI channels.
This endpoint is called by OpenWebUI when a message is sent in a channel.
The AI service will process the message and respond in the channel.
Args:
request: Channel message webhook request.
Returns:
Processing status.
"""
try:
print(f"Received channel message webhook: {request.channel_id}, {request.user_id}, {request.message}")
# Find the chat associated with this OpenWebUI channel
chat_id = None
for cid, chat in chat_service.chats.items():
if chat.get('is_team_chat') and chat.get('openwebui_channel_id') == request.channel_id:
chat_id = cid
break
if not chat_id:
print(f"No chat found for OpenWebUI channel {request.channel_id}")
return {"status": "error", "message": "No chat found for this channel"}
# Skip messages from the AI assistant to avoid loops
if request.user_id == "ai-assistant":
return {"status": "skipped", "message": "Skipping AI assistant message"}
# Check if we should respond to all messages or only to mentions
if not config.AI_RESPOND_TO_ALL:
# Check if the message mentions the AI using configured triggers
message_lower = request.message.lower()
is_triggered = False
for trigger in config.AI_TRIGGERS:
if trigger.lower() in message_lower:
is_triggered = True
break
# If no trigger is found, skip processing
if not is_triggered:
print(f"No AI mention found in message, skipping: {request.message[:50]}...")
return {"status": "skipped", "message": "No AI mention found in message"}
# Extract the actual message content (remove the trigger)
# This is a simple approach - for more complex cases, you might want more sophisticated parsing
processed_message = request.message
message_lower = request.message.lower()
# Only try to remove triggers if we're not responding to all messages
if not config.AI_RESPOND_TO_ALL:
for trigger in config.AI_TRIGGERS:
if trigger.lower() in message_lower:
# Remove the trigger from the message
processed_message = request.message.replace(trigger, "").strip()
break
# Process the message and generate a response
response = chat_service.get_chat_response(
chat_id=chat_id,
message=processed_message,
user_id=request.user_id
)
return {"status": "success", "message": "Message processed", "response": response}
except Exception as e:
print(f"Error processing channel message webhook: {str(e)}")
return {"status": "error", "message": f"Error processing message: {str(e)}"}