176 lines
5.0 KiB
Python
176 lines
5.0 KiB
Python
"""
|
|
OpenWebUI-compatible API endpoints for the AI service.
|
|
"""
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from typing import List, Dict, Any, Optional, Union
|
|
import json
|
|
import time
|
|
import uuid
|
|
|
|
from ai_service.models.model_service import model_service
|
|
from ai_service.models.chat_service import chat_service
|
|
from ai_service.models.model_parameters import ModelParameters
|
|
|
|
# Create router
|
|
router = APIRouter()
|
|
|
|
# Models endpoint
|
|
@router.get("/models", response_model=List[Dict[str, Any]])
|
|
async def get_models():
|
|
"""
|
|
Get available models in OpenWebUI-compatible format.
|
|
"""
|
|
models = model_service.get_available_models()
|
|
|
|
# Convert to OpenWebUI format
|
|
openwebui_models = []
|
|
for model in models:
|
|
openwebui_models.append({
|
|
"id": model["id"],
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "user",
|
|
"permission": [],
|
|
"root": model["id"],
|
|
"parent": None
|
|
})
|
|
|
|
return openwebui_models
|
|
|
|
# Chat completions endpoint (OpenAI-compatible)
|
|
@router.post("/chat/completions")
|
|
async def chat_completions(request: Request):
|
|
"""
|
|
OpenAI-compatible chat completions endpoint.
|
|
"""
|
|
# Parse request body
|
|
body = await request.json()
|
|
|
|
# Extract parameters
|
|
model_id = body.get("model", "llama3.1")
|
|
messages = body.get("messages", [])
|
|
stream = body.get("stream", False)
|
|
temperature = body.get("temperature")
|
|
max_tokens = body.get("max_tokens")
|
|
top_p = body.get("top_p")
|
|
frequency_penalty = body.get("frequency_penalty")
|
|
presence_penalty = body.get("presence_penalty")
|
|
stop = body.get("stop")
|
|
|
|
# Create a unique chat ID
|
|
chat_id = str(uuid.uuid4())
|
|
|
|
# Create a user ID (in a real implementation, this would come from authentication)
|
|
user_id = "openwebui-user"
|
|
|
|
# Create a new chat
|
|
chat_service.create_chat(user_id=user_id, title="API Chat", model_id=model_id)
|
|
|
|
# Extract the user's message (last user message in the array)
|
|
user_message = None
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
user_message = msg.get("content")
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
# Get chat response
|
|
response = chat_service.get_chat_response(
|
|
chat_id=chat_id,
|
|
message=user_message,
|
|
user_id=user_id,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
frequency_penalty=frequency_penalty,
|
|
presence_penalty=presence_penalty,
|
|
stop_sequences=stop if isinstance(stop, list) else [stop] if stop else None
|
|
)
|
|
|
|
# Format response in OpenAI-compatible format
|
|
completion_id = f"chatcmpl-{str(uuid.uuid4())[:8]}"
|
|
|
|
openai_response = {
|
|
"id": completion_id,
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": model_id,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response.get("content", "")
|
|
},
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": -1, # We don't track tokens
|
|
"completion_tokens": -1,
|
|
"total_tokens": -1
|
|
}
|
|
}
|
|
|
|
# Handle streaming if requested
|
|
if stream:
|
|
async def generate_stream():
|
|
# Yield the response in the SSE format
|
|
yield f"data: {json.dumps(openai_response)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
|
|
|
return openai_response
|
|
|
|
# Health check endpoint
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""
|
|
Health check endpoint.
|
|
"""
|
|
return {"status": "healthy"}
|
|
|
|
# Ollama API proxy endpoints
|
|
@router.post("/ollama/api/generate")
|
|
async def ollama_generate(request: Request):
|
|
"""
|
|
Proxy to Ollama's generate endpoint.
|
|
"""
|
|
# Parse request body
|
|
body = await request.json()
|
|
|
|
# Extract parameters
|
|
model_id = body.get("model", "llama3.1")
|
|
prompt = body.get("prompt", "")
|
|
|
|
# Create a unique chat ID
|
|
chat_id = str(uuid.uuid4())
|
|
|
|
# Create a user ID (in a real implementation, this would come from authentication)
|
|
user_id = "openwebui-user"
|
|
|
|
# Create a new chat
|
|
chat_service.create_chat(user_id=user_id, title="API Chat", model_id=model_id)
|
|
|
|
# Get chat response
|
|
response = chat_service.get_chat_response(
|
|
chat_id=chat_id,
|
|
message=prompt,
|
|
user_id=user_id
|
|
)
|
|
|
|
# Format response in Ollama-compatible format
|
|
ollama_response = {
|
|
"model": model_id,
|
|
"created_at": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ", time.gmtime()),
|
|
"response": response.get("content", ""),
|
|
"done": True
|
|
}
|
|
|
|
return ollama_response
|