Using the ai_service as backend for openwebui
This commit is contained in:
+14
-112
@@ -1,22 +1,21 @@
|
||||
"""
|
||||
FastAPI application for the AI service.
|
||||
This service acts as a backend for OpenWebUI, providing OpenWebUI-compatible API endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, Query, Path
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from ai_service.config import config
|
||||
from ai_service.embeddings.document_service import document_service
|
||||
from ai_service.models.model_service import model_service
|
||||
from ai_service.models.chat_service import chat_service
|
||||
from ai_service.models.model_parameters import ModelParameters
|
||||
from ai_service.openwebui_api import router as openwebui_router
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="AI Service API",
|
||||
description="API for the AI service",
|
||||
description="Backend API for OpenWebUI",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
@@ -29,32 +28,16 @@ app.add_middleware(
|
||||
allow_headers=["*"], # Allow all headers
|
||||
)
|
||||
|
||||
# Define API models
|
||||
class DocumentRequest(BaseModel):
|
||||
"""Request model for document processing."""
|
||||
content: str = Field(..., description="Document content")
|
||||
title: str = Field(..., description="Document title")
|
||||
description: Optional[str] = Field(None, description="Document description")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
# Include OpenWebUI-compatible API routes
|
||||
app.include_router(openwebui_router, prefix="/api")
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
"""Response model for document processing."""
|
||||
id: str = Field(..., description="Document ID")
|
||||
title: str = Field(..., description="Document title")
|
||||
description: str = Field(..., description="Document description")
|
||||
chunk_count: int = Field(..., description="Number of chunks")
|
||||
metadata: Dict[str, Any] = Field(..., description="Additional metadata")
|
||||
# Include Ollama proxy routes
|
||||
app.include_router(openwebui_router, prefix="/ollama")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request model for document search."""
|
||||
query: str = Field(..., description="Search query")
|
||||
top_k: int = Field(5, description="Number of results to return")
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Model for a search result."""
|
||||
id: str = Field(..., description="Result ID")
|
||||
score: float = Field(..., description="Similarity score")
|
||||
metadata: Dict[str, Any] = Field(..., description="Result metadata")
|
||||
# Define API models for health check
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
status: str = Field(..., description="Health status")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model for model information."""
|
||||
@@ -114,7 +97,7 @@ class Chat(BaseModel):
|
||||
team_members: List[str] = Field(..., description="Team members")
|
||||
|
||||
# Define API endpoints
|
||||
@app.get("/health")
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
@@ -124,88 +107,7 @@ async def health_check():
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
# Document endpoints
|
||||
@app.post("/documents", response_model=DocumentResponse)
|
||||
async def process_document(request: DocumentRequest):
|
||||
"""
|
||||
Process a document for embedding.
|
||||
|
||||
Args:
|
||||
request: Document processing request.
|
||||
|
||||
Returns:
|
||||
Processed document information.
|
||||
"""
|
||||
doc_id = document_service.process_document(
|
||||
content=request.content,
|
||||
title=request.title,
|
||||
description=request.description,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return document_service.get_document(doc_id)
|
||||
|
||||
@app.get("/documents", response_model=List[DocumentResponse])
|
||||
async def get_all_documents():
|
||||
"""
|
||||
Get all documents.
|
||||
|
||||
Returns:
|
||||
List of document information.
|
||||
"""
|
||||
return document_service.get_all_documents()
|
||||
|
||||
@app.get("/documents/{doc_id}", response_model=DocumentResponse)
|
||||
async def get_document(doc_id: str):
|
||||
"""
|
||||
Get a document by ID.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Document information.
|
||||
"""
|
||||
doc = document_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return doc
|
||||
|
||||
@app.delete("/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""
|
||||
Delete a document.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Deletion status.
|
||||
"""
|
||||
success = document_service.delete_document(doc_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {"status": "success", "message": "Document deleted"}
|
||||
|
||||
@app.post("/documents/search", response_model=List[SearchResult])
|
||||
async def search_documents(request: SearchRequest):
|
||||
"""
|
||||
Search for documents.
|
||||
|
||||
Args:
|
||||
request: Search request.
|
||||
|
||||
Returns:
|
||||
Search results.
|
||||
"""
|
||||
results = document_service.search_documents(
|
||||
query=request.query,
|
||||
top_k=request.top_k
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# Model endpoints
|
||||
@app.get("/models", response_model=List[ModelInfo])
|
||||
|
||||
Reference in New Issue
Block a user