Files
ds_zagres_ai/ai_service/api.py
T
Iyeoluwa Akinrinola e82861a5db Latest fixxes
2025-05-16 13:23:35 +01:00

605 lines
20 KiB
Python

"""
FastAPI application for the AI service.
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional
import uuid
from datetime import datetime, timezone
from ai_service.models.model_service import model_service
from ai_service.models.chat_service import chat_service
from ai_service.openwebui_api import router as openwebui_router
from ai_service.openwebui_channels import openwebui_channels
from ai_service.config import config
# Create FastAPI app
app = FastAPI(
title="AI Service API",
description="Backend API for OpenWebUI",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
# Include OpenWebUI-compatible API routes
app.include_router(openwebui_router, prefix="/api")
# Include Ollama proxy routes
app.include_router(openwebui_router, prefix="/ollama")
# Register webhook for channel messages on startup
@app.on_event("startup")
async def startup_event():
"""
Register webhook for channel messages on startup.
"""
# Get the public URL of this service
service_url = f"http://{config.API_HOST}:{config.API_PORT}"
if config.PUBLIC_URL:
service_url = config.PUBLIC_URL
# Register webhook
webhook_url = f"{service_url}/webhooks/channel-message"
print(f"Registering webhook for channel messages: {webhook_url}")
success = openwebui_channels.register_webhook(webhook_url)
if success:
print("Successfully registered webhook for channel messages")
else:
print("Failed to register webhook for channel messages")
# Define API models for health check
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str = Field(..., description="Health status")
class ModelInfo(BaseModel):
"""Model for model information."""
id: str = Field(..., description="Model ID")
name: str = Field(..., description="Model name")
description: str = Field(..., description="Model description")
provider: str = Field(..., description="Model provider")
max_tokens: int = Field(..., description="Maximum tokens")
is_default: bool = Field(..., description="Whether this is the default model")
class ChatRequest(BaseModel):
"""Request model for creating a chat."""
user_id: str = Field(..., description="User ID")
title: Optional[str] = Field(None, description="Chat title")
model_id: Optional[str] = Field(None, description="Model ID")
is_team_chat: bool = Field(False, description="Whether this is a team chat")
class MessageRequest(BaseModel):
"""Request model for sending a message."""
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
use_rag: bool = Field(False, description="Whether to use RAG")
# Model parameters
temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions")
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating")
system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior")
# Additional advanced parameters
min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection")
top_k: Optional[int] = Field(None, description="Only sample from the top k tokens")
repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens")
function_calling: Optional[bool] = Field(None, description="Whether to enable function calling")
class Message(BaseModel):
"""Model for a message."""
id: str = Field(..., description="Message ID")
content: str = Field(..., description="Message content")
user_id: Optional[str] = Field(None, description="User ID")
is_user_message: bool = Field(..., description="Whether this is a user message")
timestamp: str = Field(..., description="Message timestamp")
class Chat(BaseModel):
"""Model for a chat."""
id: str = Field(..., description="Chat ID")
title: str = Field(..., description="Chat title")
user_id: str = Field(..., description="User ID")
model_id: str = Field(..., description="Model ID")
is_team_chat: bool = Field(..., description="Whether this is a team chat")
created_at: str = Field(..., description="Creation timestamp")
updated_at: str = Field(..., description="Update timestamp")
messages: List[Message] = Field(..., description="Chat messages")
team_members: List[str] = Field(..., description="Team members")
# Define API endpoints
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Returns:
Health status.
"""
return {"status": "healthy"}
@app.get("/test-ollama")
async def test_ollama_connection():
"""
Test the connection to the Ollama API.
Returns:
Connection status and available models from Ollama.
"""
import requests
try:
# Try to connect to Ollama API
response = requests.get(f"{config.OLLAMA_API_URL}/api/tags", timeout=config.API_TIMEOUT)
response.raise_for_status()
# Return the models from Ollama
return {
"status": "success",
"message": "Successfully connected to Ollama API",
"ollama_url": config.OLLAMA_API_URL,
"models": response.json()
}
except requests.exceptions.Timeout as e:
return {
"status": "error",
"message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.",
"ollama_url": config.OLLAMA_API_URL
}
except requests.exceptions.ConnectionError as e:
return {
"status": "error",
"message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.",
"ollama_url": config.OLLAMA_API_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to Ollama API: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
@app.post("/test-chat")
async def test_chat_completion():
"""
Test the chat completion with a simple prompt.
Returns:
Model response.
"""
try:
# Use the model service directly
response = model_service.generate_response(
model_id=config.DEFAULT_MODEL,
prompt="Hello, how are you?",
context=[],
use_rag=False
)
return {
"status": "success",
"model": config.DEFAULT_MODEL,
"response": response,
"ollama_url": config.OLLAMA_API_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to get chat completion: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
@app.post("/test-ollama-direct")
async def test_ollama_direct():
"""
Test the Ollama API directly with a simple chat request.
Returns:
Raw Ollama API response.
"""
import requests
try:
# Prepare a simple chat request
request_json = {
"model": config.DEFAULT_MODEL,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"}
],
"stream": False
}
# Make the API call to Ollama
print(f"Sending direct request to Ollama API at: {config.OLLAMA_API_URL}/api/chat")
response = requests.post(
f"{config.OLLAMA_API_URL}/api/chat",
headers={"Content-Type": "application/json"},
json=request_json,
timeout=config.API_TIMEOUT
)
response.raise_for_status()
result = response.json()
return {
"status": "success",
"ollama_url": config.OLLAMA_API_URL,
"request": request_json,
"response": result
}
except requests.exceptions.Timeout as e:
return {
"status": "error",
"message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.",
"ollama_url": config.OLLAMA_API_URL
}
except requests.exceptions.ConnectionError as e:
return {
"status": "error",
"message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.",
"ollama_url": config.OLLAMA_API_URL
}
except Exception as e:
return {
"status": "error",
"message": f"Failed to connect to Ollama API: {str(e)}",
"ollama_url": config.OLLAMA_API_URL
}
@app.get("/config")
async def get_config():
"""
Get the current configuration.
Returns:
Current configuration settings.
"""
return {
"api_host": config.API_HOST,
"api_port": config.API_PORT,
"openwebui_url": config.OPENWEBUI_URL,
"ollama_api_url": config.OLLAMA_API_URL,
"default_model": config.DEFAULT_MODEL,
"api_timeout": config.API_TIMEOUT,
"available_models": list(model_service.AVAILABLE_MODELS.keys())
}
# Model endpoints
@app.get("/models", response_model=List[ModelInfo])
async def get_available_models():
"""
Get available models.
Returns:
List of model information.
"""
models = model_service.get_available_models()
# Debug log
print(f"API models: {models}")
return models
@app.get("/models/{model_id}", response_model=ModelInfo)
async def get_model_info(model_id: str):
"""
Get information about a model.
Args:
model_id: Model ID.
Returns:
Model information.
"""
model_info = model_service.get_model_info(model_id)
if not model_info:
raise HTTPException(status_code=404, detail="Model not found")
return model_info
# Chat endpoints
@app.post("/chats", response_model=Chat)
async def create_chat(request: ChatRequest):
"""
Create a new chat.
Args:
request: Chat creation request.
Returns:
Created chat.
"""
chat_id = chat_service.create_chat(
user_id=request.user_id,
title=request.title,
model_id=request.model_id,
is_team_chat=request.is_team_chat
)
return chat_service.get_chat(chat_id)
@app.get("/chats/user/{user_id}", response_model=List[Chat])
async def get_user_chats(user_id: str):
"""
Get all chats for a user.
Args:
user_id: User ID.
Returns:
List of chats.
"""
return chat_service.get_user_chats(user_id)
@app.get("/chats/{chat_id}", response_model=Chat)
async def get_chat(chat_id: str):
"""
Get a chat by ID.
Args:
chat_id: Chat ID.
Returns:
Chat information.
"""
chat = chat_service.get_chat(chat_id)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
return chat
@app.post("/chats/{chat_id}/messages", response_model=Message)
async def send_message(chat_id: str, request: MessageRequest):
"""
Send a message to a chat.
Args:
chat_id: Chat ID.
request: Message request with optional model parameters.
Returns:
Bot response message.
"""
try:
print(f"Processing message for chat {chat_id} from user {request.user_id}")
print(f"Message: {request.message[:50]}...") # Print first 50 chars of message
print(f"Using RAG: {request.use_rag}")
print(f"Model parameters: temperature={request.temperature}, max_tokens={request.max_tokens}")
# Extract model parameters from the request
response = chat_service.get_chat_response(
chat_id=chat_id,
message=request.message,
user_id=request.user_id,
use_rag=request.use_rag,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
stop_sequences=request.stop_sequences,
system_prompt=request.system_prompt,
min_p=request.min_p,
top_k=request.top_k,
repeat_penalty=request.repeat_penalty,
function_calling=request.function_calling
)
print(f"Response received. Length: {len(response.get('content', ''))}")
return response
except ValueError as e:
error_msg = f"Chat not found: {str(e)}"
print(f"ERROR: {error_msg}")
raise HTTPException(status_code=404, detail=error_msg)
except Exception as e:
error_msg = f"Error processing message: {str(e)}"
print(f"ERROR: {error_msg}")
# Return an error message instead of raising an exception
# This ensures the client gets a proper response
return {
"id": str(uuid.uuid4()),
"content": f"Error processing message: {str(e)}",
"user_id": None,
"is_user_message": False,
"timestamp": datetime.now(timezone.utc).isoformat()
}
@app.post("/chats/{chat_id}/members/{user_id}")
async def add_team_member(chat_id: str, user_id: str):
"""
Add a user to a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Addition status.
"""
success = chat_service.add_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to add team member")
return {"status": "success", "message": "Team member added"}
@app.delete("/chats/{chat_id}/members/{user_id}")
async def remove_team_member(chat_id: str, user_id: str):
"""
Remove a user from a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Removal status.
"""
success = chat_service.remove_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to remove team member")
return {"status": "success", "message": "Team member removed"}
@app.delete("/chats/{chat_id}")
async def delete_chat(chat_id: str):
"""
Delete a chat.
Args:
chat_id: Chat ID.
Returns:
Deletion status.
"""
success = chat_service.delete_chat(chat_id)
if not success:
raise HTTPException(status_code=404, detail="Chat not found")
return {"status": "success", "message": "Chat deleted"}
# OpenWebUI Channels endpoints
@app.get("/channels")
async def get_openwebui_channels():
"""
Get all OpenWebUI channels.
Returns:
List of channels.
"""
channels = openwebui_channels.get_channels()
return channels
@app.get("/channels/{channel_id}")
async def get_openwebui_channel(channel_id: str):
"""
Get an OpenWebUI channel by ID.
Args:
channel_id: Channel ID.
Returns:
Channel information.
"""
channel = openwebui_channels.get_channel(channel_id)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
@app.post("/channels")
async def create_openwebui_channel(name: str, description: str = "", is_private: bool = False):
"""
Create a new OpenWebUI channel.
Args:
name: Channel name.
description: Channel description.
is_private: Whether the channel is private.
Returns:
Created channel.
"""
channel = openwebui_channels.create_channel(name, description, is_private)
if not channel:
raise HTTPException(status_code=400, detail="Failed to create channel")
return channel
# Webhook endpoint for OpenWebUI channel messages
class ChannelMessageWebhook(BaseModel):
"""Model for channel message webhook."""
channel_id: str = Field(..., description="Channel ID")
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
timestamp: Optional[str] = Field(None, description="Message timestamp")
@app.post("/webhooks/channel-message")
async def channel_message_webhook(request: ChannelMessageWebhook):
"""
Webhook endpoint for receiving messages from OpenWebUI channels.
This endpoint is called by OpenWebUI when a message is sent in a channel.
The AI service will process the message and respond in the channel.
Args:
request: Channel message webhook request.
Returns:
Processing status.
"""
try:
print(f"Received channel message webhook: {request.channel_id}, {request.user_id}, {request.message}")
# Find the chat associated with this OpenWebUI channel
chat_id = None
for cid, chat in chat_service.chats.items():
if chat.get('is_team_chat') and chat.get('openwebui_channel_id') == request.channel_id:
chat_id = cid
break
if not chat_id:
print(f"No chat found for OpenWebUI channel {request.channel_id}")
return {"status": "error", "message": "No chat found for this channel"}
# Skip messages from the AI assistant to avoid loops
if request.user_id == "ai-assistant":
return {"status": "skipped", "message": "Skipping AI assistant message"}
# Check if we should respond to all messages or only to mentions
if not config.AI_RESPOND_TO_ALL:
# Check if the message mentions the AI using configured triggers
message_lower = request.message.lower()
is_triggered = False
for trigger in config.AI_TRIGGERS:
if trigger.lower() in message_lower:
is_triggered = True
break
# If no trigger is found, skip processing
if not is_triggered:
print(f"No AI mention found in message, skipping: {request.message[:50]}...")
return {"status": "skipped", "message": "No AI mention found in message"}
# Extract the actual message content (remove the trigger)
# This is a simple approach - for more complex cases, you might want more sophisticated parsing
processed_message = request.message
message_lower = request.message.lower()
# Only try to remove triggers if we're not responding to all messages
if not config.AI_RESPOND_TO_ALL:
for trigger in config.AI_TRIGGERS:
if trigger.lower() in message_lower:
# Remove the trigger from the message
processed_message = request.message.replace(trigger, "").strip()
break
# Process the message and generate a response
response = chat_service.get_chat_response(
chat_id=chat_id,
message=processed_message,
user_id=request.user_id
)
return {"status": "success", "message": "Message processed", "response": response}
except Exception as e:
print(f"Error processing channel message webhook: {str(e)}")
return {"status": "error", "message": f"Error processing message: {str(e)}"}