859c17aad8
- Update config.py with Pinecone settings and model configurations - Implement VectorStore class with Pinecone backend - Add comprehensive vector operations (add, search, delete) - Set up proper error handling and metadata management - Add .gitignore for Python project
93 lines
3.0 KiB
Python
93 lines
3.0 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import Optional, List
|
|
import uvicorn
|
|
from copywriter import generate_marketing_copy
|
|
from vector_store import VectorStore
|
|
from embeddings import CohereEmbeddings
|
|
from brand_style import BrandStyleChecker
|
|
from config import Settings
|
|
from finetuned_model import finetuned_model
|
|
|
|
app = FastAPI(title="Marketing Assistant AI")
|
|
settings = Settings()
|
|
vector_store = VectorStore()
|
|
embeddings = CohereEmbeddings()
|
|
brand_checker = BrandStyleChecker()
|
|
|
|
class CopyRequest(BaseModel):
|
|
prompt: str
|
|
content_type: str
|
|
tone: Optional[str] = None
|
|
target_audience: Optional[str] = None
|
|
|
|
class CopyResponse(BaseModel):
|
|
content: str
|
|
confidence_score: float
|
|
brand_alignment_score: float
|
|
|
|
class DirectModelRequest(BaseModel):
|
|
prompt: str
|
|
max_length: Optional[int] = 200
|
|
num_return_sequences: Optional[int] = 1
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 0.9
|
|
|
|
class DirectModelResponse(BaseModel):
|
|
generated_texts: List[str]
|
|
|
|
@app.post("/generate-copy", response_model=CopyResponse)
|
|
async def create_marketing_copy(request: CopyRequest):
|
|
try:
|
|
# Generate embeddings for the prompt
|
|
prompt_embedding = embeddings.generate(request.prompt)
|
|
|
|
# Retrieve similar content from vector store
|
|
similar_content = vector_store.search(prompt_embedding)
|
|
|
|
# Generate marketing copy
|
|
content = generate_marketing_copy(
|
|
prompt=request.prompt,
|
|
content_type=request.content_type,
|
|
similar_content=similar_content,
|
|
tone=request.tone,
|
|
target_audience=request.target_audience
|
|
)
|
|
|
|
# Check brand alignment
|
|
brand_alignment = brand_checker.check_alignment(content)
|
|
|
|
return CopyResponse(
|
|
content=content,
|
|
confidence_score=0.85, # This should be calculated based on model confidence
|
|
brand_alignment_score=brand_alignment
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/direct-model", response_model=DirectModelResponse)
|
|
async def direct_model_inference(request: DirectModelRequest):
|
|
"""
|
|
Direct inference using the finetuned model without using the vector store or other components.
|
|
This endpoint is useful for testing the model directly.
|
|
"""
|
|
try:
|
|
# Generate text using the finetuned model
|
|
generated_texts = finetuned_model.generate(
|
|
prompt=request.prompt,
|
|
max_length=request.max_length,
|
|
num_return_sequences=request.num_return_sequences,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p
|
|
)
|
|
|
|
return DirectModelResponse(generated_texts=generated_texts)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy"}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", host="localhost", port=8000, reload=True) |