""" FastAPI application for the AI service. This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints. The service supports document-based question answering using OpenWebUI's knowledge database: - Set use_rag=True in API requests to enable Retrieval Augmented Generation - When enabled, the service will use OpenWebUI's knowledge database to find relevant information - Documents uploaded to OpenWebUI will be used to augment the model's responses """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional import uuid from datetime import datetime, timezone from ai_service.models.model_service import model_service from ai_service.models.chat_service import chat_service from ai_service.openwebui_api import router as openwebui_router from ai_service.openwebui_channels import openwebui_channels from ai_service.config import config # Create FastAPI app app = FastAPI( title="AI Service API", description="Backend API for OpenWebUI", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins allow_credentials=True, allow_methods=["*"], # Allow all methods allow_headers=["*"], # Allow all headers ) # Include OpenWebUI-compatible API routes app.include_router(openwebui_router, prefix="/api") # Include Ollama proxy routes app.include_router(openwebui_router, prefix="/ollama") # Register webhook for channel messages on startup @app.on_event("startup") async def startup_event(): """ Register webhook for channel messages on startup. """ # Get the public URL of this service service_url = f"http://{config.API_HOST}:{config.API_PORT}" if config.PUBLIC_URL: service_url = config.PUBLIC_URL # Register webhook webhook_url = f"{service_url}/webhooks/channel-message" print(f"Registering webhook for channel messages: {webhook_url}") success = openwebui_channels.register_webhook(webhook_url) if success: print("Successfully registered webhook for channel messages") else: print("Failed to register webhook for channel messages") # Define API models for health check class HealthResponse(BaseModel): """Response model for health check.""" status: str = Field(..., description="Health status") class ModelInfo(BaseModel): """Model for model information.""" id: str = Field(..., description="Model ID") name: str = Field(..., description="Model name") description: str = Field(..., description="Model description") provider: str = Field(..., description="Model provider") max_tokens: int = Field(..., description="Maximum tokens") is_default: bool = Field(..., description="Whether this is the default model") class ChatRequest(BaseModel): """Request model for creating a chat.""" user_id: str = Field(..., description="User ID") title: Optional[str] = Field(None, description="Chat title") model_id: Optional[str] = Field(None, description="Model ID") is_team_chat: bool = Field(False, description="Whether this is a team chat") class MessageRequest(BaseModel): """Request model for sending a message.""" message: str = Field(..., description="Message content") user_id: str = Field(..., description="User ID") use_rag: bool = Field(False, description="Whether to use RAG") # Model parameters temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions") max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate") top_p: Optional[float] = Field(None, description="Nucleus sampling parameter") frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens") presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics") stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating") system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior") # Additional advanced parameters min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection") top_k: Optional[int] = Field(None, description="Only sample from the top k tokens") repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens") function_calling: Optional[bool] = Field(None, description="Whether to enable function calling") class Message(BaseModel): """Model for a message.""" id: str = Field(..., description="Message ID") content: str = Field(..., description="Message content") user_id: Optional[str] = Field(None, description="User ID") is_user_message: bool = Field(..., description="Whether this is a user message") timestamp: str = Field(..., description="Message timestamp") class Chat(BaseModel): """Model for a chat.""" id: str = Field(..., description="Chat ID") title: str = Field(..., description="Chat title") user_id: str = Field(..., description="User ID") model_id: str = Field(..., description="Model ID") is_team_chat: bool = Field(..., description="Whether this is a team chat") created_at: str = Field(..., description="Creation timestamp") updated_at: str = Field(..., description="Update timestamp") messages: List[Message] = Field(..., description="Chat messages") team_members: List[str] = Field(..., description="Team members") # Define API endpoints @app.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint. Returns: Health status. """ return {"status": "healthy"} @app.get("/test-ollama") async def test_ollama_connection(): """ Test the connection to the Ollama API. Returns: Connection status and available models from Ollama. """ import requests try: # Try to connect to Ollama API response = requests.get(f"{config.OLLAMA_API_URL}/api/tags", timeout=config.API_TIMEOUT) response.raise_for_status() # Return the models from Ollama return { "status": "success", "message": "Successfully connected to Ollama API", "ollama_url": config.OLLAMA_API_URL, "models": response.json() } except requests.exceptions.Timeout as e: return { "status": "error", "message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.", "ollama_url": config.OLLAMA_API_URL } except requests.exceptions.ConnectionError as e: return { "status": "error", "message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.", "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to connect to Ollama API: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.post("/test-chat") async def test_chat_completion(): """ Test the chat completion with a simple prompt. Returns: Model response. """ try: # Use the model service directly response = model_service.generate_response( model_id=config.DEFAULT_MODEL, prompt="Hello, how are you?", context=[], use_rag=False ) return { "status": "success", "model": config.DEFAULT_MODEL, "response": response, "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to get chat completion: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.post("/test-rag") async def test_rag_completion(query: str = "What information do you have in your knowledge database?"): """ Test the RAG (Retrieval Augmented Generation) functionality with a query. This endpoint tests the integration with OpenWebUI's knowledge database. Args: query: The question to ask about documents in the knowledge database. Returns: Model response using RAG. """ try: # Use the model service directly with RAG enabled response = model_service.generate_response( model_id=config.DEFAULT_MODEL, prompt=query, context=[], use_rag=True # Enable RAG ) return { "status": "success", "model": config.DEFAULT_MODEL, "query": query, "use_rag": True, "response": response, "openwebui_url": config.OPENWEBUI_URL } except Exception as e: return { "status": "error", "message": f"Failed to get RAG completion: {str(e)}", "openwebui_url": config.OPENWEBUI_URL } @app.post("/test-ollama-direct") async def test_ollama_direct(): """ Test the Ollama API directly with a simple chat request. Returns: Raw Ollama API response. """ import requests try: # Prepare a simple chat request request_json = { "model": config.DEFAULT_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} ], "stream": False } # Make the API call to Ollama print(f"Sending direct request to Ollama API at: {config.OLLAMA_API_URL}/api/chat") response = requests.post( f"{config.OLLAMA_API_URL}/api/chat", headers={"Content-Type": "application/json"}, json=request_json, timeout=config.API_TIMEOUT ) response.raise_for_status() result = response.json() return { "status": "success", "ollama_url": config.OLLAMA_API_URL, "request": request_json, "response": result } except requests.exceptions.Timeout as e: return { "status": "error", "message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.", "ollama_url": config.OLLAMA_API_URL } except requests.exceptions.ConnectionError as e: return { "status": "error", "message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.", "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to connect to Ollama API: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.get("/config") async def get_config(): """ Get the current configuration. Returns: Current configuration settings. """ return { "api_host": config.API_HOST, "api_port": config.API_PORT, "openwebui_url": config.OPENWEBUI_URL, "ollama_api_url": config.OLLAMA_API_URL, "default_model": config.DEFAULT_MODEL, "api_timeout": config.API_TIMEOUT, "available_models": list(model_service.AVAILABLE_MODELS.keys()) } # Model endpoints @app.get("/models", response_model=List[ModelInfo]) async def get_available_models(): """ Get available models. Returns: List of model information. """ models = model_service.get_available_models() # Debug log print(f"API models: {models}") return models @app.get("/models/{model_id}", response_model=ModelInfo) async def get_model_info(model_id: str): """ Get information about a model. Args: model_id: Model ID. Returns: Model information. """ model_info = model_service.get_model_info(model_id) if not model_info: raise HTTPException(status_code=404, detail="Model not found") return model_info # Chat endpoints @app.post("/chats", response_model=Chat) async def create_chat(request: ChatRequest): """ Create a new chat. Args: request: Chat creation request. Returns: Created chat. """ chat_id = chat_service.create_chat( user_id=request.user_id, title=request.title, model_id=request.model_id, is_team_chat=request.is_team_chat ) return chat_service.get_chat(chat_id) @app.get("/chats/user/{user_id}", response_model=List[Chat]) async def get_user_chats(user_id: str): """ Get all chats for a user. Args: user_id: User ID. Returns: List of chats. """ return chat_service.get_user_chats(user_id) @app.get("/chats/{chat_id}", response_model=Chat) async def get_chat(chat_id: str): """ Get a chat by ID. Args: chat_id: Chat ID. Returns: Chat information. """ chat = chat_service.get_chat(chat_id) if not chat: raise HTTPException(status_code=404, detail="Chat not found") return chat @app.post("/chats/{chat_id}/messages", response_model=Message) async def send_message(chat_id: str, request: MessageRequest): """ Send a message to a chat. Args: chat_id: Chat ID. request: Message request with optional model parameters. Returns: Bot response message. """ try: print(f"Processing message for chat {chat_id} from user {request.user_id}") print(f"Message: {request.message[:50]}...") # Print first 50 chars of message print(f"Using RAG: {request.use_rag}") print(f"Model parameters: temperature={request.temperature}, max_tokens={request.max_tokens}") # Extract model parameters from the request response = chat_service.get_chat_response( chat_id=chat_id, message=request.message, user_id=request.user_id, use_rag=request.use_rag, temperature=request.temperature, max_tokens=request.max_tokens, top_p=request.top_p, frequency_penalty=request.frequency_penalty, presence_penalty=request.presence_penalty, stop_sequences=request.stop_sequences, system_prompt=request.system_prompt, min_p=request.min_p, top_k=request.top_k, repeat_penalty=request.repeat_penalty, function_calling=request.function_calling ) print(f"Response received. Length: {len(response.get('content', ''))}") return response except ValueError as e: error_msg = f"Chat not found: {str(e)}" print(f"ERROR: {error_msg}") raise HTTPException(status_code=404, detail=error_msg) except Exception as e: error_msg = f"Error processing message: {str(e)}" print(f"ERROR: {error_msg}") # Return an error message instead of raising an exception # This ensures the client gets a proper response return { "id": str(uuid.uuid4()), "content": f"Error processing message: {str(e)}", "user_id": None, "is_user_message": False, "timestamp": datetime.now(timezone.utc).isoformat() } @app.post("/chats/{chat_id}/members/{user_id}") async def add_team_member(chat_id: str, user_id: str): """ Add a user to a team chat. Args: chat_id: Chat ID. user_id: User ID. Returns: Addition status. """ success = chat_service.add_team_member(chat_id, user_id) if not success: raise HTTPException(status_code=400, detail="Failed to add team member") return {"status": "success", "message": "Team member added"} @app.delete("/chats/{chat_id}/members/{user_id}") async def remove_team_member(chat_id: str, user_id: str): """ Remove a user from a team chat. Args: chat_id: Chat ID. user_id: User ID. Returns: Removal status. """ success = chat_service.remove_team_member(chat_id, user_id) if not success: raise HTTPException(status_code=400, detail="Failed to remove team member") return {"status": "success", "message": "Team member removed"} @app.delete("/chats/{chat_id}") async def delete_chat(chat_id: str): """ Delete a chat. Args: chat_id: Chat ID. Returns: Deletion status. """ success = chat_service.delete_chat(chat_id) if not success: raise HTTPException(status_code=404, detail="Chat not found") return {"status": "success", "message": "Chat deleted"} # OpenWebUI Channels endpoints @app.get("/channels") async def get_openwebui_channels(): """ Get all OpenWebUI channels. Returns: List of channels. """ channels = openwebui_channels.get_channels() return channels @app.get("/channels/{channel_id}") async def get_openwebui_channel(channel_id: str): """ Get an OpenWebUI channel by ID. Args: channel_id: Channel ID. Returns: Channel information. """ channel = openwebui_channels.get_channel(channel_id) if not channel: raise HTTPException(status_code=404, detail="Channel not found") return channel @app.post("/channels") async def create_openwebui_channel(name: str, description: str = "", is_private: bool = False): """ Create a new OpenWebUI channel. Args: name: Channel name. description: Channel description. is_private: Whether the channel is private. Returns: Created channel. """ channel = openwebui_channels.create_channel(name, description, is_private) if not channel: raise HTTPException(status_code=400, detail="Failed to create channel") return channel # Webhook endpoint for OpenWebUI channel messages class ChannelMessageWebhook(BaseModel): """Model for channel message webhook.""" channel_id: str = Field(..., description="Channel ID") message: str = Field(..., description="Message content") user_id: str = Field(..., description="User ID") timestamp: Optional[str] = Field(None, description="Message timestamp") @app.post("/webhooks/channel-message") async def channel_message_webhook(request: ChannelMessageWebhook): """ Webhook endpoint for receiving messages from OpenWebUI channels. This endpoint is called by OpenWebUI when a message is sent in a channel. The AI service will process the message and respond in the channel. Args: request: Channel message webhook request. Returns: Processing status. """ try: print(f"Received channel message webhook: {request.channel_id}, {request.user_id}, {request.message}") # Find the chat associated with this OpenWebUI channel chat_id = None for cid, chat in chat_service.chats.items(): if chat.get('is_team_chat') and chat.get('openwebui_channel_id') == request.channel_id: chat_id = cid break if not chat_id: print(f"No chat found for OpenWebUI channel {request.channel_id}") return {"status": "error", "message": "No chat found for this channel"} # Skip messages from the AI assistant to avoid loops if request.user_id == "ai-assistant": return {"status": "skipped", "message": "Skipping AI assistant message"} # Check if we should respond to all messages or only to mentions if not config.AI_RESPOND_TO_ALL: # Check if the message mentions the AI using configured triggers message_lower = request.message.lower() is_triggered = False for trigger in config.AI_TRIGGERS: if trigger.lower() in message_lower: is_triggered = True break # If no trigger is found, skip processing if not is_triggered: print(f"No AI mention found in message, skipping: {request.message[:50]}...") return {"status": "skipped", "message": "No AI mention found in message"} # Extract the actual message content (remove the trigger) # This is a simple approach - for more complex cases, you might want more sophisticated parsing processed_message = request.message message_lower = request.message.lower() # Only try to remove triggers if we're not responding to all messages if not config.AI_RESPOND_TO_ALL: for trigger in config.AI_TRIGGERS: if trigger.lower() in message_lower: # Remove the trigger from the message processed_message = request.message.replace(trigger, "").strip() break # Process the message and generate a response response = chat_service.get_chat_response( chat_id=chat_id, message=processed_message, user_id=request.user_id ) return {"status": "success", "message": "Message processed", "response": response} except Exception as e: print(f"Error processing channel message webhook: {str(e)}") return {"status": "error", "message": f"Error processing message: {str(e)}"}