Files
ds_zagres_ai/ai_service/api.py
T

378 lines
12 KiB
Python
Raw Normal View History

2025-05-09 15:41:16 +01:00
"""
FastAPI application for the AI service.
"""
from fastapi import FastAPI, HTTPException, Depends, Body, Query, Path
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from ai_service.config import config
from ai_service.embeddings.document_service import document_service
from ai_service.models.model_service import model_service
from ai_service.models.chat_service import chat_service
from ai_service.models.model_parameters import ModelParameters
# Create FastAPI app
app = FastAPI(
title="AI Service API",
description="API for the AI service",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
# Define API models
class DocumentRequest(BaseModel):
"""Request model for document processing."""
content: str = Field(..., description="Document content")
title: str = Field(..., description="Document title")
description: Optional[str] = Field(None, description="Document description")
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
class DocumentResponse(BaseModel):
"""Response model for document processing."""
id: str = Field(..., description="Document ID")
title: str = Field(..., description="Document title")
description: str = Field(..., description="Document description")
chunk_count: int = Field(..., description="Number of chunks")
metadata: Dict[str, Any] = Field(..., description="Additional metadata")
class SearchRequest(BaseModel):
"""Request model for document search."""
query: str = Field(..., description="Search query")
top_k: int = Field(5, description="Number of results to return")
class SearchResult(BaseModel):
"""Model for a search result."""
id: str = Field(..., description="Result ID")
score: float = Field(..., description="Similarity score")
metadata: Dict[str, Any] = Field(..., description="Result metadata")
class ModelInfo(BaseModel):
"""Model for model information."""
id: str = Field(..., description="Model ID")
name: str = Field(..., description="Model name")
description: str = Field(..., description="Model description")
provider: str = Field(..., description="Model provider")
max_tokens: int = Field(..., description="Maximum tokens")
is_default: bool = Field(..., description="Whether this is the default model")
class ChatRequest(BaseModel):
"""Request model for creating a chat."""
user_id: str = Field(..., description="User ID")
title: Optional[str] = Field(None, description="Chat title")
model_id: Optional[str] = Field(None, description="Model ID")
is_team_chat: bool = Field(False, description="Whether this is a team chat")
class MessageRequest(BaseModel):
"""Request model for sending a message."""
message: str = Field(..., description="Message content")
user_id: str = Field(..., description="User ID")
use_rag: bool = Field(False, description="Whether to use RAG")
# Model parameters
temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions")
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating")
system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior")
# Additional advanced parameters
min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection")
top_k: Optional[int] = Field(None, description="Only sample from the top k tokens")
repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens")
function_calling: Optional[bool] = Field(None, description="Whether to enable function calling")
class Message(BaseModel):
"""Model for a message."""
id: str = Field(..., description="Message ID")
content: str = Field(..., description="Message content")
user_id: Optional[str] = Field(None, description="User ID")
is_user_message: bool = Field(..., description="Whether this is a user message")
timestamp: str = Field(..., description="Message timestamp")
class Chat(BaseModel):
"""Model for a chat."""
id: str = Field(..., description="Chat ID")
title: str = Field(..., description="Chat title")
user_id: str = Field(..., description="User ID")
model_id: str = Field(..., description="Model ID")
is_team_chat: bool = Field(..., description="Whether this is a team chat")
created_at: str = Field(..., description="Creation timestamp")
updated_at: str = Field(..., description="Update timestamp")
messages: List[Message] = Field(..., description="Chat messages")
team_members: List[str] = Field(..., description="Team members")
# Define API endpoints
@app.get("/health")
async def health_check():
"""
Health check endpoint.
Returns:
Health status.
"""
return {"status": "healthy"}
# Document endpoints
@app.post("/documents", response_model=DocumentResponse)
async def process_document(request: DocumentRequest):
"""
Process a document for embedding.
Args:
request: Document processing request.
Returns:
Processed document information.
"""
doc_id = document_service.process_document(
content=request.content,
title=request.title,
description=request.description,
metadata=request.metadata
)
return document_service.get_document(doc_id)
@app.get("/documents", response_model=List[DocumentResponse])
async def get_all_documents():
"""
Get all documents.
Returns:
List of document information.
"""
return document_service.get_all_documents()
@app.get("/documents/{doc_id}", response_model=DocumentResponse)
async def get_document(doc_id: str):
"""
Get a document by ID.
Args:
doc_id: Document ID.
Returns:
Document information.
"""
doc = document_service.get_document(doc_id)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
return doc
@app.delete("/documents/{doc_id}")
async def delete_document(doc_id: str):
"""
Delete a document.
Args:
doc_id: Document ID.
Returns:
Deletion status.
"""
success = document_service.delete_document(doc_id)
if not success:
raise HTTPException(status_code=404, detail="Document not found")
return {"status": "success", "message": "Document deleted"}
@app.post("/documents/search", response_model=List[SearchResult])
async def search_documents(request: SearchRequest):
"""
Search for documents.
Args:
request: Search request.
Returns:
Search results.
"""
results = document_service.search_documents(
query=request.query,
top_k=request.top_k
)
return results
# Model endpoints
@app.get("/models", response_model=List[ModelInfo])
async def get_available_models():
"""
Get available models.
Returns:
List of model information.
"""
return model_service.get_available_models()
@app.get("/models/{model_id}", response_model=ModelInfo)
async def get_model_info(model_id: str):
"""
Get information about a model.
Args:
model_id: Model ID.
Returns:
Model information.
"""
model_info = model_service.get_model_info(model_id)
if not model_info:
raise HTTPException(status_code=404, detail="Model not found")
return model_info
# Chat endpoints
@app.post("/chats", response_model=Chat)
async def create_chat(request: ChatRequest):
"""
Create a new chat.
Args:
request: Chat creation request.
Returns:
Created chat.
"""
chat_id = chat_service.create_chat(
user_id=request.user_id,
title=request.title,
model_id=request.model_id,
is_team_chat=request.is_team_chat
)
return chat_service.get_chat(chat_id)
@app.get("/chats/user/{user_id}", response_model=List[Chat])
async def get_user_chats(user_id: str):
"""
Get all chats for a user.
Args:
user_id: User ID.
Returns:
List of chats.
"""
return chat_service.get_user_chats(user_id)
@app.get("/chats/{chat_id}", response_model=Chat)
async def get_chat(chat_id: str):
"""
Get a chat by ID.
Args:
chat_id: Chat ID.
Returns:
Chat information.
"""
chat = chat_service.get_chat(chat_id)
if not chat:
raise HTTPException(status_code=404, detail="Chat not found")
return chat
@app.post("/chats/{chat_id}/messages", response_model=Message)
async def send_message(chat_id: str, request: MessageRequest):
"""
Send a message to a chat.
Args:
chat_id: Chat ID.
request: Message request with optional model parameters.
Returns:
Bot response message.
"""
try:
# Extract model parameters from the request
response = chat_service.get_chat_response(
chat_id=chat_id,
message=request.message,
user_id=request.user_id,
use_rag=request.use_rag,
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
stop_sequences=request.stop_sequences,
system_prompt=request.system_prompt,
min_p=request.min_p,
top_k=request.top_k,
repeat_penalty=request.repeat_penalty,
function_calling=request.function_calling
)
return response
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@app.post("/chats/{chat_id}/members/{user_id}")
async def add_team_member(chat_id: str, user_id: str):
"""
Add a user to a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Addition status.
"""
success = chat_service.add_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to add team member")
return {"status": "success", "message": "Team member added"}
@app.delete("/chats/{chat_id}/members/{user_id}")
async def remove_team_member(chat_id: str, user_id: str):
"""
Remove a user from a team chat.
Args:
chat_id: Chat ID.
user_id: User ID.
Returns:
Removal status.
"""
success = chat_service.remove_team_member(chat_id, user_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to remove team member")
return {"status": "success", "message": "Team member removed"}
@app.delete("/chats/{chat_id}")
async def delete_chat(chat_id: str):
"""
Delete a chat.
Args:
chat_id: Chat ID.
Returns:
Deletion status.
"""
success = chat_service.delete_chat(chat_id)
if not success:
raise HTTPException(status_code=404, detail="Chat not found")
return {"status": "success", "message": "Chat deleted"}