from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional, List import uvicorn from copywriter import generate_marketing_copy from vector_store import VectorStore from embeddings import CohereEmbeddings from brand_style import BrandStyleChecker from config import Settings from finetuned_model import finetuned_model app = FastAPI(title="Marketing Assistant AI") settings = Settings() vector_store = VectorStore() embeddings = CohereEmbeddings() brand_checker = BrandStyleChecker() class CopyRequest(BaseModel): prompt: str content_type: str tone: Optional[str] = None target_audience: Optional[str] = None class CopyResponse(BaseModel): content: str confidence_score: float brand_alignment_score: float class DirectModelRequest(BaseModel): prompt: str max_length: Optional[int] = 200 num_return_sequences: Optional[int] = 1 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 class DirectModelResponse(BaseModel): generated_texts: List[str] @app.post("/generate-copy", response_model=CopyResponse) async def create_marketing_copy(request: CopyRequest): try: # Generate embeddings for the prompt prompt_embedding = embeddings.generate(request.prompt) # Retrieve similar content from vector store similar_content = vector_store.search(prompt_embedding) # Generate marketing copy content = generate_marketing_copy( prompt=request.prompt, content_type=request.content_type, similar_content=similar_content, tone=request.tone, target_audience=request.target_audience ) # Check brand alignment brand_alignment = brand_checker.check_alignment(content) return CopyResponse( content=content, confidence_score=0.85, # This should be calculated based on model confidence brand_alignment_score=brand_alignment ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/direct-model", response_model=DirectModelResponse) async def direct_model_inference(request: DirectModelRequest): """ Direct inference using the finetuned model without using the vector store or other components. This endpoint is useful for testing the model directly. """ try: # Generate text using the finetuned model generated_texts = finetuned_model.generate( prompt=request.prompt, max_length=request.max_length, num_return_sequences=request.num_return_sequences, temperature=request.temperature, top_p=request.top_p ) return DirectModelResponse(generated_texts=generated_texts) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": uvicorn.run("main:app", host="localhost", port=8000, reload=True)