""" FastAPI application for the AI service. This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints. """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional import uuid from datetime import datetime, timezone from ai_service.models.model_service import model_service from ai_service.models.chat_service import chat_service from ai_service.openwebui_api import router as openwebui_router from ai_service.config import config # Create FastAPI app app = FastAPI( title="AI Service API", description="Backend API for OpenWebUI", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins allow_credentials=True, allow_methods=["*"], # Allow all methods allow_headers=["*"], # Allow all headers ) # Include OpenWebUI-compatible API routes app.include_router(openwebui_router, prefix="/api") # Include Ollama proxy routes app.include_router(openwebui_router, prefix="/ollama") # Define API models for health check class HealthResponse(BaseModel): """Response model for health check.""" status: str = Field(..., description="Health status") class ModelInfo(BaseModel): """Model for model information.""" id: str = Field(..., description="Model ID") name: str = Field(..., description="Model name") description: str = Field(..., description="Model description") provider: str = Field(..., description="Model provider") max_tokens: int = Field(..., description="Maximum tokens") is_default: bool = Field(..., description="Whether this is the default model") class ChatRequest(BaseModel): """Request model for creating a chat.""" user_id: str = Field(..., description="User ID") title: Optional[str] = Field(None, description="Chat title") model_id: Optional[str] = Field(None, description="Model ID") is_team_chat: bool = Field(False, description="Whether this is a team chat") class MessageRequest(BaseModel): """Request model for sending a message.""" message: str = Field(..., description="Message content") user_id: str = Field(..., description="User ID") use_rag: bool = Field(False, description="Whether to use RAG") # Model parameters temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions") max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate") top_p: Optional[float] = Field(None, description="Nucleus sampling parameter") frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens") presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics") stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating") system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior") # Additional advanced parameters min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection") top_k: Optional[int] = Field(None, description="Only sample from the top k tokens") repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens") function_calling: Optional[bool] = Field(None, description="Whether to enable function calling") class Message(BaseModel): """Model for a message.""" id: str = Field(..., description="Message ID") content: str = Field(..., description="Message content") user_id: Optional[str] = Field(None, description="User ID") is_user_message: bool = Field(..., description="Whether this is a user message") timestamp: str = Field(..., description="Message timestamp") class Chat(BaseModel): """Model for a chat.""" id: str = Field(..., description="Chat ID") title: str = Field(..., description="Chat title") user_id: str = Field(..., description="User ID") model_id: str = Field(..., description="Model ID") is_team_chat: bool = Field(..., description="Whether this is a team chat") created_at: str = Field(..., description="Creation timestamp") updated_at: str = Field(..., description="Update timestamp") messages: List[Message] = Field(..., description="Chat messages") team_members: List[str] = Field(..., description="Team members") # Define API endpoints @app.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint. Returns: Health status. """ return {"status": "healthy"} @app.get("/test-ollama") async def test_ollama_connection(): """ Test the connection to the Ollama API. Returns: Connection status and available models from Ollama. """ import requests try: # Try to connect to Ollama API response = requests.get(f"{config.OLLAMA_API_URL}/api/tags", timeout=config.API_TIMEOUT) response.raise_for_status() # Return the models from Ollama return { "status": "success", "message": "Successfully connected to Ollama API", "ollama_url": config.OLLAMA_API_URL, "models": response.json() } except requests.exceptions.Timeout as e: return { "status": "error", "message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.", "ollama_url": config.OLLAMA_API_URL } except requests.exceptions.ConnectionError as e: return { "status": "error", "message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.", "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to connect to Ollama API: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.post("/test-chat") async def test_chat_completion(): """ Test the chat completion with a simple prompt. Returns: Model response. """ try: # Use the model service directly response = model_service.generate_response( model_id=config.DEFAULT_MODEL, prompt="Hello, how are you?", context=[], use_rag=False ) return { "status": "success", "model": config.DEFAULT_MODEL, "response": response, "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to get chat completion: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.post("/test-ollama-direct") async def test_ollama_direct(): """ Test the Ollama API directly with a simple chat request. Returns: Raw Ollama API response. """ import requests try: # Prepare a simple chat request request_json = { "model": config.DEFAULT_MODEL, "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, how are you?"} ], "stream": False } # Make the API call to Ollama print(f"Sending direct request to Ollama API at: {config.OLLAMA_API_URL}/api/chat") response = requests.post( f"{config.OLLAMA_API_URL}/api/chat", headers={"Content-Type": "application/json"}, json=request_json, timeout=config.API_TIMEOUT ) response.raise_for_status() result = response.json() return { "status": "success", "ollama_url": config.OLLAMA_API_URL, "request": request_json, "response": result } except requests.exceptions.Timeout as e: return { "status": "error", "message": f"Timeout connecting to Ollama API: {str(e)}. The request exceeded the {config.API_TIMEOUT} second timeout.", "ollama_url": config.OLLAMA_API_URL } except requests.exceptions.ConnectionError as e: return { "status": "error", "message": f"Connection error to Ollama API: {str(e)}. Please check if Ollama is running at {config.OLLAMA_API_URL}.", "ollama_url": config.OLLAMA_API_URL } except Exception as e: return { "status": "error", "message": f"Failed to connect to Ollama API: {str(e)}", "ollama_url": config.OLLAMA_API_URL } @app.get("/config") async def get_config(): """ Get the current configuration. Returns: Current configuration settings. """ return { "api_host": config.API_HOST, "api_port": config.API_PORT, "openwebui_url": config.OPENWEBUI_URL, "ollama_api_url": config.OLLAMA_API_URL, "default_model": config.DEFAULT_MODEL, "api_timeout": config.API_TIMEOUT, "available_models": list(model_service.AVAILABLE_MODELS.keys()) } # Model endpoints @app.get("/models", response_model=List[ModelInfo]) async def get_available_models(): """ Get available models. Returns: List of model information. """ models = model_service.get_available_models() # Debug log print(f"API models: {models}") return models @app.get("/models/{model_id}", response_model=ModelInfo) async def get_model_info(model_id: str): """ Get information about a model. Args: model_id: Model ID. Returns: Model information. """ model_info = model_service.get_model_info(model_id) if not model_info: raise HTTPException(status_code=404, detail="Model not found") return model_info # Chat endpoints @app.post("/chats", response_model=Chat) async def create_chat(request: ChatRequest): """ Create a new chat. Args: request: Chat creation request. Returns: Created chat. """ chat_id = chat_service.create_chat( user_id=request.user_id, title=request.title, model_id=request.model_id, is_team_chat=request.is_team_chat ) return chat_service.get_chat(chat_id) @app.get("/chats/user/{user_id}", response_model=List[Chat]) async def get_user_chats(user_id: str): """ Get all chats for a user. Args: user_id: User ID. Returns: List of chats. """ return chat_service.get_user_chats(user_id) @app.get("/chats/{chat_id}", response_model=Chat) async def get_chat(chat_id: str): """ Get a chat by ID. Args: chat_id: Chat ID. Returns: Chat information. """ chat = chat_service.get_chat(chat_id) if not chat: raise HTTPException(status_code=404, detail="Chat not found") return chat @app.post("/chats/{chat_id}/messages", response_model=Message) async def send_message(chat_id: str, request: MessageRequest): """ Send a message to a chat. Args: chat_id: Chat ID. request: Message request with optional model parameters. Returns: Bot response message. """ try: print(f"Processing message for chat {chat_id} from user {request.user_id}") print(f"Message: {request.message[:50]}...") # Print first 50 chars of message print(f"Using RAG: {request.use_rag}") print(f"Model parameters: temperature={request.temperature}, max_tokens={request.max_tokens}") # Extract model parameters from the request response = chat_service.get_chat_response( chat_id=chat_id, message=request.message, user_id=request.user_id, use_rag=request.use_rag, temperature=request.temperature, max_tokens=request.max_tokens, top_p=request.top_p, frequency_penalty=request.frequency_penalty, presence_penalty=request.presence_penalty, stop_sequences=request.stop_sequences, system_prompt=request.system_prompt, min_p=request.min_p, top_k=request.top_k, repeat_penalty=request.repeat_penalty, function_calling=request.function_calling ) print(f"Response received. Length: {len(response.get('content', ''))}") return response except ValueError as e: error_msg = f"Chat not found: {str(e)}" print(f"ERROR: {error_msg}") raise HTTPException(status_code=404, detail=error_msg) except Exception as e: error_msg = f"Error processing message: {str(e)}" print(f"ERROR: {error_msg}") # Return an error message instead of raising an exception # This ensures the client gets a proper response return { "id": str(uuid.uuid4()), "content": f"Error processing message: {str(e)}", "user_id": None, "is_user_message": False, "timestamp": datetime.now(timezone.utc).isoformat() } @app.post("/chats/{chat_id}/members/{user_id}") async def add_team_member(chat_id: str, user_id: str): """ Add a user to a team chat. Args: chat_id: Chat ID. user_id: User ID. Returns: Addition status. """ success = chat_service.add_team_member(chat_id, user_id) if not success: raise HTTPException(status_code=400, detail="Failed to add team member") return {"status": "success", "message": "Team member added"} @app.delete("/chats/{chat_id}/members/{user_id}") async def remove_team_member(chat_id: str, user_id: str): """ Remove a user from a team chat. Args: chat_id: Chat ID. user_id: User ID. Returns: Removal status. """ success = chat_service.remove_team_member(chat_id, user_id) if not success: raise HTTPException(status_code=400, detail="Failed to remove team member") return {"status": "success", "message": "Team member removed"} @app.delete("/chats/{chat_id}") async def delete_chat(chat_id: str): """ Delete a chat. Args: chat_id: Chat ID. Returns: Deletion status. """ success = chat_service.delete_chat(chat_id) if not success: raise HTTPException(status_code=404, detail="Chat not found") return {"status": "success", "message": "Chat deleted"}