378 lines
12 KiB
Python
378 lines
12 KiB
Python
"""
|
|
FastAPI application for the AI service.
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Body, Query, Path
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
from ai_service.config import config
|
|
from ai_service.embeddings.document_service import document_service
|
|
from ai_service.models.model_service import model_service
|
|
from ai_service.models.chat_service import chat_service
|
|
from ai_service.models.model_parameters import ModelParameters
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="AI Service API",
|
|
description="API for the AI service",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allow all origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allow all methods
|
|
allow_headers=["*"], # Allow all headers
|
|
)
|
|
|
|
# Define API models
|
|
class DocumentRequest(BaseModel):
|
|
"""Request model for document processing."""
|
|
content: str = Field(..., description="Document content")
|
|
title: str = Field(..., description="Document title")
|
|
description: Optional[str] = Field(None, description="Document description")
|
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
|
|
|
class DocumentResponse(BaseModel):
|
|
"""Response model for document processing."""
|
|
id: str = Field(..., description="Document ID")
|
|
title: str = Field(..., description="Document title")
|
|
description: str = Field(..., description="Document description")
|
|
chunk_count: int = Field(..., description="Number of chunks")
|
|
metadata: Dict[str, Any] = Field(..., description="Additional metadata")
|
|
|
|
class SearchRequest(BaseModel):
|
|
"""Request model for document search."""
|
|
query: str = Field(..., description="Search query")
|
|
top_k: int = Field(5, description="Number of results to return")
|
|
|
|
class SearchResult(BaseModel):
|
|
"""Model for a search result."""
|
|
id: str = Field(..., description="Result ID")
|
|
score: float = Field(..., description="Similarity score")
|
|
metadata: Dict[str, Any] = Field(..., description="Result metadata")
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Model for model information."""
|
|
id: str = Field(..., description="Model ID")
|
|
name: str = Field(..., description="Model name")
|
|
description: str = Field(..., description="Model description")
|
|
provider: str = Field(..., description="Model provider")
|
|
max_tokens: int = Field(..., description="Maximum tokens")
|
|
is_default: bool = Field(..., description="Whether this is the default model")
|
|
|
|
class ChatRequest(BaseModel):
|
|
"""Request model for creating a chat."""
|
|
user_id: str = Field(..., description="User ID")
|
|
title: Optional[str] = Field(None, description="Chat title")
|
|
model_id: Optional[str] = Field(None, description="Model ID")
|
|
is_team_chat: bool = Field(False, description="Whether this is a team chat")
|
|
|
|
class MessageRequest(BaseModel):
|
|
"""Request model for sending a message."""
|
|
message: str = Field(..., description="Message content")
|
|
user_id: str = Field(..., description="User ID")
|
|
use_rag: bool = Field(False, description="Whether to use RAG")
|
|
|
|
# Model parameters
|
|
temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions")
|
|
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
|
|
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
|
|
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
|
|
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
|
|
stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating")
|
|
system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior")
|
|
|
|
# Additional advanced parameters
|
|
min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection")
|
|
top_k: Optional[int] = Field(None, description="Only sample from the top k tokens")
|
|
repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens")
|
|
function_calling: Optional[bool] = Field(None, description="Whether to enable function calling")
|
|
|
|
class Message(BaseModel):
|
|
"""Model for a message."""
|
|
id: str = Field(..., description="Message ID")
|
|
content: str = Field(..., description="Message content")
|
|
user_id: Optional[str] = Field(None, description="User ID")
|
|
is_user_message: bool = Field(..., description="Whether this is a user message")
|
|
timestamp: str = Field(..., description="Message timestamp")
|
|
|
|
class Chat(BaseModel):
|
|
"""Model for a chat."""
|
|
id: str = Field(..., description="Chat ID")
|
|
title: str = Field(..., description="Chat title")
|
|
user_id: str = Field(..., description="User ID")
|
|
model_id: str = Field(..., description="Model ID")
|
|
is_team_chat: bool = Field(..., description="Whether this is a team chat")
|
|
created_at: str = Field(..., description="Creation timestamp")
|
|
updated_at: str = Field(..., description="Update timestamp")
|
|
messages: List[Message] = Field(..., description="Chat messages")
|
|
team_members: List[str] = Field(..., description="Team members")
|
|
|
|
# Define API endpoints
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""
|
|
Health check endpoint.
|
|
|
|
Returns:
|
|
Health status.
|
|
"""
|
|
return {"status": "healthy"}
|
|
|
|
# Document endpoints
|
|
@app.post("/documents", response_model=DocumentResponse)
|
|
async def process_document(request: DocumentRequest):
|
|
"""
|
|
Process a document for embedding.
|
|
|
|
Args:
|
|
request: Document processing request.
|
|
|
|
Returns:
|
|
Processed document information.
|
|
"""
|
|
doc_id = document_service.process_document(
|
|
content=request.content,
|
|
title=request.title,
|
|
description=request.description,
|
|
metadata=request.metadata
|
|
)
|
|
|
|
return document_service.get_document(doc_id)
|
|
|
|
@app.get("/documents", response_model=List[DocumentResponse])
|
|
async def get_all_documents():
|
|
"""
|
|
Get all documents.
|
|
|
|
Returns:
|
|
List of document information.
|
|
"""
|
|
return document_service.get_all_documents()
|
|
|
|
@app.get("/documents/{doc_id}", response_model=DocumentResponse)
|
|
async def get_document(doc_id: str):
|
|
"""
|
|
Get a document by ID.
|
|
|
|
Args:
|
|
doc_id: Document ID.
|
|
|
|
Returns:
|
|
Document information.
|
|
"""
|
|
doc = document_service.get_document(doc_id)
|
|
if not doc:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
|
|
return doc
|
|
|
|
@app.delete("/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""
|
|
Delete a document.
|
|
|
|
Args:
|
|
doc_id: Document ID.
|
|
|
|
Returns:
|
|
Deletion status.
|
|
"""
|
|
success = document_service.delete_document(doc_id)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
|
|
return {"status": "success", "message": "Document deleted"}
|
|
|
|
@app.post("/documents/search", response_model=List[SearchResult])
|
|
async def search_documents(request: SearchRequest):
|
|
"""
|
|
Search for documents.
|
|
|
|
Args:
|
|
request: Search request.
|
|
|
|
Returns:
|
|
Search results.
|
|
"""
|
|
results = document_service.search_documents(
|
|
query=request.query,
|
|
top_k=request.top_k
|
|
)
|
|
|
|
return results
|
|
|
|
# Model endpoints
|
|
@app.get("/models", response_model=List[ModelInfo])
|
|
async def get_available_models():
|
|
"""
|
|
Get available models.
|
|
|
|
Returns:
|
|
List of model information.
|
|
"""
|
|
return model_service.get_available_models()
|
|
|
|
@app.get("/models/{model_id}", response_model=ModelInfo)
|
|
async def get_model_info(model_id: str):
|
|
"""
|
|
Get information about a model.
|
|
|
|
Args:
|
|
model_id: Model ID.
|
|
|
|
Returns:
|
|
Model information.
|
|
"""
|
|
model_info = model_service.get_model_info(model_id)
|
|
if not model_info:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
return model_info
|
|
|
|
# Chat endpoints
|
|
@app.post("/chats", response_model=Chat)
|
|
async def create_chat(request: ChatRequest):
|
|
"""
|
|
Create a new chat.
|
|
|
|
Args:
|
|
request: Chat creation request.
|
|
|
|
Returns:
|
|
Created chat.
|
|
"""
|
|
chat_id = chat_service.create_chat(
|
|
user_id=request.user_id,
|
|
title=request.title,
|
|
model_id=request.model_id,
|
|
is_team_chat=request.is_team_chat
|
|
)
|
|
|
|
return chat_service.get_chat(chat_id)
|
|
|
|
@app.get("/chats/user/{user_id}", response_model=List[Chat])
|
|
async def get_user_chats(user_id: str):
|
|
"""
|
|
Get all chats for a user.
|
|
|
|
Args:
|
|
user_id: User ID.
|
|
|
|
Returns:
|
|
List of chats.
|
|
"""
|
|
return chat_service.get_user_chats(user_id)
|
|
|
|
@app.get("/chats/{chat_id}", response_model=Chat)
|
|
async def get_chat(chat_id: str):
|
|
"""
|
|
Get a chat by ID.
|
|
|
|
Args:
|
|
chat_id: Chat ID.
|
|
|
|
Returns:
|
|
Chat information.
|
|
"""
|
|
chat = chat_service.get_chat(chat_id)
|
|
if not chat:
|
|
raise HTTPException(status_code=404, detail="Chat not found")
|
|
|
|
return chat
|
|
|
|
@app.post("/chats/{chat_id}/messages", response_model=Message)
|
|
async def send_message(chat_id: str, request: MessageRequest):
|
|
"""
|
|
Send a message to a chat.
|
|
|
|
Args:
|
|
chat_id: Chat ID.
|
|
request: Message request with optional model parameters.
|
|
|
|
Returns:
|
|
Bot response message.
|
|
"""
|
|
try:
|
|
# Extract model parameters from the request
|
|
response = chat_service.get_chat_response(
|
|
chat_id=chat_id,
|
|
message=request.message,
|
|
user_id=request.user_id,
|
|
use_rag=request.use_rag,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
top_p=request.top_p,
|
|
frequency_penalty=request.frequency_penalty,
|
|
presence_penalty=request.presence_penalty,
|
|
stop_sequences=request.stop_sequences,
|
|
system_prompt=request.system_prompt,
|
|
min_p=request.min_p,
|
|
top_k=request.top_k,
|
|
repeat_penalty=request.repeat_penalty,
|
|
function_calling=request.function_calling
|
|
)
|
|
|
|
return response
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
@app.post("/chats/{chat_id}/members/{user_id}")
|
|
async def add_team_member(chat_id: str, user_id: str):
|
|
"""
|
|
Add a user to a team chat.
|
|
|
|
Args:
|
|
chat_id: Chat ID.
|
|
user_id: User ID.
|
|
|
|
Returns:
|
|
Addition status.
|
|
"""
|
|
success = chat_service.add_team_member(chat_id, user_id)
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to add team member")
|
|
|
|
return {"status": "success", "message": "Team member added"}
|
|
|
|
@app.delete("/chats/{chat_id}/members/{user_id}")
|
|
async def remove_team_member(chat_id: str, user_id: str):
|
|
"""
|
|
Remove a user from a team chat.
|
|
|
|
Args:
|
|
chat_id: Chat ID.
|
|
user_id: User ID.
|
|
|
|
Returns:
|
|
Removal status.
|
|
"""
|
|
success = chat_service.remove_team_member(chat_id, user_id)
|
|
if not success:
|
|
raise HTTPException(status_code=400, detail="Failed to remove team member")
|
|
|
|
return {"status": "success", "message": "Team member removed"}
|
|
|
|
@app.delete("/chats/{chat_id}")
|
|
async def delete_chat(chat_id: str):
|
|
"""
|
|
Delete a chat.
|
|
|
|
Args:
|
|
chat_id: Chat ID.
|
|
|
|
Returns:
|
|
Deletion status.
|
|
"""
|
|
success = chat_service.delete_chat(chat_id)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="Chat not found")
|
|
|
|
return {"status": "success", "message": "Chat deleted"}
|