Initial commit for deployment
This commit is contained in:
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
FastAPI application for the AI service.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, Query, Path
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from ai_service.config import config
|
||||
from ai_service.embeddings.document_service import document_service
|
||||
from ai_service.models.model_service import model_service
|
||||
from ai_service.models.chat_service import chat_service
|
||||
from ai_service.models.model_parameters import ModelParameters
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="AI Service API",
|
||||
description="API for the AI service",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allow all methods
|
||||
allow_headers=["*"], # Allow all headers
|
||||
)
|
||||
|
||||
# Define API models
|
||||
class DocumentRequest(BaseModel):
|
||||
"""Request model for document processing."""
|
||||
content: str = Field(..., description="Document content")
|
||||
title: str = Field(..., description="Document title")
|
||||
description: Optional[str] = Field(None, description="Document description")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
"""Response model for document processing."""
|
||||
id: str = Field(..., description="Document ID")
|
||||
title: str = Field(..., description="Document title")
|
||||
description: str = Field(..., description="Document description")
|
||||
chunk_count: int = Field(..., description="Number of chunks")
|
||||
metadata: Dict[str, Any] = Field(..., description="Additional metadata")
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request model for document search."""
|
||||
query: str = Field(..., description="Search query")
|
||||
top_k: int = Field(5, description="Number of results to return")
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Model for a search result."""
|
||||
id: str = Field(..., description="Result ID")
|
||||
score: float = Field(..., description="Similarity score")
|
||||
metadata: Dict[str, Any] = Field(..., description="Result metadata")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Model for model information."""
|
||||
id: str = Field(..., description="Model ID")
|
||||
name: str = Field(..., description="Model name")
|
||||
description: str = Field(..., description="Model description")
|
||||
provider: str = Field(..., description="Model provider")
|
||||
max_tokens: int = Field(..., description="Maximum tokens")
|
||||
is_default: bool = Field(..., description="Whether this is the default model")
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Request model for creating a chat."""
|
||||
user_id: str = Field(..., description="User ID")
|
||||
title: Optional[str] = Field(None, description="Chat title")
|
||||
model_id: Optional[str] = Field(None, description="Model ID")
|
||||
is_team_chat: bool = Field(False, description="Whether this is a team chat")
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request model for sending a message."""
|
||||
message: str = Field(..., description="Message content")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
use_rag: bool = Field(False, description="Whether to use RAG")
|
||||
|
||||
# Model parameters
|
||||
temperature: Optional[float] = Field(None, description="Controls randomness: higher values mean more random completions")
|
||||
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
|
||||
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
|
||||
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
|
||||
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
|
||||
stop_sequences: Optional[List[str]] = Field(None, description="Sequences where the API will stop generating")
|
||||
system_prompt: Optional[str] = Field(None, description="System prompt to guide the model's behavior")
|
||||
|
||||
# Additional advanced parameters
|
||||
min_p: Optional[float] = Field(None, description="Minimum probability threshold for token selection")
|
||||
top_k: Optional[int] = Field(None, description="Only sample from the top k tokens")
|
||||
repeat_penalty: Optional[float] = Field(None, description="Penalty for repeating tokens")
|
||||
function_calling: Optional[bool] = Field(None, description="Whether to enable function calling")
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Model for a message."""
|
||||
id: str = Field(..., description="Message ID")
|
||||
content: str = Field(..., description="Message content")
|
||||
user_id: Optional[str] = Field(None, description="User ID")
|
||||
is_user_message: bool = Field(..., description="Whether this is a user message")
|
||||
timestamp: str = Field(..., description="Message timestamp")
|
||||
|
||||
class Chat(BaseModel):
|
||||
"""Model for a chat."""
|
||||
id: str = Field(..., description="Chat ID")
|
||||
title: str = Field(..., description="Chat title")
|
||||
user_id: str = Field(..., description="User ID")
|
||||
model_id: str = Field(..., description="Model ID")
|
||||
is_team_chat: bool = Field(..., description="Whether this is a team chat")
|
||||
created_at: str = Field(..., description="Creation timestamp")
|
||||
updated_at: str = Field(..., description="Update timestamp")
|
||||
messages: List[Message] = Field(..., description="Chat messages")
|
||||
team_members: List[str] = Field(..., description="Team members")
|
||||
|
||||
# Define API endpoints
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
|
||||
Returns:
|
||||
Health status.
|
||||
"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
# Document endpoints
|
||||
@app.post("/documents", response_model=DocumentResponse)
|
||||
async def process_document(request: DocumentRequest):
|
||||
"""
|
||||
Process a document for embedding.
|
||||
|
||||
Args:
|
||||
request: Document processing request.
|
||||
|
||||
Returns:
|
||||
Processed document information.
|
||||
"""
|
||||
doc_id = document_service.process_document(
|
||||
content=request.content,
|
||||
title=request.title,
|
||||
description=request.description,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return document_service.get_document(doc_id)
|
||||
|
||||
@app.get("/documents", response_model=List[DocumentResponse])
|
||||
async def get_all_documents():
|
||||
"""
|
||||
Get all documents.
|
||||
|
||||
Returns:
|
||||
List of document information.
|
||||
"""
|
||||
return document_service.get_all_documents()
|
||||
|
||||
@app.get("/documents/{doc_id}", response_model=DocumentResponse)
|
||||
async def get_document(doc_id: str):
|
||||
"""
|
||||
Get a document by ID.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Document information.
|
||||
"""
|
||||
doc = document_service.get_document(doc_id)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return doc
|
||||
|
||||
@app.delete("/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""
|
||||
Delete a document.
|
||||
|
||||
Args:
|
||||
doc_id: Document ID.
|
||||
|
||||
Returns:
|
||||
Deletion status.
|
||||
"""
|
||||
success = document_service.delete_document(doc_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
return {"status": "success", "message": "Document deleted"}
|
||||
|
||||
@app.post("/documents/search", response_model=List[SearchResult])
|
||||
async def search_documents(request: SearchRequest):
|
||||
"""
|
||||
Search for documents.
|
||||
|
||||
Args:
|
||||
request: Search request.
|
||||
|
||||
Returns:
|
||||
Search results.
|
||||
"""
|
||||
results = document_service.search_documents(
|
||||
query=request.query,
|
||||
top_k=request.top_k
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# Model endpoints
|
||||
@app.get("/models", response_model=List[ModelInfo])
|
||||
async def get_available_models():
|
||||
"""
|
||||
Get available models.
|
||||
|
||||
Returns:
|
||||
List of model information.
|
||||
"""
|
||||
return model_service.get_available_models()
|
||||
|
||||
@app.get("/models/{model_id}", response_model=ModelInfo)
|
||||
async def get_model_info(model_id: str):
|
||||
"""
|
||||
Get information about a model.
|
||||
|
||||
Args:
|
||||
model_id: Model ID.
|
||||
|
||||
Returns:
|
||||
Model information.
|
||||
"""
|
||||
model_info = model_service.get_model_info(model_id)
|
||||
if not model_info:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return model_info
|
||||
|
||||
# Chat endpoints
|
||||
@app.post("/chats", response_model=Chat)
|
||||
async def create_chat(request: ChatRequest):
|
||||
"""
|
||||
Create a new chat.
|
||||
|
||||
Args:
|
||||
request: Chat creation request.
|
||||
|
||||
Returns:
|
||||
Created chat.
|
||||
"""
|
||||
chat_id = chat_service.create_chat(
|
||||
user_id=request.user_id,
|
||||
title=request.title,
|
||||
model_id=request.model_id,
|
||||
is_team_chat=request.is_team_chat
|
||||
)
|
||||
|
||||
return chat_service.get_chat(chat_id)
|
||||
|
||||
@app.get("/chats/user/{user_id}", response_model=List[Chat])
|
||||
async def get_user_chats(user_id: str):
|
||||
"""
|
||||
Get all chats for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
List of chats.
|
||||
"""
|
||||
return chat_service.get_user_chats(user_id)
|
||||
|
||||
@app.get("/chats/{chat_id}", response_model=Chat)
|
||||
async def get_chat(chat_id: str):
|
||||
"""
|
||||
Get a chat by ID.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID.
|
||||
|
||||
Returns:
|
||||
Chat information.
|
||||
"""
|
||||
chat = chat_service.get_chat(chat_id)
|
||||
if not chat:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
return chat
|
||||
|
||||
@app.post("/chats/{chat_id}/messages", response_model=Message)
|
||||
async def send_message(chat_id: str, request: MessageRequest):
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID.
|
||||
request: Message request with optional model parameters.
|
||||
|
||||
Returns:
|
||||
Bot response message.
|
||||
"""
|
||||
try:
|
||||
# Extract model parameters from the request
|
||||
response = chat_service.get_chat_response(
|
||||
chat_id=chat_id,
|
||||
message=request.message,
|
||||
user_id=request.user_id,
|
||||
use_rag=request.use_rag,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
presence_penalty=request.presence_penalty,
|
||||
stop_sequences=request.stop_sequences,
|
||||
system_prompt=request.system_prompt,
|
||||
min_p=request.min_p,
|
||||
top_k=request.top_k,
|
||||
repeat_penalty=request.repeat_penalty,
|
||||
function_calling=request.function_calling
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@app.post("/chats/{chat_id}/members/{user_id}")
|
||||
async def add_team_member(chat_id: str, user_id: str):
|
||||
"""
|
||||
Add a user to a team chat.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID.
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
Addition status.
|
||||
"""
|
||||
success = chat_service.add_team_member(chat_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to add team member")
|
||||
|
||||
return {"status": "success", "message": "Team member added"}
|
||||
|
||||
@app.delete("/chats/{chat_id}/members/{user_id}")
|
||||
async def remove_team_member(chat_id: str, user_id: str):
|
||||
"""
|
||||
Remove a user from a team chat.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID.
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
Removal status.
|
||||
"""
|
||||
success = chat_service.remove_team_member(chat_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to remove team member")
|
||||
|
||||
return {"status": "success", "message": "Team member removed"}
|
||||
|
||||
@app.delete("/chats/{chat_id}")
|
||||
async def delete_chat(chat_id: str):
|
||||
"""
|
||||
Delete a chat.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID.
|
||||
|
||||
Returns:
|
||||
Deletion status.
|
||||
"""
|
||||
success = chat_service.delete_chat(chat_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Chat not found")
|
||||
|
||||
return {"status": "success", "message": "Chat deleted"}
|
||||
Reference in New Issue
Block a user