145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
|
|
"""
|
||
|
|
Simple API for testing deployment.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import uuid
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Dict, Any, Optional
|
||
|
|
|
||
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
# Create FastAPI app
|
||
|
|
app = FastAPI(
|
||
|
|
title="Simple AI Service API",
|
||
|
|
description="Simple API for testing deployment",
|
||
|
|
version="1.0.0"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Add CORS middleware
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"],
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Define API models
|
||
|
|
class MessageRequest(BaseModel):
|
||
|
|
"""Request model for sending a message."""
|
||
|
|
message: str = Field(..., description="Message content")
|
||
|
|
user_id: str = Field(..., description="User ID")
|
||
|
|
|
||
|
|
# Model parameters
|
||
|
|
temperature: Optional[float] = Field(None, description="Controls randomness")
|
||
|
|
max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate")
|
||
|
|
top_p: Optional[float] = Field(None, description="Nucleus sampling parameter")
|
||
|
|
frequency_penalty: Optional[float] = Field(None, description="Penalizes repeated tokens")
|
||
|
|
presence_penalty: Optional[float] = Field(None, description="Penalizes repeated topics")
|
||
|
|
system_prompt: Optional[str] = Field(None, description="System prompt")
|
||
|
|
|
||
|
|
class Message(BaseModel):
|
||
|
|
"""Model for a message."""
|
||
|
|
id: str = Field(..., description="Message ID")
|
||
|
|
content: str = Field(..., description="Message content")
|
||
|
|
user_id: Optional[str] = Field(None, description="User ID")
|
||
|
|
is_user_message: bool = Field(..., description="Whether this is a user message")
|
||
|
|
timestamp: str = Field(..., description="Message timestamp")
|
||
|
|
|
||
|
|
class ChatRequest(BaseModel):
|
||
|
|
"""Request model for creating a chat."""
|
||
|
|
user_id: str = Field(..., description="User ID")
|
||
|
|
title: Optional[str] = Field(None, description="Chat title")
|
||
|
|
model_id: Optional[str] = Field(None, description="Model ID")
|
||
|
|
|
||
|
|
class Chat(BaseModel):
|
||
|
|
"""Model for a chat."""
|
||
|
|
id: str = Field(..., description="Chat ID")
|
||
|
|
title: str = Field(..., description="Chat title")
|
||
|
|
user_id: str = Field(..., description="User ID")
|
||
|
|
model_id: str = Field(..., description="Model ID")
|
||
|
|
created_at: str = Field(..., description="Creation timestamp")
|
||
|
|
updated_at: str = Field(..., description="Update timestamp")
|
||
|
|
messages: List[Message] = Field(default=[], description="Chat messages")
|
||
|
|
|
||
|
|
# In-memory storage
|
||
|
|
chats = {}
|
||
|
|
|
||
|
|
# API endpoints
|
||
|
|
@app.get("/health")
|
||
|
|
async def health_check():
|
||
|
|
"""Health check endpoint."""
|
||
|
|
return {"status": "healthy"}
|
||
|
|
|
||
|
|
@app.post("/chats", response_model=Chat)
|
||
|
|
async def create_chat(request: ChatRequest):
|
||
|
|
"""Create a new chat."""
|
||
|
|
chat_id = str(uuid.uuid4())
|
||
|
|
|
||
|
|
chat = {
|
||
|
|
"id": chat_id,
|
||
|
|
"title": request.title or f"Chat {len(chats) + 1}",
|
||
|
|
"user_id": request.user_id,
|
||
|
|
"model_id": request.model_id or "gpt-3.5-turbo",
|
||
|
|
"created_at": datetime.utcnow().isoformat(),
|
||
|
|
"updated_at": datetime.utcnow().isoformat(),
|
||
|
|
"messages": []
|
||
|
|
}
|
||
|
|
|
||
|
|
chats[chat_id] = chat
|
||
|
|
return chat
|
||
|
|
|
||
|
|
@app.get("/chats/{chat_id}", response_model=Chat)
|
||
|
|
async def get_chat(chat_id: str):
|
||
|
|
"""Get a chat by ID."""
|
||
|
|
if chat_id not in chats:
|
||
|
|
raise HTTPException(status_code=404, detail="Chat not found")
|
||
|
|
|
||
|
|
return chats[chat_id]
|
||
|
|
|
||
|
|
@app.post("/chats/{chat_id}/messages", response_model=Message)
|
||
|
|
async def send_message(chat_id: str, request: MessageRequest):
|
||
|
|
"""Send a message to a chat."""
|
||
|
|
if chat_id not in chats:
|
||
|
|
raise HTTPException(status_code=404, detail="Chat not found")
|
||
|
|
|
||
|
|
# Add user message
|
||
|
|
user_message = {
|
||
|
|
"id": str(uuid.uuid4()),
|
||
|
|
"content": request.message,
|
||
|
|
"user_id": request.user_id,
|
||
|
|
"is_user_message": True,
|
||
|
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
|
}
|
||
|
|
|
||
|
|
chats[chat_id]["messages"].append(user_message)
|
||
|
|
|
||
|
|
# Generate bot response
|
||
|
|
params_text = ""
|
||
|
|
if request.temperature is not None:
|
||
|
|
params_text += f" (temperature={request.temperature})"
|
||
|
|
if request.max_tokens is not None:
|
||
|
|
params_text += f" (max_tokens={request.max_tokens})"
|
||
|
|
if request.system_prompt is not None:
|
||
|
|
params_text += f" (using custom system prompt)"
|
||
|
|
|
||
|
|
bot_message = {
|
||
|
|
"id": str(uuid.uuid4()),
|
||
|
|
"content": f"This is a test response to: '{request.message}'{params_text}",
|
||
|
|
"user_id": None,
|
||
|
|
"is_user_message": False,
|
||
|
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
|
}
|
||
|
|
|
||
|
|
chats[chat_id]["messages"].append(bot_message)
|
||
|
|
chats[chat_id]["updated_at"] = datetime.utcnow().isoformat()
|
||
|
|
|
||
|
|
return bot_message
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import uvicorn
|
||
|
|
uvicorn.run(app, host="0.0.0.0", port=5251)
|