296 lines
10 KiB
Python
296 lines
10 KiB
Python
|
|
from fastapi import FastAPI, HTTPException
|
||
|
|
from fastapi.middleware.cors import CORSMiddleware
|
||
|
|
from fastapi.staticfiles import StaticFiles
|
||
|
|
from fastapi.responses import FileResponse
|
||
|
|
import uvicorn
|
||
|
|
import os
|
||
|
|
import logging
|
||
|
|
import json
|
||
|
|
import base64
|
||
|
|
import io
|
||
|
|
from viral_velocity_scorer import ViralVelocityScorer
|
||
|
|
from image_enhancer import ImageEnhancer
|
||
|
|
from typing import Optional
|
||
|
|
from pydantic import BaseModel
|
||
|
|
|
||
|
|
# Set up logging
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||
|
|
handlers=[
|
||
|
|
logging.FileHandler('api.log'),
|
||
|
|
logging.StreamHandler()
|
||
|
|
]
|
||
|
|
)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Request model
|
||
|
|
class UserPreferences(BaseModel):
|
||
|
|
aesthetic: str = ""
|
||
|
|
niche: str = ""
|
||
|
|
target_audience: str = ""
|
||
|
|
content_type: str = ""
|
||
|
|
brand_voice: str = ""
|
||
|
|
|
||
|
|
class ScoreImageRequest(BaseModel):
|
||
|
|
image: str # base64 encoded image
|
||
|
|
user_preferences: Optional[UserPreferences] = None
|
||
|
|
|
||
|
|
app = FastAPI(title="Viral Velocity API", description="AI-powered social media image scoring with personalization")
|
||
|
|
|
||
|
|
# Add CORS middleware
|
||
|
|
app.add_middleware(
|
||
|
|
CORSMiddleware,
|
||
|
|
allow_origins=["*"], # In production, specify actual origins
|
||
|
|
allow_credentials=True,
|
||
|
|
allow_methods=["*"],
|
||
|
|
allow_headers=["*"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Mount static files (frontend)
|
||
|
|
if os.path.exists("frontend"):
|
||
|
|
app.mount("/static", StaticFiles(directory="frontend"), name="static")
|
||
|
|
|
||
|
|
# Initialize the scorer and enhancer
|
||
|
|
try:
|
||
|
|
logger.info("Initializing ViralVelocityScorer for API...")
|
||
|
|
scorer = ViralVelocityScorer()
|
||
|
|
logger.info("ViralVelocityScorer initialized successfully")
|
||
|
|
|
||
|
|
logger.info("Initializing ImageEnhancer for API...")
|
||
|
|
enhancer = ImageEnhancer()
|
||
|
|
logger.info("ImageEnhancer initialized successfully")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to initialize services: {e}")
|
||
|
|
print(f"Failed to initialize services: {e}")
|
||
|
|
scorer = None
|
||
|
|
enhancer = None
|
||
|
|
|
||
|
|
@app.get("/")
|
||
|
|
def read_root():
|
||
|
|
"""Serve the main frontend page"""
|
||
|
|
logger.info("Root endpoint accessed")
|
||
|
|
if os.path.exists("frontend/index.html"):
|
||
|
|
return FileResponse("frontend/index.html")
|
||
|
|
else:
|
||
|
|
return {"message": "Viral Velocity API - Social Media Image Scorer with Personalization"}
|
||
|
|
|
||
|
|
@app.post("/score-image")
|
||
|
|
async def score_image(request: ScoreImageRequest):
|
||
|
|
"""
|
||
|
|
Score an image using the 4-pillar Viral Velocity scoring system with personalization
|
||
|
|
|
||
|
|
Request body:
|
||
|
|
{
|
||
|
|
"image": "base64 string",
|
||
|
|
"user_preferences": {
|
||
|
|
"aesthetic": "",
|
||
|
|
"niche": "",
|
||
|
|
"target_audience": "",
|
||
|
|
"content_type": "",
|
||
|
|
"brand_voice": ""
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
logger.info("Received image scoring request")
|
||
|
|
|
||
|
|
if not scorer:
|
||
|
|
logger.error("Scorer not initialized, returning 500 error")
|
||
|
|
raise HTTPException(status_code=500, detail="Scorer not initialized")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Decode base64 image
|
||
|
|
logger.info("Decoding base64 image...")
|
||
|
|
image_data = base64.b64decode(request.image)
|
||
|
|
image = io.BytesIO(image_data)
|
||
|
|
|
||
|
|
# Save temporarily for processing
|
||
|
|
with open("temp_image.jpg", "wb") as f:
|
||
|
|
f.write(image_data)
|
||
|
|
|
||
|
|
logger.info("Image decoded and saved temporarily")
|
||
|
|
|
||
|
|
# Analyze the image
|
||
|
|
logger.info("Starting efficient image analysis...")
|
||
|
|
user_prefs_dict = request.user_preferences.dict() if request.user_preferences else None
|
||
|
|
result = scorer.analyze_image_efficient("temp_image.jpg", user_prefs_dict)
|
||
|
|
|
||
|
|
# Clean up temporary file
|
||
|
|
os.remove("temp_image.jpg")
|
||
|
|
|
||
|
|
# Handle content moderation rejections specifically
|
||
|
|
if result.get('status') == 'rejected':
|
||
|
|
logger.warning(f"Content rejected: {result['rejection_reason']}")
|
||
|
|
return result
|
||
|
|
|
||
|
|
# Handle other errors
|
||
|
|
if 'error' in result:
|
||
|
|
logger.error(f"Analysis returned error: {result['error']}")
|
||
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
||
|
|
|
||
|
|
logger.info(f"Analysis completed successfully. Final score: {result.get('final_score', 'N/A')}")
|
||
|
|
return result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error processing image: {str(e)}", exc_info=True)
|
||
|
|
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
||
|
|
|
||
|
|
@app.get("/health")
|
||
|
|
def health_check():
|
||
|
|
"""Health check endpoint"""
|
||
|
|
logger.info("Health check endpoint accessed")
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"scorer_initialized": scorer is not None
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/scoring-weights")
|
||
|
|
def get_scoring_weights():
|
||
|
|
"""Get the current scoring weights"""
|
||
|
|
logger.info("Scoring weights endpoint accessed")
|
||
|
|
|
||
|
|
if not scorer:
|
||
|
|
logger.error("Scorer not initialized, returning 500 error")
|
||
|
|
raise HTTPException(status_code=500, detail="Scorer not initialized")
|
||
|
|
|
||
|
|
return {
|
||
|
|
"weights": scorer.weights,
|
||
|
|
"description": "4-pillar scoring system weights"
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/available-preferences")
|
||
|
|
def get_available_preferences():
|
||
|
|
"""Get available preference options for users"""
|
||
|
|
logger.info("Available preferences endpoint accessed")
|
||
|
|
|
||
|
|
return {
|
||
|
|
"aesthetics": [
|
||
|
|
"Y2K", "Maximalist", "Minimalist", "Ethereal Grunge",
|
||
|
|
"Cottagecore", "Dark Academia", "Cyberpunk", "Vintage",
|
||
|
|
"Modern", "Boho", "Streetwear", "High Fashion",
|
||
|
|
"Luxury", "Casual", "Professional", "Artistic"
|
||
|
|
],
|
||
|
|
"niches": [
|
||
|
|
"Fashion Influencer", "Food Blogger", "Travel Photographer",
|
||
|
|
"Fitness Influencer", "Tech Professional", "Artist",
|
||
|
|
"Business Professional", "Lifestyle Blogger", "Beauty Influencer",
|
||
|
|
"Parenting Blogger", "Pet Influencer", "Gaming Content Creator"
|
||
|
|
],
|
||
|
|
"target_audiences": [
|
||
|
|
"Gen Z", "Millennials", "Gen X", "Boomers",
|
||
|
|
"Teenagers", "Young Adults", "Professionals", "Parents",
|
||
|
|
"Students", "Entrepreneurs", "Creative Professionals"
|
||
|
|
],
|
||
|
|
"content_types": [
|
||
|
|
"Instagram Post", "Instagram Story", "TikTok Video",
|
||
|
|
"YouTube Thumbnail", "LinkedIn Post", "Twitter Post",
|
||
|
|
"Facebook Post", "Pinterest Pin", "Blog Post"
|
||
|
|
],
|
||
|
|
"brand_voices": [
|
||
|
|
"Playful and Trendy", "Professional and Trustworthy",
|
||
|
|
"Casual and Relatable", "Luxury and Sophisticated",
|
||
|
|
"Fun and Energetic", "Calm and Minimalist",
|
||
|
|
"Bold and Confident", "Warm and Friendly"
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.get("/moderation-status")
|
||
|
|
def get_moderation_status():
|
||
|
|
"""Get content moderation system status"""
|
||
|
|
logger.info("Moderation status endpoint accessed")
|
||
|
|
|
||
|
|
if not scorer:
|
||
|
|
logger.error("Scorer not initialized, returning 500 error")
|
||
|
|
raise HTTPException(status_code=500, detail="Scorer not initialized")
|
||
|
|
|
||
|
|
moderation_status = scorer.content_moderator.get_moderation_status()
|
||
|
|
return {
|
||
|
|
"moderation_system": "Google Cloud Vision SafeSearch",
|
||
|
|
"status": moderation_status,
|
||
|
|
"description": "Content safety and moderation system for detecting inappropriate content"
|
||
|
|
}
|
||
|
|
|
||
|
|
@app.post("/enhance-image")
|
||
|
|
async def enhance_image(request: ScoreImageRequest):
|
||
|
|
"""
|
||
|
|
Generate 5 enhanced versions of an image using AI
|
||
|
|
|
||
|
|
Request body:
|
||
|
|
{
|
||
|
|
"image": "base64 string",
|
||
|
|
"user_preferences": {
|
||
|
|
"aesthetic": "",
|
||
|
|
"niche": "",
|
||
|
|
"target_audience": "",
|
||
|
|
"content_type": "",
|
||
|
|
"brand_voice": ""
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
logger.info("Received image enhancement request")
|
||
|
|
|
||
|
|
if not enhancer:
|
||
|
|
logger.error("Enhancer not initialized, returning 500 error")
|
||
|
|
raise HTTPException(status_code=500, detail="Enhancer not initialized")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Decode base64 image
|
||
|
|
logger.info("Decoding base64 image for enhancement...")
|
||
|
|
image_data = base64.b64decode(request.image)
|
||
|
|
image = io.BytesIO(image_data)
|
||
|
|
|
||
|
|
# Save temporarily for processing
|
||
|
|
with open("temp_enhance_image.jpg", "wb") as f:
|
||
|
|
f.write(image_data)
|
||
|
|
|
||
|
|
logger.info("Image decoded and saved temporarily for enhancement")
|
||
|
|
|
||
|
|
# Prepare user preferences
|
||
|
|
user_prefs_dict = request.user_preferences.dict() if request.user_preferences else None
|
||
|
|
|
||
|
|
# Generate enhanced images
|
||
|
|
logger.info("Starting AI image enhancement...")
|
||
|
|
result = enhancer.enhance_image("temp_enhance_image.jpg", user_prefs_dict)
|
||
|
|
|
||
|
|
# Clean up temporary file
|
||
|
|
os.remove("temp_enhance_image.jpg")
|
||
|
|
|
||
|
|
if result['status'] == 'error':
|
||
|
|
logger.error(f"Enhancement failed: {result['error']}")
|
||
|
|
raise HTTPException(status_code=500, detail=result['error'])
|
||
|
|
|
||
|
|
logger.info(f"Enhancement completed successfully. Generated {result['total_generated']} images")
|
||
|
|
|
||
|
|
# Convert enhanced images to base64 for frontend
|
||
|
|
enhanced_images_data = []
|
||
|
|
for enhanced_img in result['enhanced_images']:
|
||
|
|
try:
|
||
|
|
with open(enhanced_img['image_path'], 'rb') as f:
|
||
|
|
img_data = f.read()
|
||
|
|
img_base64 = base64.b64encode(img_data).decode()
|
||
|
|
|
||
|
|
enhanced_images_data.append({
|
||
|
|
'version': enhanced_img['version'],
|
||
|
|
'image': img_base64,
|
||
|
|
'prompt': enhanced_img['prompt'],
|
||
|
|
'image_path': enhanced_img['image_path']
|
||
|
|
})
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to convert enhanced image {enhanced_img['version']} to base64: {e}")
|
||
|
|
|
||
|
|
return {
|
||
|
|
'status': 'success',
|
||
|
|
'message': f'Successfully generated {len(enhanced_images_data)} enhanced images',
|
||
|
|
'enhanced_images': enhanced_images_data,
|
||
|
|
'total_generated': len(enhanced_images_data),
|
||
|
|
'original_analysis': result['original_image']['analysis']
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error processing enhancement request: {str(e)}", exc_info=True)
|
||
|
|
raise HTTPException(status_code=500, detail=f"Enhancement failed: {str(e)}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
logger.info("Starting Viral Velocity API server...")
|
||
|
|
uvicorn.run(app, host="0.0.0.0", port=5300)
|