Files

201 lines
5.7 KiB
Python
Raw Permalink Normal View History

"""
OpenWebUI-compatible API endpoints for the AI service.
"""
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from fastapi.responses import StreamingResponse
from typing import List, Dict, Any, Optional, Union
import json
import time
import uuid
from ai_service.models.model_service import model_service
from ai_service.models.chat_service import chat_service
from ai_service.models.model_parameters import ModelParameters
# Create router
router = APIRouter()
# Models endpoint
@router.get("/models", response_model=List[Dict[str, Any]])
async def get_models():
"""
Get available models in OpenWebUI-compatible format.
"""
models = model_service.get_available_models()
2025-05-12 16:10:45 +01:00
# Convert to OpenWebUI format
openwebui_models = []
for model in models:
openwebui_models.append({
"id": model["id"],
"object": "model",
"created": int(time.time()),
"owned_by": "user",
"permission": [],
"root": model["id"],
"parent": None
})
2025-05-12 16:10:45 +01:00
# Debug log
print(f"OpenWebUI models: {json.dumps(openwebui_models, indent=2)}")
# Ensure we're returning a properly formatted list
if not openwebui_models:
# Return at least one default model if none are found
return [{
"id": "llama3.1",
"object": "model",
"created": int(time.time()),
"owned_by": "user",
"permission": [],
"root": "llama3.1",
"parent": None
}]
return openwebui_models
# Chat completions endpoint (OpenAI-compatible)
@router.post("/chat/completions")
async def chat_completions(request: Request):
"""
OpenAI-compatible chat completions endpoint.
"""
# Parse request body
body = await request.json()
2025-05-12 16:10:45 +01:00
# Extract parameters
model_id = body.get("model", "llama3.1")
messages = body.get("messages", [])
stream = body.get("stream", False)
temperature = body.get("temperature")
max_tokens = body.get("max_tokens")
top_p = body.get("top_p")
2025-05-16 15:24:01 +01:00
top_k = body.get("top_k")
frequency_penalty = body.get("frequency_penalty")
presence_penalty = body.get("presence_penalty")
2025-05-16 15:24:01 +01:00
repeat_penalty = body.get("repeat_penalty")
stop = body.get("stop")
2025-05-12 16:10:45 +01:00
2025-05-16 15:24:01 +01:00
# Check if RAG should be used
use_knowledge = body.get("use_knowledge", False)
use_rag = use_knowledge # Map OpenWebUI's use_knowledge to our use_rag parameter
# Create a unique chat ID
chat_id = str(uuid.uuid4())
2025-05-12 16:10:45 +01:00
# Create a user ID (in a real implementation, this would come from authentication)
user_id = "openwebui-user"
2025-05-12 16:10:45 +01:00
# Create a new chat
chat_service.create_chat(user_id=user_id, title="API Chat", model_id=model_id)
2025-05-12 16:10:45 +01:00
# Extract the user's message (last user message in the array)
user_message = None
for msg in reversed(messages):
if msg.get("role") == "user":
user_message = msg.get("content")
break
2025-05-12 16:10:45 +01:00
if not user_message:
raise HTTPException(status_code=400, detail="No user message found")
2025-05-12 16:10:45 +01:00
# Get chat response
response = chat_service.get_chat_response(
chat_id=chat_id,
message=user_message,
user_id=user_id,
2025-05-16 15:24:01 +01:00
use_rag=use_rag,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
2025-05-16 15:24:01 +01:00
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
2025-05-16 15:24:01 +01:00
repeat_penalty=repeat_penalty,
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
)
2025-05-12 16:10:45 +01:00
# Format response in OpenAI-compatible format
completion_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
2025-05-12 16:10:45 +01:00
openai_response = {
"id": completion_id,
"object": "chat.completion",
"created": int(time.time()),
"model": model_id,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response.get("content", "")
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": -1, # We don't track tokens
"completion_tokens": -1,
"total_tokens": -1
}
}
2025-05-12 16:10:45 +01:00
# Handle streaming if requested
if stream:
async def generate_stream():
# Yield the response in the SSE format
yield f"data: {json.dumps(openai_response)}\n\n"
yield "data: [DONE]\n\n"
2025-05-12 16:10:45 +01:00
return StreamingResponse(generate_stream(), media_type="text/event-stream")
2025-05-12 16:10:45 +01:00
return openai_response
# Health check endpoint
@router.get("/health")
async def health_check():
"""
Health check endpoint.
"""
return {"status": "healthy"}
# Ollama API proxy endpoints
@router.post("/ollama/api/generate")
async def ollama_generate(request: Request):
"""
Proxy to Ollama's generate endpoint.
"""
# Parse request body
body = await request.json()
2025-05-12 16:10:45 +01:00
# Extract parameters
model_id = body.get("model", "llama3.1")
prompt = body.get("prompt", "")
2025-05-12 16:10:45 +01:00
# Create a unique chat ID
chat_id = str(uuid.uuid4())
2025-05-12 16:10:45 +01:00
# Create a user ID (in a real implementation, this would come from authentication)
user_id = "openwebui-user"
2025-05-12 16:10:45 +01:00
# Create a new chat
chat_service.create_chat(user_id=user_id, title="API Chat", model_id=model_id)
2025-05-12 16:10:45 +01:00
# Get chat response
response = chat_service.get_chat_response(
chat_id=chat_id,
message=prompt,
user_id=user_id
)
2025-05-12 16:10:45 +01:00
# Format response in Ollama-compatible format
ollama_response = {
"model": model_id,
"created_at": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
"response": response.get("content", ""),
"done": True
}
2025-05-12 16:10:45 +01:00
return ollama_response