Initial commit
This commit is contained in:
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Brand style module for the Marketing Assistant AI.
|
||||
Ensures generated content aligns with Adriana James' brand voice and tone.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
import config
|
||||
|
||||
class BrandStyleManager:
|
||||
"""Manages brand style guidelines and ensures content consistency."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BrandStyleManager with default or stored style guidelines."""
|
||||
self.style_path = Path(config.DATA_DIR) / "style_guidelines" / "brand_style.json"
|
||||
self.style_guidelines = self._load_or_create_style()
|
||||
self.content_formats = {
|
||||
"website_copy": """
|
||||
Generate engaging website copy for a brand or business.
|
||||
- Start with a strong headline and supporting subheadline
|
||||
- Write in a clear, benefit-driven tone
|
||||
- Use SEO-friendly keywords naturally
|
||||
- Structure content with short paragraphs and bullet points
|
||||
- Include a clear call-to-action at the end
|
||||
""",
|
||||
"email": """
|
||||
Create a marketing or sales email for a target audience.
|
||||
- Start with a compelling subject line
|
||||
- Use a warm, conversational tone
|
||||
- Keep the message focused and value-driven
|
||||
- Personalize where possible (name, context)
|
||||
- End with a clear and persuasive CTA
|
||||
""",
|
||||
"social_media": """
|
||||
Write social media content tailored to a specific platform.
|
||||
- Hook the reader within the first sentence
|
||||
- Keep the message concise and engaging
|
||||
- Use platform-appropriate tone and emojis (if applicable)
|
||||
- Add relevant hashtags and tag accounts when needed
|
||||
- Include a prompt or CTA to drive interaction
|
||||
""",
|
||||
"blog_post": """
|
||||
Generate a blog article on a given topic or keyword.
|
||||
- Begin with a strong hook or introduction
|
||||
- Organize content with subheadings and logical flow
|
||||
- Use examples, data, and storytelling
|
||||
- Optimize for SEO with keywords and meta description
|
||||
- Conclude with a summary or actionable insight
|
||||
""",
|
||||
"sales_copy": """
|
||||
Write persuasive sales copy for a product or service.
|
||||
- Lead with a strong value proposition
|
||||
- Address specific pain points and offer solutions
|
||||
- Highlight features, benefits, and outcomes
|
||||
- Include social proof (testimonials, stats, etc.)
|
||||
- End with a direct and compelling CTA
|
||||
""",
|
||||
"ad_copy": """
|
||||
Create short, punchy ad copy for digital or print campaigns.
|
||||
- Capture attention in the first line
|
||||
- Use emotional or benefit-driven language
|
||||
- Keep it brief and persuasive
|
||||
- Align copy with the target audience
|
||||
- Include a CTA or promotional message
|
||||
""",
|
||||
"video_script": """
|
||||
Generate a short video script for a marketing video.
|
||||
- Hook the viewer in the first few seconds
|
||||
- Introduce the problem and present the solution
|
||||
- Keep the tone conversational and natural
|
||||
- Include visual cues and on-screen text ideas
|
||||
- Wrap up with a strong CTA
|
||||
""",
|
||||
"case_study": """
|
||||
Write a case study that highlights a customer success story.
|
||||
- Start with a quick summary of the results
|
||||
- Describe the client and their initial problem
|
||||
- Explain how the product/service helped
|
||||
- Include measurable outcomes or metrics
|
||||
- End with a quote and a CTA to learn more
|
||||
""",
|
||||
"product_description": """
|
||||
Generate a product description that drives interest and conversions.
|
||||
- Begin with the most attractive benefit
|
||||
- Mention key features and what makes the product unique
|
||||
- Use sensory and persuasive language
|
||||
- Include important specs or FAQs
|
||||
- End with a micro-CTA (e.g., "Shop now", "View details")
|
||||
""",
|
||||
"landing_page": """
|
||||
Write copy for a focused landing page.
|
||||
- Use a bold, attention-grabbing headline
|
||||
- Describe the offer clearly and simply
|
||||
- Include supporting details that reinforce value
|
||||
- Remove distractions and focus on a single goal
|
||||
- Add a CTA above the fold and at the end
|
||||
""",
|
||||
"press_release": """
|
||||
Create a professional press release for an announcement.
|
||||
- Begin with a headline that summarizes the news
|
||||
- Use a journalistic tone and structure
|
||||
- Provide key facts in the first paragraph
|
||||
- Add quotes from relevant leaders or stakeholders
|
||||
- End with boilerplate company info and contact details
|
||||
""",
|
||||
"newsletter": """
|
||||
Write a newsletter update for subscribers.
|
||||
- Start with a warm greeting or short intro
|
||||
- Highlight the most important news or offer first
|
||||
- Use engaging sub-sections or article teasers
|
||||
- Maintain consistent tone with the brand
|
||||
- Include CTAs to drive clicks or traffic
|
||||
"""
|
||||
}
|
||||
logger.info("BrandStyleManager initialized successfully")
|
||||
|
||||
def _load_or_create_style(self) -> Dict[str, Any]:
|
||||
"""Load existing style guidelines or create new ones with defaults."""
|
||||
try:
|
||||
if self.style_path.exists():
|
||||
with open(self.style_path, 'r') as f:
|
||||
style = json.load(f)
|
||||
logger.info("Loaded existing brand style guidelines")
|
||||
return style
|
||||
else:
|
||||
# Create directory if it doesn't exist
|
||||
self.style_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Use default style guidelines
|
||||
style = config.DEFAULT_BRAND_STYLE
|
||||
|
||||
# Save default style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(style, f, indent=2)
|
||||
|
||||
logger.info("Created default brand style guidelines")
|
||||
return style
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating style guidelines: {str(e)}")
|
||||
# Fall back to default style
|
||||
return config.DEFAULT_BRAND_STYLE
|
||||
|
||||
def get_style_guidelines(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current brand style guidelines.
|
||||
|
||||
Returns:
|
||||
Dictionary of style guidelines
|
||||
"""
|
||||
return self.style_guidelines
|
||||
|
||||
def update_style_guidelines(self, new_style: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Update brand style guidelines.
|
||||
|
||||
Args:
|
||||
new_style: Dictionary with new style guidelines
|
||||
|
||||
Returns:
|
||||
Updated style guidelines dictionary
|
||||
"""
|
||||
try:
|
||||
# Merge new style with existing
|
||||
for key, value in new_style.items():
|
||||
self.style_guidelines[key] = value
|
||||
|
||||
# Ensure brand name is preserved
|
||||
self.style_guidelines['brand_name'] = config.BRAND_NAME
|
||||
|
||||
# Save updated style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(self.style_guidelines, f, indent=2)
|
||||
|
||||
logger.info("Updated brand style guidelines")
|
||||
return self.style_guidelines
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating style guidelines: {str(e)}")
|
||||
raise
|
||||
|
||||
def format_prompt_with_brand_style(self, user_prompt: str, content_type: Optional[str] = None) -> str:
|
||||
"""Format user prompt to match the distinctive communication style."""
|
||||
|
||||
style_instructions = [
|
||||
"Follow these distinctive communication style guidelines:",
|
||||
"- Use empowering, assertive language that inspires action",
|
||||
"- Address the reader directly using 'you' and 'your' with conviction",
|
||||
"- Create rhythmic, repetitive patterns in key messages for emphasis",
|
||||
"- Maintain a clear, confident, and conversational teaching tone",
|
||||
"- Use simple, practical language that communicates profound ideas",
|
||||
"- Use embedded commands (e.g., 'Decide now to change your thinking')",
|
||||
"- Include cause-effect statements (e.g., 'Because you understand this, you will now take action')",
|
||||
"- Speak with conviction and clarity rather than hesitation",
|
||||
"- Replace tentative phrases with confident declarations",
|
||||
"- Use a motivational coach-like clarity in all communications",
|
||||
"- IMPORTANT: Do not mention any specific person's name in the content"
|
||||
]
|
||||
|
||||
# Content type specific formatting
|
||||
content_format = self._get_content_format(content_type) if content_type else ""
|
||||
|
||||
return "\n".join([
|
||||
f"Generate content based on this request:",
|
||||
f"\"{user_prompt}\"",
|
||||
"",
|
||||
"\n".join(style_instructions),
|
||||
content_format
|
||||
])
|
||||
|
||||
def check_content_alignment(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if generated content aligns with brand style guidelines.
|
||||
|
||||
Args:
|
||||
content: Generated marketing content
|
||||
|
||||
Returns:
|
||||
Dictionary with alignment metrics and suggestions
|
||||
"""
|
||||
style = self.style_guidelines
|
||||
taboo_words = style.get('taboo_words', [])
|
||||
preferred_terms = style.get('preferred_terms', {})
|
||||
|
||||
# Check for taboo words
|
||||
found_taboo_words = []
|
||||
for word in taboo_words:
|
||||
if word.lower() in content.lower():
|
||||
found_taboo_words.append(word)
|
||||
|
||||
# Check for preferred terminology
|
||||
terminology_issues = []
|
||||
for avoid, use in preferred_terms.items():
|
||||
if avoid.lower() in content.lower():
|
||||
terminology_issues.append(f"Found '{avoid}', should use '{use}' instead")
|
||||
|
||||
# Calculate an overall alignment score (simple implementation)
|
||||
issues_count = len(found_taboo_words) + len(terminology_issues)
|
||||
alignment_score = max(0, 100 - (issues_count * 10)) # Reduce score for each issue
|
||||
|
||||
return {
|
||||
'alignment_score': alignment_score,
|
||||
'taboo_words_found': found_taboo_words,
|
||||
'terminology_issues': terminology_issues,
|
||||
'aligned': alignment_score >= 80 # Consider aligned if score is 80% or higher
|
||||
}
|
||||
|
||||
def _get_content_format(self, content_type: str) -> str:
|
||||
"""
|
||||
Get formatting instructions for specific content type.
|
||||
|
||||
Args:
|
||||
content_type: Type of content to generate
|
||||
|
||||
Returns:
|
||||
Formatting instructions as string
|
||||
"""
|
||||
if not content_type:
|
||||
return ""
|
||||
|
||||
format_instructions = self.content_formats.get(content_type, "")
|
||||
if format_instructions:
|
||||
return f"\nContent type specific instructions:\n{format_instructions.strip()}"
|
||||
return ""
|
||||
|
||||
# Create a singleton instance
|
||||
brand_style_manager = BrandStyleManager()
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Configuration module for the Marketing Assistant AI.
|
||||
Handles environment variables and application settings.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Base paths
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = BASE_DIR / "data"
|
||||
|
||||
# Ensure data directories exist
|
||||
(DATA_DIR / "past_campaigns").mkdir(exist_ok=True)
|
||||
(DATA_DIR / "user_queries").mkdir(exist_ok=True)
|
||||
(DATA_DIR / "style_guidelines").mkdir(exist_ok=True)
|
||||
|
||||
# API configuration
|
||||
API_HOST = os.getenv("API_HOST", "localhost")
|
||||
API_PORT = int(os.getenv("API_PORT", 8000))
|
||||
|
||||
# LLM configuration
|
||||
LLM_MODEL = os.getenv("LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("LLM_API_KEY")
|
||||
|
||||
# Cohere configuration
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
|
||||
# Vector database configuration
|
||||
VECTOR_DB_PATH = os.getenv("VECTOR_DB_PATH", str(DATA_DIR / "vector_store"))
|
||||
|
||||
# Brand configuration
|
||||
BRAND_NAME = os.getenv("BRAND_NAME", "Adriana James")
|
||||
|
||||
# Content types
|
||||
CONTENT_TYPES = [
|
||||
"website_copy",
|
||||
"email",
|
||||
"social_media",
|
||||
"blog_post",
|
||||
"sales_copy",
|
||||
"ad_copy",
|
||||
"video_script",
|
||||
"case_study",
|
||||
"product_description",
|
||||
"landing_page",
|
||||
"press_release",
|
||||
"newsletter"
|
||||
]
|
||||
|
||||
# Tone options - specifically matching Adriana James' communication style
|
||||
TONE_OPTIONS = [
|
||||
"empowering",
|
||||
"assertive",
|
||||
"inspirational",
|
||||
"direct"
|
||||
]
|
||||
|
||||
# Content length options
|
||||
LENGTH_OPTIONS = [
|
||||
"short", # < 100 words
|
||||
"medium", # 100-300 words
|
||||
"long", # > 300 words
|
||||
]
|
||||
|
||||
# Default brand style guidelines - fixed to match Adriana James' distinct communication style
|
||||
DEFAULT_BRAND_STYLE = {
|
||||
"tone": ["empowering", "assertive", "inspirational", "direct"],
|
||||
"voice_characteristics": ["clear", "confident", "conversational", "teaching"],
|
||||
"writing_patterns": ["direct commands", "personal pronouns", "repetitive rhythms", "embedded commands", "cause-effect statements"],
|
||||
"taboo_words": ["cheap", "discount", "bargain", "failure", "impossible", "difficult", "might", "try", "consider"],
|
||||
"preferred_terms": {
|
||||
"problems": "challenges",
|
||||
"try": "take action",
|
||||
"difficult": "ready for growth",
|
||||
"failure": "learning opportunity",
|
||||
"hope": "know",
|
||||
"maybe": "will",
|
||||
"might help you": "you can do this",
|
||||
"consider doing this": "decide now to change your thinking",
|
||||
"this could work": "this works because"
|
||||
}
|
||||
}
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FILE = os.getenv("LOG_FILE", str(BASE_DIR / "logs" / "app.log"))
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
(BASE_DIR / "logs").mkdir(exist_ok=True)
|
||||
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Copywriter module for the Marketing Assistant AI.
|
||||
Core AI-powered content generation using a fine-tuned LLM.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import httpx
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
import config
|
||||
from brand_style import brand_style_manager
|
||||
from vector_store import vector_store
|
||||
|
||||
class Copywriter:
|
||||
"""Generates marketing copy using a fine-tuned LLM."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Copywriter with Cohere LLM client."""
|
||||
self.model = "command" # Cohere's generation model
|
||||
self.api_key = config.COHERE_API_KEY
|
||||
logger.info("Copywriter initialized with Cohere API successfully")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def generate_copy(
|
||||
self,
|
||||
prompt: str,
|
||||
content_type: Optional[str] = None,
|
||||
length: Optional[str] = None,
|
||||
include_cta: bool = False,
|
||||
reference_similar_content: bool = True,
|
||||
max_tokens: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate marketing copy based on the user prompt and parameters.
|
||||
Note: Removed tone parameter as we always use the established style
|
||||
"""
|
||||
try:
|
||||
# Step 1: Format prompt with brand style guidelines
|
||||
branded_prompt = brand_style_manager.format_prompt_with_brand_style(prompt, content_type)
|
||||
|
||||
# Step 2: Find similar content for reference (if enabled)
|
||||
reference_content = []
|
||||
if reference_similar_content:
|
||||
logger.info(f"Searching for similar content to reference for prompt: {prompt[:50]}...")
|
||||
search_results = await vector_store.search(prompt, top_k=3)
|
||||
if search_results:
|
||||
reference_content = [result['text'] for result in search_results]
|
||||
logger.info(f"Found {len(reference_content)} similar content items to reference")
|
||||
for i, content in enumerate(reference_content):
|
||||
logger.debug(f"Reference content {i+1}: {content[:100]}...")
|
||||
else:
|
||||
logger.warning("No similar content found in vector store for reference")
|
||||
|
||||
# Step 3: Add length and CTA instructions if needed
|
||||
if length:
|
||||
branded_prompt += f"\n- Generate {length} content"
|
||||
if include_cta:
|
||||
branded_prompt += "\n- Include a direct, empowering call to action"
|
||||
|
||||
# Step 4: Add reference content if available
|
||||
if reference_content:
|
||||
branded_prompt += "\n\nReference these successful examples for tone and style:\n"
|
||||
branded_prompt += "\n---\n".join(reference_content)
|
||||
|
||||
# Step 5: Generate content using the LLM
|
||||
generated_content = await self._call_llm_api(branded_prompt, max_tokens)
|
||||
|
||||
# Step 6: Post-process to remove any mentions of Adriana James
|
||||
generated_content = self._remove_name_mentions(generated_content)
|
||||
|
||||
# Step 7: Check content alignment with brand style
|
||||
alignment_check = brand_style_manager.check_content_alignment(generated_content)
|
||||
|
||||
# Step 7: Generate alternative headline suggestions
|
||||
headline_suggestions = await self._generate_headline_suggestions(prompt, generated_content)
|
||||
|
||||
# Step 8: Return the generated content with metadata
|
||||
result = {
|
||||
"content": generated_content,
|
||||
"suggestions": headline_suggestions,
|
||||
"metadata": {
|
||||
"content_type": content_type,
|
||||
"tone": None, # Removed tone parameter
|
||||
"alignment_score": alignment_check['alignment_score'],
|
||||
"generated_at": None # Will be added by the API
|
||||
}
|
||||
}
|
||||
|
||||
# Add alignment issues if any
|
||||
if alignment_check['taboo_words_found'] or alignment_check['terminology_issues']:
|
||||
result["alignment_issues"] = {
|
||||
"taboo_words_found": alignment_check['taboo_words_found'],
|
||||
"terminology_issues": alignment_check['terminology_issues']
|
||||
}
|
||||
|
||||
logger.info(f"Generated content with {len(generated_content)} characters")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating copy: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def _call_llm_api(self, prompt: str, max_tokens: int = 1000) -> str:
|
||||
"""
|
||||
Call the Cohere API to generate content.
|
||||
|
||||
Args:
|
||||
prompt: The formatted prompt for the LLM
|
||||
max_tokens: Maximum tokens for the generated response
|
||||
|
||||
Returns:
|
||||
Generated content as a string with preserved formatting
|
||||
"""
|
||||
try:
|
||||
cohere_api_key = config.COHERE_API_KEY
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.cohere.ai/v1/generate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {cohere_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": "command",
|
||||
"prompt": f"{prompt}\n\nNote: Please preserve formatting with proper paragraphs, line breaks, and bullet points where appropriate.",
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
"k": 0,
|
||||
"p": 0.75,
|
||||
"return_likelihoods": "NONE"
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
generated_text = result["generations"][0]["text"].strip()
|
||||
|
||||
# Preserve paragraph breaks and formatting
|
||||
formatted_text = (
|
||||
generated_text
|
||||
.replace("\n\n", "<paragraph-break>") # Preserve paragraph breaks
|
||||
.replace("\n- ", "\n• ") # Convert hyphens to bullets
|
||||
.replace("<paragraph-break>", "\n\n") # Restore paragraph breaks
|
||||
)
|
||||
|
||||
return formatted_text
|
||||
else:
|
||||
logger.error(f"Cohere API error: {response.status_code}, {response.text}")
|
||||
raise Exception(f"Cohere API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Cohere API: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_headline_suggestions(self, original_prompt: str, generated_content: str) -> List[str]:
|
||||
"""
|
||||
Generate alternative headline suggestions based on the content.
|
||||
|
||||
Args:
|
||||
original_prompt: The original user prompt
|
||||
generated_content: The generated marketing content
|
||||
|
||||
Returns:
|
||||
List of headline suggestions
|
||||
"""
|
||||
try:
|
||||
# Create a prompt for headline generation
|
||||
headline_prompt = f"""
|
||||
Generate 3 alternative marketing headlines for the following content.
|
||||
Make headlines compelling, concise, and aligned with the content's message.
|
||||
Each headline should be unique and capture attention.
|
||||
IMPORTANT: Do not mention any specific person's name in the headlines.
|
||||
|
||||
ORIGINAL PROMPT:
|
||||
{original_prompt}
|
||||
|
||||
CONTENT:
|
||||
{generated_content}
|
||||
|
||||
Generate exactly 3 headlines, one per line, without numbering or prefixes.
|
||||
"""
|
||||
|
||||
# Call LLM to generate headlines
|
||||
response = await self._call_llm_api(
|
||||
prompt=headline_prompt,
|
||||
max_tokens=100 # Shorter limit for headlines
|
||||
)
|
||||
|
||||
# Process the response into a list of headlines
|
||||
headlines = [
|
||||
headline.strip()
|
||||
for headline in response.split('\n')
|
||||
if headline.strip() and not headline.lower().startswith(('headline', 'title', '-', '*', '•'))
|
||||
]
|
||||
|
||||
# Remove any mentions of Adriana James from headlines
|
||||
headlines = [self._remove_name_mentions(headline) for headline in headlines]
|
||||
|
||||
# Ensure we have exactly 3 headlines
|
||||
if len(headlines) > 3:
|
||||
headlines = headlines[:3]
|
||||
while len(headlines) < 3:
|
||||
headlines.append(f"Headline Option {len(headlines) + 1}")
|
||||
|
||||
logger.info(f"Generated {len(headlines)} headline suggestions")
|
||||
return headlines
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating headline suggestions: {str(e)}")
|
||||
# Return empty list instead of mock response on error
|
||||
return []
|
||||
|
||||
async def improve_copy(self, content: str, feedback: str) -> str:
|
||||
"""
|
||||
Improve content based on user feedback.
|
||||
|
||||
Args:
|
||||
content: Original generated content
|
||||
feedback: User feedback for improvement
|
||||
|
||||
Returns:
|
||||
Improved content
|
||||
"""
|
||||
try:
|
||||
# Format prompt for improvement
|
||||
improve_prompt = f"""
|
||||
Please improve the following marketing content based on the feedback provided:
|
||||
IMPORTANT: Do not mention any specific person's name in the content.
|
||||
|
||||
ORIGINAL CONTENT:
|
||||
{content}
|
||||
|
||||
FEEDBACK:
|
||||
{feedback}
|
||||
|
||||
IMPROVED CONTENT:
|
||||
"""
|
||||
|
||||
# Call LLM to improve content
|
||||
improved_content = await self._call_llm_api(improve_prompt, max_tokens=1200)
|
||||
|
||||
# Remove any mentions of Adriana James from improved content
|
||||
improved_content = self._remove_name_mentions(improved_content)
|
||||
|
||||
logger.info(f"Improved content based on feedback")
|
||||
return improved_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error improving content: {str(e)}")
|
||||
raise
|
||||
|
||||
async def analyze_content_performance(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze marketing content for performance prediction.
|
||||
|
||||
Args:
|
||||
content: Marketing content to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with analysis results
|
||||
"""
|
||||
try:
|
||||
# This would be enhanced with actual ML models in production
|
||||
# Simplified mock response for demonstration
|
||||
|
||||
# Very basic analysis using length and keyword presence
|
||||
word_count = len(content.split())
|
||||
has_cta = any(phrase in content.lower() for phrase in ["call", "contact", "get started", "try", "buy", "sign up"])
|
||||
sentence_count = len([s for s in content.split(".") if s.strip()])
|
||||
avg_words_per_sentence = word_count / max(1, sentence_count)
|
||||
|
||||
# Simple scoring system
|
||||
readability_score = 100 - min(100, max(0, abs(avg_words_per_sentence - 15) * 5))
|
||||
cta_score = 90 if has_cta else 60
|
||||
length_score = min(100, max(0, word_count / 3))
|
||||
|
||||
overall_score = (readability_score + cta_score + length_score) / 3
|
||||
|
||||
return {
|
||||
"overall_score": round(overall_score, 1),
|
||||
"readability_score": round(readability_score, 1),
|
||||
"cta_effectiveness": round(cta_score, 1),
|
||||
"length_appropriateness": round(length_score, 1),
|
||||
"metrics": {
|
||||
"word_count": word_count,
|
||||
"sentence_count": sentence_count,
|
||||
"avg_words_per_sentence": round(avg_words_per_sentence, 1),
|
||||
"has_cta": has_cta
|
||||
},
|
||||
"improvement_suggestions": [
|
||||
"Consider adding a stronger call to action" if cta_score < 80 else "Your call to action is effective",
|
||||
"Try to use shorter sentences for better readability" if avg_words_per_sentence > 20 else "Your sentence length is good for readability",
|
||||
"Consider adding more content for better engagement" if word_count < 100 else "Your content length is appropriate"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing content: {str(e)}")
|
||||
raise
|
||||
|
||||
def _remove_name_mentions(self, content: str) -> str:
|
||||
"""
|
||||
Remove any mentions of specific names from the generated content.
|
||||
|
||||
Args:
|
||||
content: The generated content to process
|
||||
|
||||
Returns:
|
||||
Content with name mentions removed
|
||||
"""
|
||||
try:
|
||||
# Remove any mentions of "Adriana James" (case insensitive)
|
||||
import re
|
||||
pattern = re.compile(r'\bAdriana\s+James\b', re.IGNORECASE)
|
||||
content = pattern.sub('', content)
|
||||
|
||||
# Clean up any double spaces that might result from the removal
|
||||
content = re.sub(r'\s+', ' ', content)
|
||||
|
||||
# Clean up any lines that might now be empty
|
||||
content = '\n'.join([line for line in content.split('\n') if line.strip()])
|
||||
|
||||
logger.info("Removed any name mentions from generated content")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing name mentions: {str(e)}")
|
||||
return content
|
||||
|
||||
# Create a singleton instance
|
||||
copywriter = Copywriter()
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Embeddings module for the Marketing Assistant AI.
|
||||
Uses Cohere to generate and manage text embeddings.
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
import config
|
||||
|
||||
class EmbeddingsManager:
|
||||
"""Manages the generation and manipulation of text embeddings using Cohere."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the EmbeddingsManager with Cohere API client."""
|
||||
try:
|
||||
self.co = cohere.Client(config.COHERE_API_KEY)
|
||||
logger.info("EmbeddingsManager initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize EmbeddingsManager: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def get_embeddings(self, texts: List[str], model: str = "embed-english-v3.0") -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
model: Cohere embedding model to use
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of embeddings vectors
|
||||
"""
|
||||
try:
|
||||
if not texts:
|
||||
logger.warning("Empty text list provided for embedding")
|
||||
return np.array([])
|
||||
|
||||
# Ensure texts are not too long for the API
|
||||
processed_texts = [text[:8192] for text in texts]
|
||||
|
||||
response = self.co.embed(
|
||||
texts=processed_texts,
|
||||
model=model,
|
||||
input_type="search_document"
|
||||
)
|
||||
|
||||
embeddings = np.array(response.embeddings)
|
||||
logger.debug(f"Generated {len(embeddings)} embeddings with shape {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def get_query_embedding(self, text: str, model: str = "embed-english-v3.0") -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single query text.
|
||||
|
||||
Args:
|
||||
text: The query text to embed
|
||||
model: Cohere embedding model to use
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Embedding vector for the query
|
||||
"""
|
||||
try:
|
||||
response = self.co.embed(
|
||||
texts=[text[:8192]],
|
||||
model=model,
|
||||
input_type="search_query"
|
||||
)
|
||||
|
||||
embedding = np.array(response.embeddings[0])
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating query embedding: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def rerank_results(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
model: str = "rerank-v3.5",
|
||||
top_n: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents based on relevance to the query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Cohere reranking model to use
|
||||
top_n: Number of top results to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with document index and relevance score
|
||||
"""
|
||||
try:
|
||||
if not documents:
|
||||
logger.warning("Empty document list provided for reranking")
|
||||
return []
|
||||
|
||||
# Truncate documents if they're too long
|
||||
processed_docs = [doc[:8192] for doc in documents]
|
||||
|
||||
response = self.co.rerank(
|
||||
query=query,
|
||||
documents=processed_docs,
|
||||
model=model,
|
||||
top_n=min(top_n, len(processed_docs))
|
||||
)
|
||||
|
||||
results = [
|
||||
{
|
||||
"index": result.index,
|
||||
"document": documents[result.index],
|
||||
"relevance_score": result.relevance_score
|
||||
}
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
logger.debug(f"Reranked {len(documents)} documents, returning top {len(results)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reranking documents: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create a singleton instance
|
||||
embeddings_manager = EmbeddingsManager()
|
||||
+431
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Main FastAPI application for the Marketing Assistant AI.
|
||||
Provides API endpoints for generating and managing marketing content.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query, Body, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
import config
|
||||
from copywriter import copywriter
|
||||
from vector_store import vector_store
|
||||
from brand_style import brand_style_manager
|
||||
from embeddings import embeddings_manager
|
||||
from models import database, training_data
|
||||
|
||||
# Initialize logging
|
||||
logger.add(config.LOG_FILE, level=config.LOG_LEVEL, rotation="10 MB", retention="1 month")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Marketing Assistant AI",
|
||||
description="AI-powered tool for marketing copywriting with Adriana James' brand voice",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify your frontend domain
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Define request and response models
|
||||
class GenerateCopyRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The main instruction for generating content")
|
||||
content_type: Optional[str] = Field(None, description="Type of content to generate")
|
||||
length: Optional[str] = Field(None, description="Desired length of the content")
|
||||
include_cta: Optional[bool] = Field(False, description="Whether to include a call to action")
|
||||
reference_similar_content: Optional[bool] = Field(True, description="Whether to reference similar content")
|
||||
max_tokens: Optional[int] = Field(1000, description="Maximum tokens for the generated response")
|
||||
|
||||
class TrainingDataRequest(BaseModel):
|
||||
content_type: str = Field(..., description="Type of content")
|
||||
content: str = Field(..., description="The marketing content")
|
||||
metadata: Optional[Dict[str, Any]] = Field({}, description="Additional metadata about the content")
|
||||
|
||||
class BrandStyleUpdateRequest(BaseModel):
|
||||
tone: Optional[List[str]] = Field(None, description="Brand tone options")
|
||||
voice_characteristics: Optional[List[str]] = Field(None, description="Voice characteristics")
|
||||
taboo_words: Optional[List[str]] = Field(None, description="Words to avoid")
|
||||
preferred_terms: Optional[Dict[str, str]] = Field(None, description="Preferred terminology")
|
||||
|
||||
class ContentImprovementRequest(BaseModel):
|
||||
content: str = Field(..., description="Original generated content")
|
||||
feedback: str = Field(..., description="User feedback for improvement")
|
||||
|
||||
# API Routes
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": "Marketing Assistant AI",
|
||||
"version": "1.0.0",
|
||||
"description": f"AI-powered marketing copywriter for {config.BRAND_NAME}"
|
||||
}
|
||||
|
||||
@app.post("/generate-copy")
|
||||
async def generate_copy(request: GenerateCopyRequest):
|
||||
"""Generate marketing copy based on the provided prompt and parameters."""
|
||||
try:
|
||||
# Validate content type if provided
|
||||
if request.content_type and request.content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Generate copy
|
||||
result = await copywriter.generate_copy(
|
||||
prompt=request.prompt,
|
||||
content_type=request.content_type,
|
||||
length=request.length,
|
||||
include_cta=request.include_cta,
|
||||
reference_similar_content=request.reference_similar_content,
|
||||
max_tokens=request.max_tokens
|
||||
)
|
||||
|
||||
# Add timestamp
|
||||
result["metadata"]["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# Store the generated content in the vector store for future reference
|
||||
if result["content"]:
|
||||
metadata = {
|
||||
"content_type": request.content_type,
|
||||
"prompt": request.prompt,
|
||||
"generated": True
|
||||
}
|
||||
await vector_store.add_documents([result["content"]], [metadata])
|
||||
|
||||
# Store the user query for future training
|
||||
query_path = Path(config.DATA_DIR) / "user_queries" / f"{datetime.now().strftime('%Y%m%d%H%M%S')}.json"
|
||||
with open(query_path, 'w') as f:
|
||||
json.dump({
|
||||
"prompt": request.prompt,
|
||||
"parameters": {
|
||||
"content_type": request.content_type,
|
||||
"length": request.length,
|
||||
"include_cta": request.include_cta
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, indent=2)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"content": result["content"],
|
||||
"suggestions": result.get("suggestions", []),
|
||||
"metadata": result["metadata"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating copy: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate copy: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/brand-style")
|
||||
async def get_brand_style():
|
||||
"""Get the current brand style guidelines."""
|
||||
try:
|
||||
style = brand_style_manager.get_style_guidelines()
|
||||
return style
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting brand style: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get brand style: {str(e)}"
|
||||
)
|
||||
|
||||
@app.put("/brand-style")
|
||||
async def update_brand_style(request: BrandStyleUpdateRequest):
|
||||
"""Update the brand style guidelines."""
|
||||
try:
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
updated_style = brand_style_manager.update_style_guidelines(update_data)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Brand style updated successfully",
|
||||
"style": updated_style
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating brand style: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update brand style: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/training-data")
|
||||
async def add_training_data(request: TrainingDataRequest):
|
||||
"""Add new marketing content for AI training."""
|
||||
try:
|
||||
# Validate content type
|
||||
if request.content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare metadata
|
||||
metadata = request.metadata.copy()
|
||||
metadata["content_type"] = request.content_type
|
||||
metadata["added_at"] = datetime.now().isoformat()
|
||||
metadata["training_data"] = True
|
||||
|
||||
# Add to database
|
||||
query = training_data.insert().values(
|
||||
content=request.content,
|
||||
content_type=request.content_type,
|
||||
metadata=metadata,
|
||||
added_at=datetime.now(),
|
||||
is_training_data=True
|
||||
)
|
||||
data_id = await database.execute(query)
|
||||
|
||||
# Add to vector store for search functionality
|
||||
doc_ids = await vector_store.add_documents([request.content], [metadata])
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Training data added successfully",
|
||||
"data_id": data_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to add training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/training-data")
|
||||
async def list_training_data(
|
||||
content_type: Optional[str] = Query(None, description="Filter by content type"),
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Items per page")
|
||||
):
|
||||
"""Retrieve a list of available training data."""
|
||||
try:
|
||||
# Build base query
|
||||
base_query = select(training_data).where(training_data.c.is_training_data == True)
|
||||
|
||||
if content_type:
|
||||
if content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
base_query = base_query.where(training_data.c.content_type == content_type)
|
||||
|
||||
# Count total records
|
||||
count_query = select(func.count()).select_from(training_data).where(training_data.c.is_training_data == True)
|
||||
if content_type:
|
||||
count_query = count_query.where(training_data.c.content_type == content_type)
|
||||
total = await database.fetch_val(count_query)
|
||||
|
||||
# Add pagination
|
||||
query = base_query.order_by(training_data.c.added_at.desc()) \
|
||||
.offset((page - 1) * limit) \
|
||||
.limit(limit)
|
||||
|
||||
# Execute query
|
||||
records = await database.fetch_all(query)
|
||||
|
||||
# Format response
|
||||
items = []
|
||||
for record in records:
|
||||
preview = record["content"][:100] + "..." if len(record["content"]) > 100 else record["content"]
|
||||
items.append({
|
||||
"id": record["id"],
|
||||
"content_type": record["content_type"],
|
||||
"preview": preview,
|
||||
"added_at": record["added_at"].isoformat()
|
||||
})
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"pagination": {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"pages": (total + limit - 1) // limit
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to list training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/training-data/{data_id}")
|
||||
async def get_training_data(data_id: int):
|
||||
"""Retrieve a specific training document by ID."""
|
||||
try:
|
||||
query = select([training_data]).where(training_data.c.id == data_id)
|
||||
record = await database.fetch_one(query)
|
||||
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Document with ID {data_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": record["id"],
|
||||
"content": record["content"],
|
||||
"content_type": record["content_type"],
|
||||
"metadata": record["metadata"],
|
||||
"added_at": record["added_at"].isoformat()
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.delete("/training-data/{data_id}")
|
||||
async def delete_training_data(data_id: int):
|
||||
"""Delete a specific training document by ID."""
|
||||
try:
|
||||
query = training_data.delete().where(training_data.c.id == data_id)
|
||||
result = await database.execute(query)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Document with ID {data_id} not found or could not be deleted"
|
||||
)
|
||||
|
||||
# Also remove from vector store
|
||||
await vector_store.delete_document(data_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Document with ID {data_id} successfully deleted"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/improve-content")
|
||||
async def improve_content(request: ContentImprovementRequest):
|
||||
"""Improve content based on user feedback."""
|
||||
try:
|
||||
improved_content = await copywriter.improve_copy(
|
||||
content=request.content,
|
||||
feedback=request.feedback
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"original_content": request.content,
|
||||
"improved_content": improved_content,
|
||||
"feedback": request.feedback
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error improving content: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to improve content: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/analyze-content")
|
||||
async def analyze_content(content: str = Body(..., embed=True)):
|
||||
"""Analyze marketing content for performance prediction."""
|
||||
try:
|
||||
analysis = await copywriter.analyze_content_performance(content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"analysis": analysis
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing content: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze content: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/user-queries")
|
||||
async def list_user_queries(
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Items per page")
|
||||
):
|
||||
"""List user queries with pagination."""
|
||||
try:
|
||||
# Calculate offset
|
||||
offset = (page - 1) * limit
|
||||
|
||||
# Get files from user_queries directory
|
||||
query_dir = Path(config.DATA_DIR) / "user_queries"
|
||||
query_dir.mkdir(exist_ok=True)
|
||||
|
||||
# List all JSON files and sort by name (timestamp) in descending order
|
||||
files = sorted(query_dir.glob("*.json"), reverse=True)
|
||||
total = len(files)
|
||||
|
||||
# Apply pagination
|
||||
files = files[offset:offset + limit]
|
||||
|
||||
items = []
|
||||
for file in files:
|
||||
with open(file, 'r') as f:
|
||||
query_data = json.load(f)
|
||||
items.append(query_data)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"pagination": {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"pages": (total + limit - 1) // limit
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing user queries: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to list user queries: {str(e)}"
|
||||
)
|
||||
|
||||
# Run the application
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
reload=True
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, JSON, DateTime, Boolean, MetaData, Table, create_engine
|
||||
from databases import Database
|
||||
from config import DATA_DIR
|
||||
|
||||
DATABASE_URL = f"sqlite:///{DATA_DIR}/training_data.db"
|
||||
database = Database(DATABASE_URL)
|
||||
metadata = MetaData()
|
||||
|
||||
training_data = Table(
|
||||
"training_data",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("content", String, nullable=False),
|
||||
Column("content_type", String, nullable=False),
|
||||
Column("metadata", JSON, nullable=False),
|
||||
Column("added_at", DateTime, nullable=False, default=datetime.utcnow),
|
||||
Column("is_training_data", Boolean, nullable=False, default=True)
|
||||
)
|
||||
|
||||
# Create tables
|
||||
engine = create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
@@ -0,0 +1,15 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
python-dotenv
|
||||
httpx
|
||||
faiss-cpu
|
||||
numpy==1.26.2
|
||||
pandas
|
||||
cohere
|
||||
python-multipart
|
||||
SQLAlchemy
|
||||
databases
|
||||
aiosqlite
|
||||
loguru
|
||||
tenacity
|
||||
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Vector store module for the Marketing Assistant AI.
|
||||
Uses FAISS for efficient storage and retrieval of content embeddings.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import faiss
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
import config
|
||||
from embeddings import embeddings_manager
|
||||
|
||||
class VectorStore:
|
||||
"""Manages vector database operations for content retrieval."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VectorStore with FAISS index."""
|
||||
self.store_path = Path(config.VECTOR_DB_PATH)
|
||||
self.store_path.mkdir(exist_ok=True)
|
||||
|
||||
self.index_path = self.store_path / "faiss_index.bin"
|
||||
self.metadata_path = self.store_path / "metadata.pkl"
|
||||
|
||||
self.dimension = None
|
||||
self.index = None
|
||||
self.metadata = []
|
||||
|
||||
self._load_or_create_index()
|
||||
logger.info("VectorStore initialized successfully")
|
||||
|
||||
# Check if the index is empty and load sample data if needed
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Vector store is empty. Loading sample data...")
|
||||
self._load_sample_data()
|
||||
|
||||
def _load_or_create_index(self) -> None:
|
||||
"""Load existing index or create new one if it doesn't exist."""
|
||||
try:
|
||||
if self.index_path.exists() and self.metadata_path.exists():
|
||||
# Load existing index and metadata
|
||||
self.index = faiss.read_index(str(self.index_path))
|
||||
with open(self.metadata_path, 'rb') as f:
|
||||
self.metadata = pickle.load(f)
|
||||
self.dimension = self.index.d
|
||||
logger.info(f"Loaded existing vector index with {self.index.ntotal} vectors")
|
||||
else:
|
||||
# Default dimension for Cohere embeddings
|
||||
self.dimension = 1024
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info(f"Created new vector index with dimension {self.dimension}")
|
||||
|
||||
# Save the empty index and metadata
|
||||
self._save_index()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating index: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_index(self) -> None:
|
||||
"""Save the index and metadata to disk."""
|
||||
try:
|
||||
faiss.write_index(self.index, str(self.index_path))
|
||||
with open(self.metadata_path, 'wb') as f:
|
||||
pickle.dump(self.metadata, f)
|
||||
logger.debug("Saved vector index and metadata")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving index: {str(e)}")
|
||||
raise
|
||||
|
||||
async def add_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadata_list: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Add documents to the vector store.
|
||||
|
||||
Args:
|
||||
texts: List of text documents to add
|
||||
metadata_list: List of metadata dictionaries for each document
|
||||
|
||||
Returns:
|
||||
List of document IDs (vector indices)
|
||||
"""
|
||||
try:
|
||||
if not texts:
|
||||
logger.warning("No texts provided to add to vector store")
|
||||
return []
|
||||
|
||||
if metadata_list is None:
|
||||
metadata_list = [{} for _ in texts]
|
||||
|
||||
if len(texts) != len(metadata_list):
|
||||
raise ValueError("Number of texts and metadata entries must match")
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = await embeddings_manager.get_embeddings(texts)
|
||||
|
||||
# Check if embeddings match our dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
# If we have no documents yet, we can adapt to the new dimension
|
||||
if self.index.ntotal == 0:
|
||||
self.dimension = embeddings.shape[1]
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
logger.info(f"Adapted to new dimension: {self.dimension}")
|
||||
else:
|
||||
raise ValueError(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
|
||||
# Add timestamp to metadata
|
||||
timestamp = datetime.now().isoformat()
|
||||
for meta in metadata_list:
|
||||
meta['timestamp'] = timestamp
|
||||
meta['document_id'] = len(self.metadata) + len(metadata_list)
|
||||
|
||||
# Store texts in metadata
|
||||
for i, (text, meta) in enumerate(zip(texts, metadata_list)):
|
||||
meta['text'] = text
|
||||
|
||||
# Add vectors to index
|
||||
start_idx = self.index.ntotal
|
||||
self.index.add(embeddings.astype(np.float32))
|
||||
self.metadata.extend(metadata_list)
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
# Return document IDs
|
||||
doc_ids = list(range(start_idx, start_idx + len(texts)))
|
||||
logger.info(f"Added {len(texts)} documents to vector store")
|
||||
return doc_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding documents to vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
rerank: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar documents.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
top_k: Number of results to return
|
||||
filters: Dictionary of metadata filters
|
||||
rerank: Whether to use Cohere's reranking
|
||||
|
||||
Returns:
|
||||
List of result dictionaries with document content and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Searching vector store with query: {query[:50]}... (top_k={top_k})")
|
||||
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Empty vector store, no results to return")
|
||||
return []
|
||||
|
||||
logger.info(f"Vector store contains {self.index.ntotal} documents")
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embeddings_manager.get_query_embedding(query)
|
||||
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
||||
|
||||
# First pass: find more candidates than needed for reranking
|
||||
search_k = top_k * 3 if rerank else top_k
|
||||
search_k = min(search_k, self.index.ntotal) # Don't request more than we have
|
||||
|
||||
distances, indices = self.index.search(query_embedding, search_k)
|
||||
|
||||
# Get metadata and texts for matching indices
|
||||
results = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
if idx < 0 or idx >= len(self.metadata):
|
||||
continue # Skip invalid indices
|
||||
|
||||
metadata = self.metadata[idx]
|
||||
text = metadata.get('text', '')
|
||||
|
||||
# Apply filters if any
|
||||
if filters and not self._matches_filters(metadata, filters):
|
||||
continue
|
||||
|
||||
results.append({
|
||||
'document_id': idx,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text'},
|
||||
'distance': float(distances[0][i])
|
||||
})
|
||||
|
||||
# Apply reranking if requested
|
||||
if rerank and results:
|
||||
texts = [r['text'] for r in results]
|
||||
reranked = await embeddings_manager.rerank_results(query, texts, top_n=top_k)
|
||||
|
||||
# Map reranked results back to our original results
|
||||
reranked_results = []
|
||||
for item in reranked:
|
||||
orig_idx = item['index']
|
||||
if 0 <= orig_idx < len(results):
|
||||
reranked_results.append({
|
||||
**results[orig_idx],
|
||||
'relevance_score': item['relevance_score']
|
||||
})
|
||||
|
||||
results = reranked_results
|
||||
else:
|
||||
# Just take the top_k results
|
||||
results = results[:top_k]
|
||||
|
||||
logger.info(f"Found {len(results)} matching documents for query")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
|
||||
"""Check if metadata matches the specified filters."""
|
||||
for key, value in filters.items():
|
||||
if key not in metadata:
|
||||
return False
|
||||
|
||||
if isinstance(value, list):
|
||||
# Check if metadata value is in the list
|
||||
if metadata[key] not in value:
|
||||
return False
|
||||
elif metadata[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def delete_document(self, document_id: int) -> bool:
|
||||
"""
|
||||
Delete a document from the vector store.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to delete
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
# FAISS doesn't support direct deletion, so we need to rebuild the index
|
||||
# Mark the document as deleted in metadata
|
||||
self.metadata[document_id]['deleted'] = True
|
||||
|
||||
# Save updated metadata
|
||||
self._save_index()
|
||||
|
||||
logger.info(f"Marked document {document_id} as deleted")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_document(self, document_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a document by ID.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to retrieve
|
||||
|
||||
Returns:
|
||||
Document with metadata or None if not found
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return None
|
||||
|
||||
metadata = self.metadata[document_id]
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if metadata.get('deleted', False):
|
||||
logger.warning(f"Document {document_id} is marked as deleted")
|
||||
return None
|
||||
|
||||
text = metadata.get('text', '')
|
||||
|
||||
return {
|
||||
'document_id': document_id,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text' and k != 'deleted'}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_document(self, document_id: int, text: str, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Update a document in the vector store.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to update
|
||||
text: New document text
|
||||
metadata: New metadata (will be merged with existing)
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
# Get existing metadata
|
||||
existing_metadata = self.metadata[document_id]
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if existing_metadata.get('deleted', False):
|
||||
logger.warning(f"Cannot update deleted document {document_id}")
|
||||
return False
|
||||
|
||||
# Generate new embedding
|
||||
embeddings = await embeddings_manager.get_embeddings([text])
|
||||
|
||||
# Update the vector in the index
|
||||
faiss.IndexFlatL2_update_vectors(self.index, embeddings.astype(np.float32), np.array([document_id], dtype=np.int64))
|
||||
|
||||
# Update metadata
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
existing_metadata[key] = value
|
||||
|
||||
existing_metadata['text'] = text
|
||||
existing_metadata['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
logger.info(f"Updated document {document_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_sample_data(self) -> None:
|
||||
"""Load sample data from past campaigns into the vector store."""
|
||||
try:
|
||||
# Path to past campaigns directory
|
||||
campaigns_dir = Path(config.DATA_DIR) / "past_campaigns"
|
||||
|
||||
if not campaigns_dir.exists() or not campaigns_dir.is_dir():
|
||||
logger.warning(f"Past campaigns directory not found: {campaigns_dir}")
|
||||
return
|
||||
|
||||
# Find all JSON files in the directory
|
||||
campaign_files = list(campaigns_dir.glob("*.json"))
|
||||
if not campaign_files:
|
||||
logger.warning("No campaign files found in past_campaigns directory")
|
||||
return
|
||||
|
||||
# Load and process each campaign file
|
||||
texts = []
|
||||
metadata_list = []
|
||||
|
||||
for file_path in campaign_files:
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
campaign_data = json.load(f)
|
||||
|
||||
# Extract content and metadata
|
||||
if 'content' in campaign_data:
|
||||
texts.append(campaign_data['content'])
|
||||
|
||||
# Create metadata entry
|
||||
metadata = {
|
||||
'content_type': campaign_data.get('content_type', 'unknown'),
|
||||
'campaign_name': campaign_data.get('metadata', {}).get('campaign_name', file_path.stem),
|
||||
'source': 'past_campaign',
|
||||
'file_path': str(file_path)
|
||||
}
|
||||
|
||||
# Add performance metrics if available
|
||||
if 'metadata' in campaign_data and 'performance_metrics' in campaign_data['metadata']:
|
||||
metadata['performance_metrics'] = campaign_data['metadata']['performance_metrics']
|
||||
|
||||
metadata_list.append(metadata)
|
||||
logger.debug(f"Loaded campaign from {file_path.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading campaign file {file_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not texts:
|
||||
logger.warning("No valid campaign content found in files")
|
||||
return
|
||||
|
||||
# Add documents to vector store
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
doc_ids = loop.run_until_complete(self.add_documents(texts, metadata_list))
|
||||
logger.info(f"Added {len(doc_ids)} past campaigns to vector store")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading sample data: {str(e)}")
|
||||
|
||||
# Create a singleton instance
|
||||
vector_store = VectorStore()
|
||||
Reference in New Issue
Block a user