Files
marketing-assistant-ai/backend/main.py
T
Michael Ikehi 12e0830ba6 feat(feedback): Add content improvement feedback system
Frontend (frontend/app.js):

- Add textarea for improvement feedback

- Add submit button with loading state

- Handle API response and display improved content

Backend (backend/copywriter.py):

- Add improve_copy() method using Cohere API

- Integrate retry mechanism for API calls

Backend (backend/main.py):

- Add /improve-content POST endpoint

- Implement error handling and return improved content with metadata

Testing:

- Verified feedback submission flow

- Confirmed improved content generation

- Tested error scenarios and loading states
2025-04-18 17:57:35 +01:00

468 lines
17 KiB
Python

"""
Main FastAPI application for the Marketing Assistant AI.
Provides API endpoints for generating and managing marketing content.
"""
import os
import json
import glob
from typing import Dict, List, Any, Optional
from datetime import datetime
from pathlib import Path
from fastapi import FastAPI, HTTPException, Depends, Query, Body, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
from pydantic import BaseModel, Field
from sqlalchemy import select, desc, func
from sqlalchemy.sql import Select
import config
from copywriter import copywriter
from vector_store import vector_store
from brand_style import brand_style_manager
from embeddings import embeddings_manager
from models import database, training_data
# Initialize logging
logger.add(config.LOG_FILE, level=config.LOG_LEVEL, rotation="10 MB", retention="1 month")
# Create FastAPI app
app = FastAPI(
title="Marketing Assistant AI",
description="AI-powered tool for marketing copywriting with Adriana James' brand voice",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define request and response models
class GenerateCopyRequest(BaseModel):
prompt: str = Field(..., description="The main instruction for generating content")
content_type: Optional[str] = Field(None, description="Type of content to generate")
length: Optional[str] = Field(None, description="Desired length of the content")
include_cta: Optional[bool] = Field(False, description="Whether to include a call to action")
reference_similar_content: Optional[bool] = Field(True, description="Whether to reference similar content")
max_tokens: Optional[int] = Field(1000, description="Maximum tokens for the generated response")
class TrainingDataRequest(BaseModel):
content_type: str = Field(..., description="Type of content")
content: str = Field(..., description="The marketing content")
metadata: Optional[Dict[str, Any]] = Field({}, description="Additional metadata about the content")
class BrandStyleUpdateRequest(BaseModel):
tone: Optional[List[str]] = Field(None, description="Brand tone options")
voice_characteristics: Optional[List[str]] = Field(None, description="Voice characteristics")
taboo_words: Optional[List[str]] = Field(None, description="Words to avoid")
preferred_terms: Optional[Dict[str, str]] = Field(None, description="Preferred terminology")
class ContentImprovementRequest(BaseModel):
content: str = Field(..., description="Original generated content")
feedback: str = Field(..., description="User feedback for improvement")
# API Routes
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": "Marketing Assistant AI",
"version": "1.0.0",
"description": f"AI-powered marketing copywriter for {config.BRAND_NAME}"
}
@app.post("/generate-copy")
async def generate_copy(request: GenerateCopyRequest):
"""Generate marketing copy based on the provided prompt and parameters."""
try:
# Validate content type if provided
if request.content_type and request.content_type not in config.CONTENT_TYPES:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"status": "error",
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
}
)
# Generate copy
result = await copywriter.generate_copy(
prompt=request.prompt,
content_type=request.content_type,
length=request.length,
include_cta=request.include_cta,
reference_similar_content=request.reference_similar_content,
max_tokens=request.max_tokens
)
# Add timestamp
result["metadata"]["generated_at"] = datetime.now().isoformat()
# Store the generated content in the vector store for future reference
if result["content"]:
metadata = {
"content_type": request.content_type,
"prompt": request.prompt,
"generated": True
}
await vector_store.add_documents([result["content"]], [metadata])
# Store the user query for future training
query_path = Path(config.DATA_DIR) / "user_queries" / f"{datetime.now().strftime('%Y%m%d%H%M%S')}.json"
with open(query_path, 'w') as f:
json.dump({
"prompt": request.prompt,
"parameters": {
"content_type": request.content_type,
"length": request.length,
"include_cta": request.include_cta
},
"timestamp": datetime.now().isoformat()
}, f, indent=2)
return {
"status": "success",
"content": result["content"],
"suggestions": result.get("suggestions", []),
"metadata": result["metadata"]
}
except Exception as e:
logger.error(f"Error generating copy: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate copy: {str(e)}"
)
@app.get("/brand-style")
async def get_brand_style():
"""Get the current brand style guidelines."""
try:
style = brand_style_manager.get_style_guidelines()
return style
except Exception as e:
logger.error(f"Error getting brand style: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get brand style: {str(e)}"
)
@app.put("/brand-style")
async def update_brand_style(request: BrandStyleUpdateRequest):
"""Update the brand style guidelines."""
try:
update_data = request.dict(exclude_unset=True)
updated_style = brand_style_manager.update_style_guidelines(update_data)
return {
"status": "success",
"message": "Brand style updated successfully",
"style": updated_style
}
except Exception as e:
logger.error(f"Error updating brand style: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to update brand style: {str(e)}"
)
@app.post("/training-data")
async def add_training_data(request: TrainingDataRequest):
"""Add new marketing content for AI training."""
try:
# Validate content type
if request.content_type not in config.CONTENT_TYPES:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"status": "error",
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
}
)
# Prepare metadata
metadata = request.metadata.copy()
metadata["content_type"] = request.content_type
metadata["added_at"] = datetime.now().isoformat()
metadata["training_data"] = True
# Add to database
query = training_data.insert().values(
content=request.content,
content_type=request.content_type,
metadata=metadata,
added_at=datetime.now(),
is_training_data=True
)
data_id = await database.execute(query)
# Add to vector store for search functionality
doc_ids = await vector_store.add_documents([request.content], [metadata])
return {
"status": "success",
"message": "Training data added successfully",
"data_id": data_id
}
except Exception as e:
logger.error(f"Error adding training data: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to add training data: {str(e)}"
)
@app.get("/training-data")
async def list_training_data(
content_type: Optional[str] = Query(None, description="Filter by content type"),
page: int = Query(1, ge=1, description="Page number"),
limit: int = Query(10, ge=1, le=100, description="Items per page")
):
"""Retrieve a list of available training data."""
try:
# Build base query
base_query = select(training_data).where(training_data.c.is_training_data == True)
if content_type:
if content_type not in config.CONTENT_TYPES:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"status": "error",
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
}
)
base_query = base_query.where(training_data.c.content_type == content_type)
# Count total records
count_query = select(func.count()).select_from(training_data).where(training_data.c.is_training_data == True)
if content_type:
count_query = count_query.where(training_data.c.content_type == content_type)
total = await database.fetch_val(count_query)
# Add pagination
query = base_query.order_by(training_data.c.added_at.desc()) \
.offset((page - 1) * limit) \
.limit(limit)
# Execute query
records = await database.fetch_all(query)
# Format response
items = []
for record in records:
preview = record["content"][:100] + "..." if len(record["content"]) > 100 else record["content"]
items.append({
"id": record["id"],
"content_type": record["content_type"],
"preview": preview,
"added_at": record["added_at"].isoformat()
})
return {
"items": items,
"pagination": {
"total": total,
"page": page,
"limit": limit,
"pages": (total + limit - 1) // limit
}
}
except Exception as e:
logger.error(f"Error listing training data: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to list training data: {str(e)}"
)
@app.get("/training-data/{data_id}")
async def get_training_data(data_id: int):
"""Retrieve a specific training document by ID."""
try:
query = select([training_data]).where(training_data.c.id == data_id)
record = await database.fetch_one(query)
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Document with ID {data_id} not found"
)
return {
"id": record["id"],
"content": record["content"],
"content_type": record["content_type"],
"metadata": record["metadata"],
"added_at": record["added_at"].isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving training data: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve training data: {str(e)}"
)
@app.delete("/training-data/{data_id}")
async def delete_training_data(data_id: int):
"""Delete a specific training document by ID."""
try:
query = training_data.delete().where(training_data.c.id == data_id)
result = await database.execute(query)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Document with ID {data_id} not found or could not be deleted"
)
# Also remove from vector store
await vector_store.delete_document(data_id)
return {
"status": "success",
"message": f"Document with ID {data_id} successfully deleted"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting training data: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete training data: {str(e)}"
)
@app.post("/improve-content")
async def improve_content(request: ContentImprovementRequest):
"""Improve content based on user feedback."""
try:
improved_content = await copywriter.improve_copy(
content=request.content,
feedback=request.feedback
)
return {
"status": "success",
"original_content": request.content,
"improved_content": improved_content,
"feedback": request.feedback
}
except Exception as e:
logger.error(f"Error improving content: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to improve content: {str(e)}"
)
@app.post("/analyze-content")
async def analyze_content(content: str = Body(..., embed=True)):
"""Analyze marketing content for performance prediction."""
try:
analysis = await copywriter.analyze_content_performance(content)
return {
"status": "success",
"analysis": analysis
}
except Exception as e:
logger.error(f"Error analyzing content: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to analyze content: {str(e)}"
)
@app.get("/user-queries")
async def list_user_queries(
page: int = Query(1, ge=1, description="Page number"),
limit: int = Query(10, ge=1, le=100, description="Items per page")
):
"""Retrieve a list of user queries."""
try:
# Get all query files
query_files = glob.glob(str(Path(config.DATA_DIR) / "user_queries" / "*.json"))
query_files.sort(reverse=True) # Sort by filename (timestamp) descending
# Apply pagination
start_idx = (page - 1) * limit
end_idx = start_idx + limit
page_files = query_files[start_idx:end_idx]
items = []
for file_path in page_files:
with open(file_path, 'r') as f:
query_data = json.load(f)
items.append(query_data)
return {
"items": items,
"total": len(query_files),
"page": page,
"limit": limit
}
except Exception as e:
logger.error(f"Error listing user queries: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to list user queries: {str(e)}"
)
@app.get("/user-queries/{timestamp}")
async def get_user_query(timestamp: str):
"""Retrieve a specific user query by timestamp."""
try:
file_path = Path(config.DATA_DIR) / "user_queries" / f"{timestamp}.json"
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Query with timestamp {timestamp} not found"
)
with open(file_path, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error getting user query: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get user query: {str(e)}"
)
@app.delete("/user-queries/{timestamp}")
async def delete_user_query(timestamp: str):
"""Delete a specific user query by timestamp."""
try:
file_path = Path(config.DATA_DIR) / "user_queries" / f"{timestamp}.json"
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Query with timestamp {timestamp} not found"
)
file_path.unlink() # Delete the file
return {
"status": "success",
"message": f"Query with timestamp {timestamp} successfully deleted"
}
except Exception as e:
logger.error(f"Error deleting user query: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete user query: {str(e)}"
)
# Run the application
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=config.API_HOST,
port=config.API_PORT,
reload=True
)