update endpoints
This commit is contained in:
+220
-14
@@ -1,24 +1,230 @@
|
||||
# backend/copywriter.py
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
import openai
|
||||
from .vector_store import VectorStore
|
||||
from .brand_style import BrandStyle
|
||||
import openai
|
||||
from .config import Config
|
||||
|
||||
class Copywriter:
|
||||
def __init__(self):
|
||||
self.vector_store = VectorStore()
|
||||
self.brand_style = BrandStyle()
|
||||
self.user_queries_path = Path("data/user_queries")
|
||||
self.user_queries_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize OpenAI
|
||||
openai.api_key = Config.OPENAI_API_KEY
|
||||
|
||||
def generate_copy(self, request):
|
||||
# Move the generation logic from main.py here
|
||||
similar = self.vector_store.search(request.prompt, request.content_type)
|
||||
def generate_copy(self, request) -> Dict:
|
||||
"""Generate marketing copy and log user interaction"""
|
||||
try:
|
||||
# Log the user query first
|
||||
query_log = self._log_user_query(request)
|
||||
|
||||
# Get similar content from vector store
|
||||
similar = self.vector_store.search(request.prompt, request.content_type)
|
||||
|
||||
# Format similar content for context
|
||||
similar_content = ""
|
||||
if similar:
|
||||
similar_content = "\n\nSimilar past campaigns for reference:\n"
|
||||
for i, campaign in enumerate(similar[:3], 1):
|
||||
similar_content += f"{i}. {campaign.get('content', '')}\n"
|
||||
|
||||
# Generate with OpenAI
|
||||
system_prompt = self.brand_style.get_prompt(request)
|
||||
user_prompt = f"Create marketing copy for: {request.prompt}{similar_content}"
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
generated_copy = response.choices[0].message.content
|
||||
|
||||
# Update the query log with the result
|
||||
self._update_query_log(query_log["query_id"], generated_copy, True)
|
||||
|
||||
# Store the generated copy for future reference
|
||||
new_campaign = {
|
||||
"content": generated_copy,
|
||||
"content_type": request.content_type,
|
||||
"metadata": {
|
||||
"prompt": request.prompt,
|
||||
"tone": request.tone,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"query_id": query_log["query_id"]
|
||||
}
|
||||
}
|
||||
|
||||
# Add to vector store for future similarity searches
|
||||
self.vector_store.add_campaign(new_campaign)
|
||||
|
||||
return {
|
||||
"result": generated_copy,
|
||||
"query_id": query_log["query_id"],
|
||||
"similar_campaigns_used": len(similar)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Log the error in user queries
|
||||
if 'query_log' in locals():
|
||||
self._update_query_log(query_log["query_id"], str(e), False)
|
||||
raise e
|
||||
|
||||
def _log_user_query(self, request) -> Dict:
|
||||
"""Log user query for AI training purposes"""
|
||||
query_id = f"query_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": self.brand_style.get_prompt(request)},
|
||||
{"role": "user", "content": f"Prompt: {request.prompt}\n\nSimilar examples:\n{similar}"}
|
||||
],
|
||||
temperature=0.7
|
||||
)
|
||||
# Handle both dict and Pydantic model objects
|
||||
if hasattr(request, 'dict'):
|
||||
request_data = request.dict()
|
||||
else:
|
||||
request_data = request
|
||||
|
||||
return response.choices[0].message.content
|
||||
query_log = {
|
||||
"query_id": query_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"user_input": {
|
||||
"prompt": request_data.get("prompt"),
|
||||
"content_type": request_data.get("content_type"),
|
||||
"tone": request_data.get("tone")
|
||||
},
|
||||
"status": "processing",
|
||||
"generated_output": None,
|
||||
"success": None,
|
||||
"processing_time": None,
|
||||
"similar_campaigns_count": 0
|
||||
}
|
||||
|
||||
# Save to user_queries folder
|
||||
query_file = self.user_queries_path / f"{query_id}.json"
|
||||
with open(query_file, 'w') as f:
|
||||
json.dump(query_log, f, indent=2)
|
||||
|
||||
return query_log
|
||||
|
||||
def _update_query_log(self, query_id: str, output: str, success: bool):
|
||||
"""Update the query log with results"""
|
||||
query_file = self.user_queries_path / f"{query_id}.json"
|
||||
|
||||
if query_file.exists():
|
||||
with open(query_file, 'r') as f:
|
||||
query_log = json.load(f)
|
||||
|
||||
query_log.update({
|
||||
"generated_output": output,
|
||||
"success": success,
|
||||
"status": "completed" if success else "failed",
|
||||
"completed_at": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
with open(query_file, 'w') as f:
|
||||
json.dump(query_log, f, indent=2)
|
||||
|
||||
def get_query_history(self, limit: int = 10) -> List[Dict]:
|
||||
"""Get recent user queries for analysis"""
|
||||
query_files = list(self.user_queries_path.glob("*.json"))
|
||||
query_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
|
||||
queries = []
|
||||
for file in query_files[:limit]:
|
||||
with open(file, 'r') as f:
|
||||
queries.append(json.load(f))
|
||||
|
||||
return queries
|
||||
|
||||
def get_query_analytics(self) -> Dict:
|
||||
"""Get analytics on user queries for AI improvement"""
|
||||
query_files = list(self.user_queries_path.glob("*.json"))
|
||||
|
||||
total_queries = len(query_files)
|
||||
successful_queries = 0
|
||||
failed_queries = 0
|
||||
content_type_counts = {}
|
||||
tone_counts = {}
|
||||
|
||||
for file in query_files:
|
||||
with open(file, 'r') as f:
|
||||
query = json.load(f)
|
||||
|
||||
if query.get("success"):
|
||||
successful_queries += 1
|
||||
elif query.get("success") is False:
|
||||
failed_queries += 1
|
||||
|
||||
content_type = query.get("user_input", {}).get("content_type", "unknown")
|
||||
content_type_counts[content_type] = content_type_counts.get(content_type, 0) + 1
|
||||
|
||||
tone = query.get("user_input", {}).get("tone", "default")
|
||||
tone_counts[tone] = tone_counts.get(tone, 0) + 1
|
||||
|
||||
return {
|
||||
"total_queries": total_queries,
|
||||
"successful_queries": successful_queries,
|
||||
"failed_queries": failed_queries,
|
||||
"success_rate": round(successful_queries / total_queries * 100, 2) if total_queries > 0 else 0,
|
||||
"content_type_distribution": content_type_counts,
|
||||
"tone_distribution": tone_counts
|
||||
}
|
||||
|
||||
def log_user_feedback(self, query_id: str, feedback: Dict):
|
||||
"""Log user feedback on generated copy for training"""
|
||||
query_file = self.user_queries_path / f"{query_id}.json"
|
||||
|
||||
if query_file.exists():
|
||||
with open(query_file, 'r') as f:
|
||||
query_log = json.load(f)
|
||||
|
||||
query_log["user_feedback"] = {
|
||||
"rating": feedback.get("rating"), # 1-5 scale
|
||||
"comments": feedback.get("comments"),
|
||||
"used_output": feedback.get("used_output", False),
|
||||
"modifications_made": feedback.get("modifications_made"),
|
||||
"feedback_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(query_file, 'w') as f:
|
||||
json.dump(query_log, f, indent=2)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def export_training_data(self, output_file: str = None) -> str:
|
||||
"""Export user queries and feedback for model training"""
|
||||
if not output_file:
|
||||
output_file = f"training_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
query_files = list(self.user_queries_path.glob("*.json"))
|
||||
training_data = []
|
||||
|
||||
for file in query_files:
|
||||
with open(file, 'r') as f:
|
||||
query = json.load(f)
|
||||
|
||||
# Only include successful queries with feedback for training
|
||||
if query.get("success") and query.get("user_feedback"):
|
||||
training_example = {
|
||||
"input": query["user_input"]["prompt"],
|
||||
"content_type": query["user_input"]["content_type"],
|
||||
"tone": query["user_input"]["tone"],
|
||||
"output": query["generated_output"],
|
||||
"rating": query["user_feedback"]["rating"],
|
||||
"used": query["user_feedback"]["used_output"]
|
||||
}
|
||||
training_data.append(training_example)
|
||||
|
||||
output_path = Path("data") / output_file
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(training_data, f, indent=2)
|
||||
|
||||
return str(output_path)
|
||||
+62
-3
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from .vector_store import VectorStore
|
||||
from .brand_style import BrandStyle
|
||||
from .config import Config
|
||||
@@ -18,6 +19,10 @@ app = FastAPI(title="Marketing Assistant AI", version="0.1.0")
|
||||
vector_store = VectorStore()
|
||||
brand_style = BrandStyle()
|
||||
|
||||
# Create user_queries directory
|
||||
user_queries_path = Path("data/user_queries")
|
||||
user_queries_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -26,22 +31,63 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Models
|
||||
class CampaignRequest(BaseModel):
|
||||
prompt: str
|
||||
content_type: str = "general"
|
||||
tone: Optional[str] = None
|
||||
|
||||
|
||||
class Campaign(BaseModel):
|
||||
content: str
|
||||
content_type: str
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class UserFeedback(BaseModel):
|
||||
query_id: str
|
||||
rating: int # 1-5 scale
|
||||
comments: Optional[str] = None
|
||||
used_output: bool = False
|
||||
modifications_made: Optional[str] = None
|
||||
|
||||
|
||||
# Helper function to log user queries
|
||||
def log_user_query(request: CampaignRequest, generated_copy: str = None, success: bool = None):
|
||||
"""Log user query to data/user_queries/"""
|
||||
query_id = f"query_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
|
||||
query_log = {
|
||||
"query_id": query_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"user_input": {
|
||||
"prompt": request.prompt,
|
||||
"content_type": request.content_type,
|
||||
"tone": request.tone
|
||||
},
|
||||
"status": "completed" if success else "failed" if success is False else "processing",
|
||||
"generated_output": generated_copy,
|
||||
"success": success
|
||||
}
|
||||
|
||||
# Save to user_queries folder
|
||||
query_file = user_queries_path / f"{query_id}.json"
|
||||
with open(query_file, 'w') as f:
|
||||
json.dump(query_log, f, indent=2)
|
||||
|
||||
return query_id
|
||||
|
||||
|
||||
# Routes
|
||||
@app.post("/generate")
|
||||
async def generate_copy(request: CampaignRequest):
|
||||
"""Generate marketing copy based on prompt and brand guidelines"""
|
||||
query_id = None
|
||||
try:
|
||||
# Log the initial query
|
||||
query_id = log_user_query(request)
|
||||
|
||||
# Get similar content from vector store
|
||||
similar = vector_store.search(request.prompt, request.content_type)
|
||||
|
||||
@@ -68,6 +114,9 @@ async def generate_copy(request: CampaignRequest):
|
||||
|
||||
generated_copy = response.choices[0].message.content
|
||||
|
||||
# Update query log with success
|
||||
log_user_query(request, generated_copy, True)
|
||||
|
||||
# Store the generated copy for future reference
|
||||
new_campaign = {
|
||||
"content": generated_copy,
|
||||
@@ -75,16 +124,23 @@ async def generate_copy(request: CampaignRequest):
|
||||
"metadata": {
|
||||
"prompt": request.prompt,
|
||||
"tone": request.tone,
|
||||
"generated_at": datetime.now().isoformat()
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"query_id": query_id
|
||||
}
|
||||
}
|
||||
|
||||
# Add to vector store for future similarity searches
|
||||
vector_store.add_campaign(new_campaign)
|
||||
|
||||
return {"result": generated_copy}
|
||||
return {
|
||||
"result": generated_copy,
|
||||
"query_id": query_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
if query_id:
|
||||
log_user_query(request, str(e), False)
|
||||
print(f"Error in generate_copy: {str(e)}") # For debugging
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -112,11 +168,13 @@ async def search_campaigns(query: str, limit: int = 5):
|
||||
print(f"Error in search_campaigns: {str(e)}") # For debugging
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Health check endpoint"""
|
||||
return {"message": "Marketing Assistant AI is running", "version": "0.1.0"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Detailed health check"""
|
||||
@@ -126,6 +184,7 @@ async def health_check():
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
+112
-13
@@ -9,29 +9,128 @@ from pathlib import Path
|
||||
class VectorStore:
|
||||
def __init__(self):
|
||||
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
self.index = faiss.IndexFlatL2(384)
|
||||
self.index = faiss.IndexFlatL2(384) # all-MiniLM-L6-v2 produces 384-dim embeddings
|
||||
self.campaigns = []
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
"""Load existing campaigns and rebuild the index"""
|
||||
if Path(Config.DATA_PATH).exists():
|
||||
with open(Config.DATA_PATH) as f:
|
||||
self.campaigns = json.load(f)
|
||||
try:
|
||||
with open(Config.DATA_PATH) as f:
|
||||
self.campaigns = json.load(f)
|
||||
|
||||
if self.campaigns:
|
||||
embeddings = self.model.encode([c["content"] for c in self.campaigns])
|
||||
self.index.add(embeddings)
|
||||
# Extract content and generate embeddings
|
||||
contents = [c.get("content", "") for c in self.campaigns]
|
||||
# Filter out empty content
|
||||
contents = [c for c in contents if c.strip()]
|
||||
|
||||
if contents:
|
||||
embeddings = self.model.encode(contents)
|
||||
# Ensure embeddings are float32 for FAISS
|
||||
embeddings = embeddings.astype('float32')
|
||||
self.index.add(embeddings)
|
||||
print(f"Loaded {len(contents)} campaigns into vector store")
|
||||
else:
|
||||
print("No valid content found in campaigns")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
self.campaigns = []
|
||||
|
||||
def add_campaign(self, campaign: dict):
|
||||
"""Add a single campaign to the vector store"""
|
||||
if not campaign.get("content", "").strip():
|
||||
print("Warning: Empty content provided, skipping campaign")
|
||||
return
|
||||
|
||||
self.campaigns.append(campaign)
|
||||
embedding = self.model.encode([campaign["content"]])
|
||||
self.index.add(embedding)
|
||||
self._save_data()
|
||||
|
||||
try:
|
||||
# Generate embedding for the new campaign
|
||||
embedding = self.model.encode([campaign["content"]])
|
||||
embedding = embedding.astype('float32')
|
||||
self.index.add(embedding)
|
||||
self._save_data()
|
||||
print(f"Added campaign to vector store. Total campaigns: {len(self.campaigns)}")
|
||||
except Exception as e:
|
||||
print(f"Error adding campaign: {e}")
|
||||
# Remove the campaign if embedding failed
|
||||
self.campaigns.pop()
|
||||
|
||||
def search(self, query: str, content_type: str = None, k: int = 3):
|
||||
query_embedding = self.model.encode([query])
|
||||
distances, indices = self.index.search(query_embedding, k)
|
||||
return [self.campaigns[i] for i in indices[0]]
|
||||
"""Search for similar campaigns"""
|
||||
if not query.strip():
|
||||
print("Empty query provided")
|
||||
return []
|
||||
|
||||
if len(self.campaigns) == 0:
|
||||
print("No campaigns in vector store")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate query embedding
|
||||
query_embedding = self.model.encode([query])
|
||||
query_embedding = query_embedding.astype('float32')
|
||||
|
||||
# Adjust k to not exceed available campaigns
|
||||
k = min(k, len(self.campaigns))
|
||||
|
||||
# Search the index
|
||||
distances, indices = self.index.search(query_embedding, k)
|
||||
|
||||
# Filter out invalid indices (-1 means no match found)
|
||||
valid_results = []
|
||||
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
|
||||
if idx >= 0 and idx < len(self.campaigns):
|
||||
campaign = self.campaigns[idx].copy()
|
||||
campaign['_search_distance'] = float(distance) # Add distance for debugging
|
||||
|
||||
# Filter by content_type if specified
|
||||
if content_type is None or campaign.get('content_type') == content_type:
|
||||
valid_results.append(campaign)
|
||||
|
||||
print(f"Search for '{query}' returned {len(valid_results)} results")
|
||||
for i, result in enumerate(valid_results):
|
||||
print(f" {i+1}. Distance: {result['_search_distance']:.4f}, Content: {result['content'][:100]}...")
|
||||
|
||||
return valid_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during search: {e}")
|
||||
return []
|
||||
|
||||
def _save_data(self):
|
||||
with open(Config.DATA_PATH, 'w') as f:
|
||||
json.dump(self.campaigns, f)
|
||||
"""Save campaigns to disk"""
|
||||
try:
|
||||
Path(Config.DATA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(Config.DATA_PATH, 'w') as f:
|
||||
json.dump(self.campaigns, f, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving data: {e}")
|
||||
|
||||
def get_stats(self):
|
||||
"""Get vector store statistics"""
|
||||
return {
|
||||
"total_campaigns": len(self.campaigns),
|
||||
"index_size": self.index.ntotal,
|
||||
"embedding_dimension": self.index.d
|
||||
}
|
||||
|
||||
def rebuild_index(self):
|
||||
"""Rebuild the entire index from scratch"""
|
||||
print("Rebuilding vector store index...")
|
||||
self.index = faiss.IndexFlatL2(384)
|
||||
|
||||
if self.campaigns:
|
||||
contents = [c.get("content", "") for c in self.campaigns if c.get("content", "").strip()]
|
||||
if contents:
|
||||
embeddings = self.model.encode(contents)
|
||||
embeddings = embeddings.astype('float32')
|
||||
self.index.add(embeddings)
|
||||
print(f"Rebuilt index with {len(contents)} campaigns")
|
||||
else:
|
||||
print("No valid content found to rebuild index")
|
||||
else:
|
||||
print("No campaigns to rebuild index")
|
||||
Reference in New Issue
Block a user