feat: Initial implementation of Marketing Assistant AI for Adriana James
- Set up FastAPI backend with modular structure: - main.py for API routing - copywriter.py for AI-powered content generation using Cohere - embeddings.py for generating and reranking content embeddings - vector_store.py for FAISS-based similarity search - brand_style.py for managing brand tone, taboo words, and preferred terms - config.py for managing environment and application settings - Configured RESTful API endpoints: /generate-copy, /brand-style, /training-data, /improve-content, /analyze-content - Created frontend with vanilla HTML, CSS, and JS (index.html, styles.css, app.js) - Integrated brand style management for tone, voice, taboo words, and terminology - Implemented vector search for referencing similar historical content - Enabled training data input to improve future AI output - Added environment variable support for API keys and model configs - Structured data storage with local JSON and DB files - Added developer documentation, API reference, and project setup instructions This commit provides the foundation for a full-stack, AI-driven content creation platform that ensures brand consistency, speeds up marketing workflows, and supports iterative improvement over time.
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Brand style module for the Marketing Assistant AI.
|
||||
Ensures generated content aligns with Adriana James' brand voice and tone.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Any, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
import config
|
||||
|
||||
class BrandStyleManager:
|
||||
"""Manages brand style guidelines and ensures content consistency."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BrandStyleManager with default or stored style guidelines."""
|
||||
self.style_path = Path(config.DATA_DIR) / "style_guidelines" / "brand_style.json"
|
||||
self.style_guidelines = self._load_or_create_style()
|
||||
logger.info("BrandStyleManager initialized successfully")
|
||||
|
||||
def _load_or_create_style(self) -> Dict[str, Any]:
|
||||
"""Load existing style guidelines or create new ones with defaults."""
|
||||
try:
|
||||
if self.style_path.exists():
|
||||
with open(self.style_path, 'r') as f:
|
||||
style = json.load(f)
|
||||
logger.info("Loaded existing brand style guidelines")
|
||||
return style
|
||||
else:
|
||||
# Create directory if it doesn't exist
|
||||
self.style_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
# Use default style guidelines
|
||||
style = config.DEFAULT_BRAND_STYLE
|
||||
|
||||
# Save default style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(style, f, indent=2)
|
||||
|
||||
logger.info("Created default brand style guidelines")
|
||||
return style
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating style guidelines: {str(e)}")
|
||||
# Fall back to default style
|
||||
return config.DEFAULT_BRAND_STYLE
|
||||
|
||||
def get_style_guidelines(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current brand style guidelines.
|
||||
|
||||
Returns:
|
||||
Dictionary of style guidelines
|
||||
"""
|
||||
return self.style_guidelines
|
||||
|
||||
def update_style_guidelines(self, new_style: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Update brand style guidelines.
|
||||
|
||||
Args:
|
||||
new_style: Dictionary with new style guidelines
|
||||
|
||||
Returns:
|
||||
Updated style guidelines dictionary
|
||||
"""
|
||||
try:
|
||||
# Merge new style with existing
|
||||
for key, value in new_style.items():
|
||||
self.style_guidelines[key] = value
|
||||
|
||||
# Ensure brand name is preserved
|
||||
self.style_guidelines['brand_name'] = config.BRAND_NAME
|
||||
|
||||
# Save updated style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(self.style_guidelines, f, indent=2)
|
||||
|
||||
logger.info("Updated brand style guidelines")
|
||||
return self.style_guidelines
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating style guidelines: {str(e)}")
|
||||
raise
|
||||
|
||||
def format_prompt_with_brand_style(self, user_prompt: str, content_type: Optional[str] = None) -> str:
|
||||
"""
|
||||
Format user prompt with brand style guidelines for the LLM.
|
||||
|
||||
Args:
|
||||
user_prompt: Original user prompt
|
||||
content_type: Type of content being generated
|
||||
|
||||
Returns:
|
||||
Formatted prompt with brand style instructions
|
||||
"""
|
||||
style = self.style_guidelines
|
||||
|
||||
# Create a formatted prompt with brand style instructions
|
||||
prompt_parts = [
|
||||
f"Generate marketing content for {style['brand_name']} based on the following request:",
|
||||
f"\"{user_prompt}\"",
|
||||
"\nFollow these brand style guidelines:",
|
||||
f"- Brand Name: {style['brand_name']}",
|
||||
f"- Tone: {', '.join(style.get('tone', []))}",
|
||||
f"- Voice Characteristics: {', '.join(style.get('voice_characteristics', []))}",
|
||||
]
|
||||
|
||||
# Add taboo words if any
|
||||
if 'taboo_words' in style and style['taboo_words']:
|
||||
prompt_parts.append(f"- Avoid these words: {', '.join(style['taboo_words'])}")
|
||||
|
||||
# Add preferred terms if any
|
||||
if 'preferred_terms' in style and style['preferred_terms']:
|
||||
terms = [f"use '{value}' instead of '{key}'" for key, value in style['preferred_terms'].items()]
|
||||
prompt_parts.append(f"- Preferred terminology: {'; '.join(terms)}")
|
||||
|
||||
# Add content type specific instructions
|
||||
if content_type:
|
||||
if content_type == "email_campaign":
|
||||
prompt_parts.append("- Format as a professional email with subject line, greeting, body, and signature")
|
||||
elif content_type == "social_media":
|
||||
prompt_parts.append("- Format as a concise social media post with appropriate hashtags")
|
||||
elif content_type == "blog_post":
|
||||
prompt_parts.append("- Format as a blog post with title, introduction, body with subheadings, and conclusion")
|
||||
elif content_type == "website_copy":
|
||||
prompt_parts.append("- Format as website copy with clear headings and concise paragraphs")
|
||||
elif content_type == "ad_copy":
|
||||
prompt_parts.append("- Format as advertising copy with headline, body, and clear call to action")
|
||||
|
||||
# Combine all parts
|
||||
formatted_prompt = "\n".join(prompt_parts)
|
||||
|
||||
logger.debug("Created formatted prompt with brand style")
|
||||
return formatted_prompt
|
||||
|
||||
def check_content_alignment(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if generated content aligns with brand style guidelines.
|
||||
|
||||
Args:
|
||||
content: Generated marketing content
|
||||
|
||||
Returns:
|
||||
Dictionary with alignment metrics and suggestions
|
||||
"""
|
||||
style = self.style_guidelines
|
||||
taboo_words = style.get('taboo_words', [])
|
||||
preferred_terms = style.get('preferred_terms', {})
|
||||
|
||||
# Check for taboo words
|
||||
found_taboo_words = []
|
||||
for word in taboo_words:
|
||||
if word.lower() in content.lower():
|
||||
found_taboo_words.append(word)
|
||||
|
||||
# Check for preferred terminology
|
||||
terminology_issues = []
|
||||
for avoid, use in preferred_terms.items():
|
||||
if avoid.lower() in content.lower():
|
||||
terminology_issues.append(f"Found '{avoid}', should use '{use}' instead")
|
||||
|
||||
# Calculate an overall alignment score (simple implementation)
|
||||
issues_count = len(found_taboo_words) + len(terminology_issues)
|
||||
alignment_score = max(0, 100 - (issues_count * 10)) # Reduce score for each issue
|
||||
|
||||
return {
|
||||
'alignment_score': alignment_score,
|
||||
'taboo_words_found': found_taboo_words,
|
||||
'terminology_issues': terminology_issues,
|
||||
'aligned': alignment_score >= 80 # Consider aligned if score is 80% or higher
|
||||
}
|
||||
|
||||
# Create a singleton instance
|
||||
brand_style_manager = BrandStyleManager()
|
||||
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Configuration module for the Marketing Assistant AI.
|
||||
Handles environment variables and application settings.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Base paths
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = BASE_DIR / "data"
|
||||
|
||||
# Ensure data directories exist
|
||||
(DATA_DIR / "past_campaigns").mkdir(exist_ok=True)
|
||||
(DATA_DIR / "user_queries").mkdir(exist_ok=True)
|
||||
(DATA_DIR / "style_guidelines").mkdir(exist_ok=True)
|
||||
|
||||
# API configuration
|
||||
API_HOST = os.getenv("API_HOST", "localhost")
|
||||
API_PORT = int(os.getenv("API_PORT", 8000))
|
||||
|
||||
# LLM configuration
|
||||
LLM_MODEL = os.getenv("LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("LLM_API_KEY")
|
||||
|
||||
# Cohere configuration
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
|
||||
# Vector database configuration
|
||||
VECTOR_DB_PATH = os.getenv("VECTOR_DB_PATH", str(DATA_DIR / "vector_store"))
|
||||
|
||||
# Brand configuration
|
||||
BRAND_NAME = os.getenv("BRAND_NAME", "Adriana James")
|
||||
|
||||
# Content types
|
||||
CONTENT_TYPES = [
|
||||
"email_campaign",
|
||||
"social_media",
|
||||
"blog_post",
|
||||
"website_copy",
|
||||
"ad_copy",
|
||||
"funnel_page",
|
||||
"product_description",
|
||||
"press_release"
|
||||
]
|
||||
|
||||
# Tone options
|
||||
TONE_OPTIONS = [
|
||||
"professional",
|
||||
"friendly",
|
||||
"excited",
|
||||
"authoritative",
|
||||
"casual",
|
||||
"inspirational",
|
||||
"empathetic",
|
||||
"humorous"
|
||||
]
|
||||
|
||||
# Content length options
|
||||
LENGTH_OPTIONS = [
|
||||
"short", # < 100 words
|
||||
"medium", # 100-300 words
|
||||
"long", # > 300 words
|
||||
]
|
||||
|
||||
# Default brand style guidelines
|
||||
DEFAULT_BRAND_STYLE = {
|
||||
"brand_name": BRAND_NAME,
|
||||
"tone": ["professional", "friendly", "inspirational"],
|
||||
"voice_characteristics": ["clear", "direct", "empowering"],
|
||||
"taboo_words": ["cheap", "discount", "bargain"],
|
||||
"preferred_terms": {
|
||||
"customers": "clients",
|
||||
"products": "solutions"
|
||||
}
|
||||
}
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FILE = os.getenv("LOG_FILE", str(BASE_DIR / "logs" / "app.log"))
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
(BASE_DIR / "logs").mkdir(exist_ok=True)
|
||||
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Copywriter module for the Marketing Assistant AI.
|
||||
Core AI-powered content generation using a fine-tuned LLM.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import httpx
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
import config
|
||||
from brand_style import brand_style_manager
|
||||
from vector_store import vector_store
|
||||
|
||||
class Copywriter:
|
||||
"""Generates marketing copy using a fine-tuned LLM."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Copywriter with Cohere LLM client."""
|
||||
self.model = "command" # Cohere's generation model
|
||||
self.api_key = config.COHERE_API_KEY
|
||||
logger.info("Copywriter initialized with Cohere API successfully")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def generate_copy(
|
||||
self,
|
||||
prompt: str,
|
||||
content_type: Optional[str] = None,
|
||||
tone: Optional[str] = None,
|
||||
length: Optional[str] = None,
|
||||
include_cta: bool = False,
|
||||
reference_similar_content: bool = True,
|
||||
max_tokens: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate marketing copy based on the user prompt and parameters.
|
||||
|
||||
Args:
|
||||
prompt: User prompt for content generation
|
||||
content_type: Type of content to generate
|
||||
tone: Desired tone of the content
|
||||
length: Desired length of the content
|
||||
include_cta: Whether to include a call to action
|
||||
reference_similar_content: Whether to fetch and reference similar content
|
||||
max_tokens: Maximum tokens for the generated response
|
||||
|
||||
Returns:
|
||||
Dictionary with generated content and metadata
|
||||
"""
|
||||
try:
|
||||
# Step 1: Format prompt with brand style guidelines
|
||||
branded_prompt = brand_style_manager.format_prompt_with_brand_style(prompt, content_type)
|
||||
|
||||
# Step 2: Find similar content for reference (if enabled)
|
||||
reference_content = []
|
||||
if reference_similar_content:
|
||||
search_results = await vector_store.search(prompt, top_k=3)
|
||||
if search_results:
|
||||
reference_content = [result['text'] for result in search_results]
|
||||
|
||||
# Step 3: Add additional instructions based on parameters
|
||||
full_prompt = branded_prompt
|
||||
|
||||
if tone:
|
||||
full_prompt += f"\n- Use a {tone} tone"
|
||||
|
||||
if length:
|
||||
length_instructions = {
|
||||
"short": "Keep the content brief and to the point (under 100 words).",
|
||||
"medium": "Write a moderate amount of content (100-300 words).",
|
||||
"long": "Create comprehensive content with depth (over 300 words)."
|
||||
}
|
||||
full_prompt += f"\n- {length_instructions.get(length, '')}"
|
||||
|
||||
if include_cta:
|
||||
full_prompt += "\n- Include a strong call to action at the end"
|
||||
|
||||
# Step 4: Add reference content if available
|
||||
if reference_content:
|
||||
full_prompt += "\n\nFor reference, here are some similar pieces of content that have performed well in the past:"
|
||||
for i, content in enumerate(reference_content, 1):
|
||||
# Truncate reference content if it's too long
|
||||
preview = content[:300] + "..." if len(content) > 300 else content
|
||||
full_prompt += f"\n\nReference {i}:\n{preview}"
|
||||
|
||||
full_prompt += "\n\nUse these references for inspiration, but create original content."
|
||||
|
||||
# Step 5: Generate content using the LLM
|
||||
generated_content = await self._call_llm_api(full_prompt, max_tokens)
|
||||
|
||||
# Step 6: Check content alignment with brand style
|
||||
alignment_check = brand_style_manager.check_content_alignment(generated_content)
|
||||
|
||||
# Step 7: Generate alternative headline suggestions
|
||||
headline_suggestions = await self._generate_headline_suggestions(prompt, generated_content)
|
||||
|
||||
# Step 8: Return the generated content with metadata
|
||||
result = {
|
||||
"content": generated_content,
|
||||
"suggestions": headline_suggestions,
|
||||
"metadata": {
|
||||
"content_type": content_type,
|
||||
"tone": tone,
|
||||
"alignment_score": alignment_check['alignment_score'],
|
||||
"generated_at": None # Will be added by the API
|
||||
}
|
||||
}
|
||||
|
||||
# Add alignment issues if any
|
||||
if alignment_check['taboo_words_found'] or alignment_check['terminology_issues']:
|
||||
result["alignment_issues"] = {
|
||||
"taboo_words_found": alignment_check['taboo_words_found'],
|
||||
"terminology_issues": alignment_check['terminology_issues']
|
||||
}
|
||||
|
||||
logger.info(f"Generated content with {len(generated_content)} characters")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating copy: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def _call_llm_api(self, prompt: str, max_tokens: int = 1000) -> str:
|
||||
"""
|
||||
Call the Cohere API to generate content.
|
||||
|
||||
Args:
|
||||
prompt: The formatted prompt for the LLM
|
||||
max_tokens: Maximum tokens for the generated response
|
||||
|
||||
Returns:
|
||||
Generated content as a string
|
||||
"""
|
||||
try:
|
||||
# Use Cohere's generate API with the API key from config
|
||||
cohere_api_key = config.COHERE_API_KEY
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.cohere.ai/v1/generate",
|
||||
headers={
|
||||
"Authorization": f"Bearer {cohere_api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": "command", # Cohere's generation model
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["generations"][0]["text"].strip()
|
||||
else:
|
||||
logger.error(f"Cohere API error: {response.status_code}, {response.text}")
|
||||
raise Exception(f"Cohere API error: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Cohere API: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_headline_suggestions(self, original_prompt: str, generated_content: str) -> List[str]:
|
||||
"""
|
||||
Generate alternative headline suggestions based on the content.
|
||||
|
||||
Args:
|
||||
original_prompt: The original user prompt
|
||||
generated_content: The generated marketing content
|
||||
|
||||
Returns:
|
||||
List of headline suggestions
|
||||
"""
|
||||
try:
|
||||
# This would call the LLM to generate headlines
|
||||
# Simplified mock response for demonstration
|
||||
return [
|
||||
"Alternative Headline 1: Discover the Power of Adriana James' Solutions",
|
||||
"Alternative Headline 2: Transform Your Results with Adriana James",
|
||||
"Alternative Headline 3: The Adriana James Approach: Excellence Redefined"
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating headline suggestions: {str(e)}")
|
||||
return []
|
||||
|
||||
async def improve_copy(self, content: str, feedback: str) -> str:
|
||||
"""
|
||||
Improve content based on user feedback.
|
||||
|
||||
Args:
|
||||
content: Original generated content
|
||||
feedback: User feedback for improvement
|
||||
|
||||
Returns:
|
||||
Improved content
|
||||
"""
|
||||
try:
|
||||
# Format prompt for improvement
|
||||
improve_prompt = f"""
|
||||
Please improve the following marketing content based on the feedback provided:
|
||||
|
||||
ORIGINAL CONTENT:
|
||||
{content}
|
||||
|
||||
FEEDBACK:
|
||||
{feedback}
|
||||
|
||||
IMPROVED CONTENT:
|
||||
"""
|
||||
|
||||
# Call LLM to improve content
|
||||
improved_content = await self._call_llm_api(improve_prompt, max_tokens=1200)
|
||||
|
||||
logger.info(f"Improved content based on feedback")
|
||||
return improved_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error improving content: {str(e)}")
|
||||
raise
|
||||
|
||||
async def analyze_content_performance(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze marketing content for performance prediction.
|
||||
|
||||
Args:
|
||||
content: Marketing content to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with analysis results
|
||||
"""
|
||||
try:
|
||||
# This would be enhanced with actual ML models in production
|
||||
# Simplified mock response for demonstration
|
||||
|
||||
# Very basic analysis using length and keyword presence
|
||||
word_count = len(content.split())
|
||||
has_cta = any(phrase in content.lower() for phrase in ["call", "contact", "get started", "try", "buy", "sign up"])
|
||||
sentence_count = len([s for s in content.split(".") if s.strip()])
|
||||
avg_words_per_sentence = word_count / max(1, sentence_count)
|
||||
|
||||
# Simple scoring system
|
||||
readability_score = 100 - min(100, max(0, abs(avg_words_per_sentence - 15) * 5))
|
||||
cta_score = 90 if has_cta else 60
|
||||
length_score = min(100, max(0, word_count / 3))
|
||||
|
||||
overall_score = (readability_score + cta_score + length_score) / 3
|
||||
|
||||
return {
|
||||
"overall_score": round(overall_score, 1),
|
||||
"readability_score": round(readability_score, 1),
|
||||
"cta_effectiveness": round(cta_score, 1),
|
||||
"length_appropriateness": round(length_score, 1),
|
||||
"metrics": {
|
||||
"word_count": word_count,
|
||||
"sentence_count": sentence_count,
|
||||
"avg_words_per_sentence": round(avg_words_per_sentence, 1),
|
||||
"has_cta": has_cta
|
||||
},
|
||||
"improvement_suggestions": [
|
||||
"Consider adding a stronger call to action" if cta_score < 80 else "Your call to action is effective",
|
||||
"Try to use shorter sentences for better readability" if avg_words_per_sentence > 20 else "Your sentence length is good for readability",
|
||||
"Consider adding more content for better engagement" if word_count < 100 else "Your content length is appropriate"
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing content: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create a singleton instance
|
||||
copywriter = Copywriter()
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Embeddings module for the Marketing Assistant AI.
|
||||
Uses Cohere to generate and manage text embeddings.
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
import config
|
||||
|
||||
class EmbeddingsManager:
|
||||
"""Manages the generation and manipulation of text embeddings using Cohere."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the EmbeddingsManager with Cohere API client."""
|
||||
try:
|
||||
self.co = cohere.Client(config.COHERE_API_KEY)
|
||||
logger.info("EmbeddingsManager initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize EmbeddingsManager: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def get_embeddings(self, texts: List[str], model: str = "embed-english-v3.0") -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to embed
|
||||
model: Cohere embedding model to use
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of embeddings vectors
|
||||
"""
|
||||
try:
|
||||
if not texts:
|
||||
logger.warning("Empty text list provided for embedding")
|
||||
return np.array([])
|
||||
|
||||
# Ensure texts are not too long for the API
|
||||
processed_texts = [text[:8192] for text in texts]
|
||||
|
||||
response = self.co.embed(
|
||||
texts=processed_texts,
|
||||
model=model,
|
||||
input_type="search_document"
|
||||
)
|
||||
|
||||
embeddings = np.array(response.embeddings)
|
||||
logger.debug(f"Generated {len(embeddings)} embeddings with shape {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def get_query_embedding(self, text: str, model: str = "embed-english-v3.0") -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single query text.
|
||||
|
||||
Args:
|
||||
text: The query text to embed
|
||||
model: Cohere embedding model to use
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Embedding vector for the query
|
||||
"""
|
||||
try:
|
||||
response = self.co.embed(
|
||||
texts=[text[:8192]],
|
||||
model=model,
|
||||
input_type="search_query"
|
||||
)
|
||||
|
||||
embedding = np.array(response.embeddings[0])
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating query embedding: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def rerank_results(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[str],
|
||||
model: str = "rerank-english-v2.0",
|
||||
top_n: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents based on relevance to the query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of documents to rerank
|
||||
model: Cohere reranking model to use
|
||||
top_n: Number of top results to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with document index and relevance score
|
||||
"""
|
||||
try:
|
||||
if not documents:
|
||||
logger.warning("Empty document list provided for reranking")
|
||||
return []
|
||||
|
||||
# Truncate documents if they're too long
|
||||
processed_docs = [doc[:8192] for doc in documents]
|
||||
|
||||
response = self.co.rerank(
|
||||
query=query,
|
||||
documents=processed_docs,
|
||||
model=model,
|
||||
top_n=min(top_n, len(processed_docs))
|
||||
)
|
||||
|
||||
results = [
|
||||
{
|
||||
"index": result.index,
|
||||
"document": documents[result.index],
|
||||
"relevance_score": result.relevance_score
|
||||
}
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
logger.debug(f"Reranked {len(documents)} documents, returning top {len(results)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reranking documents: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create a singleton instance
|
||||
embeddings_manager = EmbeddingsManager()
|
||||
+406
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
Main FastAPI application for the Marketing Assistant AI.
|
||||
Provides API endpoints for generating and managing marketing content.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query, Body, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import config
|
||||
from copywriter import copywriter
|
||||
from vector_store import vector_store
|
||||
from brand_style import brand_style_manager
|
||||
from embeddings import embeddings_manager
|
||||
|
||||
# Initialize logging
|
||||
logger.add(config.LOG_FILE, level=config.LOG_LEVEL, rotation="10 MB", retention="1 month")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Marketing Assistant AI",
|
||||
description="AI-powered tool for marketing copywriting with Adriana James' brand voice",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify your frontend domain
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Define request and response models
|
||||
class GenerateCopyRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The main instruction for generating content")
|
||||
content_type: Optional[str] = Field(None, description="Type of content to generate")
|
||||
tone: Optional[str] = Field(None, description="Desired tone of the content")
|
||||
length: Optional[str] = Field(None, description="Desired length of the content")
|
||||
include_cta: Optional[bool] = Field(False, description="Whether to include a call to action")
|
||||
reference_similar_content: Optional[bool] = Field(True, description="Whether to reference similar content")
|
||||
max_tokens: Optional[int] = Field(1000, description="Maximum tokens for the generated response")
|
||||
|
||||
class TrainingDataRequest(BaseModel):
|
||||
content_type: str = Field(..., description="Type of content")
|
||||
content: str = Field(..., description="The marketing content")
|
||||
metadata: Optional[Dict[str, Any]] = Field({}, description="Additional metadata about the content")
|
||||
|
||||
class BrandStyleUpdateRequest(BaseModel):
|
||||
tone: Optional[List[str]] = Field(None, description="Brand tone options")
|
||||
voice_characteristics: Optional[List[str]] = Field(None, description="Voice characteristics")
|
||||
taboo_words: Optional[List[str]] = Field(None, description="Words to avoid")
|
||||
preferred_terms: Optional[Dict[str, str]] = Field(None, description="Preferred terminology")
|
||||
|
||||
class ContentImprovementRequest(BaseModel):
|
||||
content: str = Field(..., description="Original generated content")
|
||||
feedback: str = Field(..., description="User feedback for improvement")
|
||||
|
||||
# API Routes
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": "Marketing Assistant AI",
|
||||
"version": "1.0.0",
|
||||
"description": f"AI-powered marketing copywriter for {config.BRAND_NAME}"
|
||||
}
|
||||
|
||||
@app.post("/generate-copy")
|
||||
async def generate_copy(request: GenerateCopyRequest):
|
||||
"""Generate marketing copy based on the provided prompt and parameters."""
|
||||
try:
|
||||
# Validate content type if provided
|
||||
if request.content_type and request.content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Validate tone if provided
|
||||
if request.tone and request.tone not in config.TONE_OPTIONS:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid tone. Must be one of: {', '.join(config.TONE_OPTIONS)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Validate length if provided
|
||||
if request.length and request.length not in config.LENGTH_OPTIONS:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid length. Must be one of: {', '.join(config.LENGTH_OPTIONS)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Generate copy
|
||||
result = await copywriter.generate_copy(
|
||||
prompt=request.prompt,
|
||||
content_type=request.content_type,
|
||||
tone=request.tone,
|
||||
length=request.length,
|
||||
include_cta=request.include_cta,
|
||||
reference_similar_content=request.reference_similar_content,
|
||||
max_tokens=request.max_tokens
|
||||
)
|
||||
|
||||
# Add timestamp
|
||||
result["metadata"]["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# Store the generated content in the vector store for future reference
|
||||
if result["content"]:
|
||||
metadata = {
|
||||
"content_type": request.content_type,
|
||||
"tone": request.tone,
|
||||
"prompt": request.prompt,
|
||||
"generated": True
|
||||
}
|
||||
await vector_store.add_documents([result["content"]], [metadata])
|
||||
|
||||
# Store the user query for future training
|
||||
query_path = Path(config.DATA_DIR) / "user_queries" / f"{datetime.now().strftime('%Y%m%d%H%M%S')}.json"
|
||||
with open(query_path, 'w') as f:
|
||||
json.dump({
|
||||
"prompt": request.prompt,
|
||||
"parameters": {
|
||||
"content_type": request.content_type,
|
||||
"tone": request.tone,
|
||||
"length": request.length,
|
||||
"include_cta": request.include_cta
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, indent=2)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"content": result["content"],
|
||||
"suggestions": result.get("suggestions", []),
|
||||
"metadata": result["metadata"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating copy: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate copy: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/brand-style")
|
||||
async def get_brand_style():
|
||||
"""Get the current brand style guidelines."""
|
||||
try:
|
||||
style = brand_style_manager.get_style_guidelines()
|
||||
return style
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting brand style: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get brand style: {str(e)}"
|
||||
)
|
||||
|
||||
@app.put("/brand-style")
|
||||
async def update_brand_style(request: BrandStyleUpdateRequest):
|
||||
"""Update the brand style guidelines."""
|
||||
try:
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
updated_style = brand_style_manager.update_style_guidelines(update_data)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Brand style updated successfully",
|
||||
"style": updated_style
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating brand style: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to update brand style: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/training-data")
|
||||
async def add_training_data(request: TrainingDataRequest):
|
||||
"""Add new marketing content for AI training."""
|
||||
try:
|
||||
# Validate content type
|
||||
if request.content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
|
||||
# Add metadata
|
||||
metadata = request.metadata.copy()
|
||||
metadata["content_type"] = request.content_type
|
||||
metadata["added_at"] = datetime.now().isoformat()
|
||||
metadata["training_data"] = True
|
||||
|
||||
# Add to vector store
|
||||
doc_ids = await vector_store.add_documents([request.content], [metadata])
|
||||
|
||||
# Save to past campaigns
|
||||
campaign_path = Path(config.DATA_DIR) / "past_campaigns" / f"{datetime.now().strftime('%Y%m%d%H%M%S')}.json"
|
||||
with open(campaign_path, 'w') as f:
|
||||
json.dump({
|
||||
"content": request.content,
|
||||
"content_type": request.content_type,
|
||||
"metadata": metadata,
|
||||
"document_id": doc_ids[0] if doc_ids else None,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, indent=2)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Training data added successfully",
|
||||
"data_id": doc_ids[0] if doc_ids else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to add training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/training-data")
|
||||
async def list_training_data(
|
||||
content_type: Optional[str] = Query(None, description="Filter by content type"),
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Items per page")
|
||||
):
|
||||
"""Retrieve a list of available training data."""
|
||||
try:
|
||||
# Build filters
|
||||
filters = {}
|
||||
if content_type:
|
||||
if content_type not in config.CONTENT_TYPES:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"status": "error",
|
||||
"message": f"Invalid content_type. Must be one of: {', '.join(config.CONTENT_TYPES)}"
|
||||
}
|
||||
)
|
||||
filters["content_type"] = content_type
|
||||
|
||||
filters["training_data"] = True
|
||||
|
||||
# Fetch all matching documents first (not efficient for large datasets but works for demo)
|
||||
all_docs = []
|
||||
for i in range(len(vector_store.metadata)):
|
||||
doc = await vector_store.get_document(i)
|
||||
if doc and all(doc["metadata"].get(k) == v for k, v in filters.items()):
|
||||
all_docs.append(doc)
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
all_docs.sort(key=lambda x: x["metadata"].get("added_at", ""), reverse=True)
|
||||
|
||||
# Paginate
|
||||
total = len(all_docs)
|
||||
pages = (total + limit - 1) // limit if total > 0 else 1
|
||||
start = (page - 1) * limit
|
||||
end = start + limit
|
||||
paginated_docs = all_docs[start:end]
|
||||
|
||||
# Format the response
|
||||
items = []
|
||||
for doc in paginated_docs:
|
||||
# Get a preview of the text (first 100 characters)
|
||||
preview = doc["text"][:100] + "..." if len(doc["text"]) > 100 else doc["text"]
|
||||
|
||||
items.append({
|
||||
"id": doc["document_id"],
|
||||
"content_type": doc["metadata"].get("content_type", "unknown"),
|
||||
"preview": preview,
|
||||
"added_at": doc["metadata"].get("added_at", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"pagination": {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"pages": pages
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to list training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.get("/training-data/{document_id}")
|
||||
async def get_training_data(document_id: int):
|
||||
"""Retrieve a specific training document by ID."""
|
||||
try:
|
||||
doc = await vector_store.get_document(document_id)
|
||||
if not doc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Document with ID {document_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": doc["document_id"],
|
||||
"content": doc["text"],
|
||||
"metadata": doc["metadata"]
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.delete("/training-data/{document_id}")
|
||||
async def delete_training_data(document_id: int):
|
||||
"""Delete a specific training document by ID."""
|
||||
try:
|
||||
success = await vector_store.delete_document(document_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Document with ID {document_id} not found or could not be deleted"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Document with ID {document_id} successfully deleted"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting training data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete training data: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/improve-content")
|
||||
async def improve_content(request: ContentImprovementRequest):
|
||||
"""Improve content based on user feedback."""
|
||||
try:
|
||||
improved_content = await copywriter.improve_copy(
|
||||
content=request.content,
|
||||
feedback=request.feedback
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"original_content": request.content,
|
||||
"improved_content": improved_content,
|
||||
"feedback": request.feedback
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error improving content: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to improve content: {str(e)}"
|
||||
)
|
||||
|
||||
@app.post("/analyze-content")
|
||||
async def analyze_content(content: str = Body(..., embed=True)):
|
||||
"""Analyze marketing content for performance prediction."""
|
||||
try:
|
||||
analysis = await copywriter.analyze_content_performance(content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"analysis": analysis
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing content: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to analyze content: {str(e)}"
|
||||
)
|
||||
|
||||
# Run the application
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=config.API_HOST,
|
||||
port=config.API_PORT,
|
||||
reload=True
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
python-dotenv
|
||||
httpx
|
||||
faiss-cpu
|
||||
numpy==1.26.2
|
||||
pandas
|
||||
cohere
|
||||
python-multipart
|
||||
SQLAlchemy
|
||||
databases
|
||||
aiosqlite
|
||||
loguru
|
||||
tenacity
|
||||
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
Vector store module for the Marketing Assistant AI.
|
||||
Uses FAISS for efficient storage and retrieval of content embeddings.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import faiss
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
import config
|
||||
from embeddings import embeddings_manager
|
||||
|
||||
class VectorStore:
|
||||
"""Manages vector database operations for content retrieval."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VectorStore with FAISS index."""
|
||||
self.store_path = Path(config.VECTOR_DB_PATH)
|
||||
self.store_path.mkdir(exist_ok=True)
|
||||
|
||||
self.index_path = self.store_path / "faiss_index.bin"
|
||||
self.metadata_path = self.store_path / "metadata.pkl"
|
||||
|
||||
self.dimension = None
|
||||
self.index = None
|
||||
self.metadata = []
|
||||
|
||||
self._load_or_create_index()
|
||||
logger.info("VectorStore initialized successfully")
|
||||
|
||||
def _load_or_create_index(self) -> None:
|
||||
"""Load existing index or create new one if it doesn't exist."""
|
||||
try:
|
||||
if self.index_path.exists() and self.metadata_path.exists():
|
||||
# Load existing index and metadata
|
||||
self.index = faiss.read_index(str(self.index_path))
|
||||
with open(self.metadata_path, 'rb') as f:
|
||||
self.metadata = pickle.load(f)
|
||||
self.dimension = self.index.d
|
||||
logger.info(f"Loaded existing vector index with {self.index.ntotal} vectors")
|
||||
else:
|
||||
# Default dimension for Cohere embeddings
|
||||
self.dimension = 1024
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info(f"Created new vector index with dimension {self.dimension}")
|
||||
|
||||
# Save the empty index and metadata
|
||||
self._save_index()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating index: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_index(self) -> None:
|
||||
"""Save the index and metadata to disk."""
|
||||
try:
|
||||
faiss.write_index(self.index, str(self.index_path))
|
||||
with open(self.metadata_path, 'wb') as f:
|
||||
pickle.dump(self.metadata, f)
|
||||
logger.debug("Saved vector index and metadata")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving index: {str(e)}")
|
||||
raise
|
||||
|
||||
async def add_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadata_list: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Add documents to the vector store.
|
||||
|
||||
Args:
|
||||
texts: List of text documents to add
|
||||
metadata_list: List of metadata dictionaries for each document
|
||||
|
||||
Returns:
|
||||
List of document IDs (vector indices)
|
||||
"""
|
||||
try:
|
||||
if not texts:
|
||||
logger.warning("No texts provided to add to vector store")
|
||||
return []
|
||||
|
||||
if metadata_list is None:
|
||||
metadata_list = [{} for _ in texts]
|
||||
|
||||
if len(texts) != len(metadata_list):
|
||||
raise ValueError("Number of texts and metadata entries must match")
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = await embeddings_manager.get_embeddings(texts)
|
||||
|
||||
# Check if embeddings match our dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
# If we have no documents yet, we can adapt to the new dimension
|
||||
if self.index.ntotal == 0:
|
||||
self.dimension = embeddings.shape[1]
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
logger.info(f"Adapted to new dimension: {self.dimension}")
|
||||
else:
|
||||
raise ValueError(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
|
||||
# Add timestamp to metadata
|
||||
timestamp = datetime.now().isoformat()
|
||||
for meta in metadata_list:
|
||||
meta['timestamp'] = timestamp
|
||||
meta['document_id'] = len(self.metadata) + len(metadata_list)
|
||||
|
||||
# Store texts in metadata
|
||||
for i, (text, meta) in enumerate(zip(texts, metadata_list)):
|
||||
meta['text'] = text
|
||||
|
||||
# Add vectors to index
|
||||
start_idx = self.index.ntotal
|
||||
self.index.add(embeddings.astype(np.float32))
|
||||
self.metadata.extend(metadata_list)
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
# Return document IDs
|
||||
doc_ids = list(range(start_idx, start_idx + len(texts)))
|
||||
logger.info(f"Added {len(texts)} documents to vector store")
|
||||
return doc_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding documents to vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
rerank: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar documents.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
top_k: Number of results to return
|
||||
filters: Dictionary of metadata filters
|
||||
rerank: Whether to use Cohere's reranking
|
||||
|
||||
Returns:
|
||||
List of result dictionaries with document content and metadata
|
||||
"""
|
||||
try:
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Empty vector store, no results to return")
|
||||
return []
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embeddings_manager.get_query_embedding(query)
|
||||
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
||||
|
||||
# First pass: find more candidates than needed for reranking
|
||||
search_k = top_k * 3 if rerank else top_k
|
||||
search_k = min(search_k, self.index.ntotal) # Don't request more than we have
|
||||
|
||||
distances, indices = self.index.search(query_embedding, search_k)
|
||||
|
||||
# Get metadata and texts for matching indices
|
||||
results = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
if idx < 0 or idx >= len(self.metadata):
|
||||
continue # Skip invalid indices
|
||||
|
||||
metadata = self.metadata[idx]
|
||||
text = metadata.get('text', '')
|
||||
|
||||
# Apply filters if any
|
||||
if filters and not self._matches_filters(metadata, filters):
|
||||
continue
|
||||
|
||||
results.append({
|
||||
'document_id': idx,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text'},
|
||||
'distance': float(distances[0][i])
|
||||
})
|
||||
|
||||
# Apply reranking if requested
|
||||
if rerank and results:
|
||||
texts = [r['text'] for r in results]
|
||||
reranked = await embeddings_manager.rerank_results(query, texts, top_n=top_k)
|
||||
|
||||
# Map reranked results back to our original results
|
||||
reranked_results = []
|
||||
for item in reranked:
|
||||
orig_idx = item['index']
|
||||
if 0 <= orig_idx < len(results):
|
||||
reranked_results.append({
|
||||
**results[orig_idx],
|
||||
'relevance_score': item['relevance_score']
|
||||
})
|
||||
|
||||
results = reranked_results
|
||||
else:
|
||||
# Just take the top_k results
|
||||
results = results[:top_k]
|
||||
|
||||
logger.info(f"Found {len(results)} matching documents for query")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
|
||||
"""Check if metadata matches the specified filters."""
|
||||
for key, value in filters.items():
|
||||
if key not in metadata:
|
||||
return False
|
||||
|
||||
if isinstance(value, list):
|
||||
# Check if metadata value is in the list
|
||||
if metadata[key] not in value:
|
||||
return False
|
||||
elif metadata[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def delete_document(self, document_id: int) -> bool:
|
||||
"""
|
||||
Delete a document from the vector store.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to delete
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
# FAISS doesn't support direct deletion, so we need to rebuild the index
|
||||
# Mark the document as deleted in metadata
|
||||
self.metadata[document_id]['deleted'] = True
|
||||
|
||||
# Save updated metadata
|
||||
self._save_index()
|
||||
|
||||
logger.info(f"Marked document {document_id} as deleted")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_document(self, document_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a document by ID.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to retrieve
|
||||
|
||||
Returns:
|
||||
Document with metadata or None if not found
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return None
|
||||
|
||||
metadata = self.metadata[document_id]
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if metadata.get('deleted', False):
|
||||
logger.warning(f"Document {document_id} is marked as deleted")
|
||||
return None
|
||||
|
||||
text = metadata.get('text', '')
|
||||
|
||||
return {
|
||||
'document_id': document_id,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text' and k != 'deleted'}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_document(self, document_id: int, text: str, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Update a document in the vector store.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to update
|
||||
text: New document text
|
||||
metadata: New metadata (will be merged with existing)
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
try:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
# Get existing metadata
|
||||
existing_metadata = self.metadata[document_id]
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if existing_metadata.get('deleted', False):
|
||||
logger.warning(f"Cannot update deleted document {document_id}")
|
||||
return False
|
||||
|
||||
# Generate new embedding
|
||||
embeddings = await embeddings_manager.get_embeddings([text])
|
||||
|
||||
# Update the vector in the index
|
||||
faiss.IndexFlatL2_update_vectors(self.index, embeddings.astype(np.float32), np.array([document_id], dtype=np.int64))
|
||||
|
||||
# Update metadata
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
existing_metadata[key] = value
|
||||
|
||||
existing_metadata['text'] = text
|
||||
existing_metadata['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
logger.info(f"Updated document {document_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create a singleton instance
|
||||
vector_store = VectorStore()
|
||||
Reference in New Issue
Block a user