Files

93 lines
3.0 KiB
Python
Raw Permalink Normal View History

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List
import uvicorn
from copywriter import generate_marketing_copy
from vector_store import VectorStore
from embeddings import CohereEmbeddings
from brand_style import BrandStyleChecker
from config import Settings
from finetuned_model import finetuned_model
app = FastAPI(title="Marketing Assistant AI")
settings = Settings()
vector_store = VectorStore()
embeddings = CohereEmbeddings()
brand_checker = BrandStyleChecker()
class CopyRequest(BaseModel):
prompt: str
content_type: str
tone: Optional[str] = None
target_audience: Optional[str] = None
class CopyResponse(BaseModel):
content: str
confidence_score: float
brand_alignment_score: float
class DirectModelRequest(BaseModel):
prompt: str
max_length: Optional[int] = 200
num_return_sequences: Optional[int] = 1
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
class DirectModelResponse(BaseModel):
generated_texts: List[str]
@app.post("/generate-copy", response_model=CopyResponse)
async def create_marketing_copy(request: CopyRequest):
try:
# Generate embeddings for the prompt
prompt_embedding = embeddings.generate(request.prompt)
# Retrieve similar content from vector store
similar_content = vector_store.search(prompt_embedding)
# Generate marketing copy
content = generate_marketing_copy(
prompt=request.prompt,
content_type=request.content_type,
similar_content=similar_content,
tone=request.tone,
target_audience=request.target_audience
)
# Check brand alignment
brand_alignment = brand_checker.check_alignment(content)
return CopyResponse(
content=content,
confidence_score=0.85, # This should be calculated based on model confidence
brand_alignment_score=brand_alignment
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/direct-model", response_model=DirectModelResponse)
async def direct_model_inference(request: DirectModelRequest):
"""
Direct inference using the finetuned model without using the vector store or other components.
This endpoint is useful for testing the model directly.
"""
try:
# Generate text using the finetuned model
generated_texts = finetuned_model.generate(
prompt=request.prompt,
max_length=request.max_length,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
top_p=request.top_p
)
return DirectModelResponse(generated_texts=generated_texts)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run("main:app", host="localhost", port=8000, reload=True)