feat(feedback): Add content improvement feedback system
Frontend (frontend/app.js): - Add textarea for improvement feedback - Add submit button with loading state - Handle API response and display improved content Backend (backend/copywriter.py): - Add improve_copy() method using Cohere API - Integrate retry mechanism for API calls Backend (backend/main.py): - Add /improve-content POST endpoint - Implement error handling and return improved content with metadata Testing: - Verified feedback submission flow - Confirmed improved content generation - Tested error scenarios and loading states
This commit is contained in:
+39
-35
@@ -13,7 +13,7 @@ import config
|
||||
|
||||
class BrandStyleManager:
|
||||
"""Manages brand style guidelines and ensures content consistency."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BrandStyleManager with default or stored style guidelines."""
|
||||
self.style_path = Path(config.DATA_DIR) / "style_guidelines" / "brand_style.json"
|
||||
@@ -117,7 +117,7 @@ class BrandStyleManager:
|
||||
"""
|
||||
}
|
||||
logger.info("BrandStyleManager initialized successfully")
|
||||
|
||||
|
||||
def _load_or_create_style(self) -> Dict[str, Any]:
|
||||
"""Load existing style guidelines or create new ones with defaults."""
|
||||
try:
|
||||
@@ -129,37 +129,37 @@ class BrandStyleManager:
|
||||
else:
|
||||
# Create directory if it doesn't exist
|
||||
self.style_path.parent.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# Use default style guidelines
|
||||
style = config.DEFAULT_BRAND_STYLE
|
||||
|
||||
|
||||
# Save default style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(style, f, indent=2)
|
||||
|
||||
|
||||
logger.info("Created default brand style guidelines")
|
||||
return style
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating style guidelines: {str(e)}")
|
||||
# Fall back to default style
|
||||
return config.DEFAULT_BRAND_STYLE
|
||||
|
||||
|
||||
def get_style_guidelines(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current brand style guidelines.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of style guidelines
|
||||
"""
|
||||
return self.style_guidelines
|
||||
|
||||
|
||||
def update_style_guidelines(self, new_style: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Update brand style guidelines.
|
||||
|
||||
|
||||
Args:
|
||||
new_style: Dictionary with new style guidelines
|
||||
|
||||
|
||||
Returns:
|
||||
Updated style guidelines dictionary
|
||||
"""
|
||||
@@ -167,37 +167,41 @@ class BrandStyleManager:
|
||||
# Merge new style with existing
|
||||
for key, value in new_style.items():
|
||||
self.style_guidelines[key] = value
|
||||
|
||||
|
||||
# Ensure brand name is preserved
|
||||
self.style_guidelines['brand_name'] = config.BRAND_NAME
|
||||
|
||||
|
||||
# Save updated style
|
||||
with open(self.style_path, 'w') as f:
|
||||
json.dump(self.style_guidelines, f, indent=2)
|
||||
|
||||
|
||||
logger.info("Updated brand style guidelines")
|
||||
return self.style_guidelines
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating style guidelines: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def format_prompt_with_brand_style(self, user_prompt: str, content_type: Optional[str] = None) -> str:
|
||||
"""Format user prompt to match the established writing style."""
|
||||
|
||||
"""Format user prompt to match the distinctive communication style."""
|
||||
|
||||
style_instructions = [
|
||||
"Follow these writing style guidelines:",
|
||||
"- Use direct commands that empower the reader",
|
||||
"- Address the reader directly using 'you' and 'your'",
|
||||
"- Create rhythmic, repetitive patterns in key messages",
|
||||
"- Maintain a clear, confident, and authoritative tone",
|
||||
"- Use simple, practical language without jargon",
|
||||
"- Acknowledge challenges while focusing on solutions",
|
||||
"- Include empowering phrases that emphasize reader's control and choice"
|
||||
"Follow these distinctive communication style guidelines:",
|
||||
"- Use empowering, assertive language that inspires action",
|
||||
"- Address the reader directly using 'you' and 'your' with conviction",
|
||||
"- Create rhythmic, repetitive patterns in key messages for emphasis",
|
||||
"- Maintain a clear, confident, and conversational teaching tone",
|
||||
"- Use simple, practical language that communicates profound ideas",
|
||||
"- Use embedded commands (e.g., 'Decide now to change your thinking')",
|
||||
"- Include cause-effect statements (e.g., 'Because you understand this, you will now take action')",
|
||||
"- Speak with conviction and clarity rather than hesitation",
|
||||
"- Replace tentative phrases with confident declarations",
|
||||
"- Use a motivational coach-like clarity in all communications",
|
||||
"- IMPORTANT: Do not mention any specific person's name in the content"
|
||||
]
|
||||
|
||||
# Content type specific formatting
|
||||
content_format = self._get_content_format(content_type) if content_type else ""
|
||||
|
||||
|
||||
return "\n".join([
|
||||
f"Generate content based on this request:",
|
||||
f"\"{user_prompt}\"",
|
||||
@@ -205,37 +209,37 @@ class BrandStyleManager:
|
||||
"\n".join(style_instructions),
|
||||
content_format
|
||||
])
|
||||
|
||||
|
||||
def check_content_alignment(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if generated content aligns with brand style guidelines.
|
||||
|
||||
|
||||
Args:
|
||||
content: Generated marketing content
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with alignment metrics and suggestions
|
||||
"""
|
||||
style = self.style_guidelines
|
||||
taboo_words = style.get('taboo_words', [])
|
||||
preferred_terms = style.get('preferred_terms', {})
|
||||
|
||||
|
||||
# Check for taboo words
|
||||
found_taboo_words = []
|
||||
for word in taboo_words:
|
||||
if word.lower() in content.lower():
|
||||
found_taboo_words.append(word)
|
||||
|
||||
|
||||
# Check for preferred terminology
|
||||
terminology_issues = []
|
||||
for avoid, use in preferred_terms.items():
|
||||
if avoid.lower() in content.lower():
|
||||
terminology_issues.append(f"Found '{avoid}', should use '{use}' instead")
|
||||
|
||||
|
||||
# Calculate an overall alignment score (simple implementation)
|
||||
issues_count = len(found_taboo_words) + len(terminology_issues)
|
||||
alignment_score = max(0, 100 - (issues_count * 10)) # Reduce score for each issue
|
||||
|
||||
|
||||
return {
|
||||
'alignment_score': alignment_score,
|
||||
'taboo_words_found': found_taboo_words,
|
||||
@@ -246,16 +250,16 @@ class BrandStyleManager:
|
||||
def _get_content_format(self, content_type: str) -> str:
|
||||
"""
|
||||
Get formatting instructions for specific content type.
|
||||
|
||||
|
||||
Args:
|
||||
content_type: Type of content to generate
|
||||
|
||||
|
||||
Returns:
|
||||
Formatting instructions as string
|
||||
"""
|
||||
if not content_type:
|
||||
return ""
|
||||
|
||||
|
||||
format_instructions = self.content_formats.get(content_type, "")
|
||||
if format_instructions:
|
||||
return f"\nContent type specific instructions:\n{format_instructions.strip()}"
|
||||
|
||||
+13
-10
@@ -52,12 +52,12 @@ CONTENT_TYPES = [
|
||||
"newsletter"
|
||||
]
|
||||
|
||||
# Tone options - simplified to match the core style
|
||||
# Tone options - specifically matching Adriana James' communication style
|
||||
TONE_OPTIONS = [
|
||||
"direct",
|
||||
"empowering",
|
||||
"confident",
|
||||
"practical"
|
||||
"assertive",
|
||||
"inspirational",
|
||||
"direct"
|
||||
]
|
||||
|
||||
# Content length options
|
||||
@@ -67,19 +67,22 @@ LENGTH_OPTIONS = [
|
||||
"long", # > 300 words
|
||||
]
|
||||
|
||||
# Default brand style guidelines
|
||||
# Default brand style guidelines - fixed to match Adriana James' distinct communication style
|
||||
DEFAULT_BRAND_STYLE = {
|
||||
"tone": ["direct", "empowering", "confident", "practical"],
|
||||
"voice_characteristics": ["clear", "authoritative", "steady", "rhythmic"],
|
||||
"writing_patterns": ["direct commands", "personal pronouns", "repetitive rhythms"],
|
||||
"taboo_words": ["cheap", "discount", "bargain", "failure", "impossible", "difficult"],
|
||||
"tone": ["empowering", "assertive", "inspirational", "direct"],
|
||||
"voice_characteristics": ["clear", "confident", "conversational", "teaching"],
|
||||
"writing_patterns": ["direct commands", "personal pronouns", "repetitive rhythms", "embedded commands", "cause-effect statements"],
|
||||
"taboo_words": ["cheap", "discount", "bargain", "failure", "impossible", "difficult", "might", "try", "consider"],
|
||||
"preferred_terms": {
|
||||
"problems": "challenges",
|
||||
"try": "take action",
|
||||
"difficult": "ready for growth",
|
||||
"failure": "learning opportunity",
|
||||
"hope": "know",
|
||||
"maybe": "will"
|
||||
"maybe": "will",
|
||||
"might help you": "you can do this",
|
||||
"consider doing this": "decide now to change your thinking",
|
||||
"this could work": "this works because"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+97
-42
@@ -16,13 +16,13 @@ from vector_store import vector_store
|
||||
|
||||
class Copywriter:
|
||||
"""Generates marketing copy using a fine-tuned LLM."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Copywriter with Cohere LLM client."""
|
||||
self.model = "command" # Cohere's generation model
|
||||
self.api_key = config.COHERE_API_KEY
|
||||
logger.info("Copywriter initialized with Cohere API successfully")
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def generate_copy(
|
||||
self,
|
||||
@@ -40,34 +40,43 @@ class Copywriter:
|
||||
try:
|
||||
# Step 1: Format prompt with brand style guidelines
|
||||
branded_prompt = brand_style_manager.format_prompt_with_brand_style(prompt, content_type)
|
||||
|
||||
|
||||
# Step 2: Find similar content for reference (if enabled)
|
||||
reference_content = []
|
||||
if reference_similar_content:
|
||||
logger.info(f"Searching for similar content to reference for prompt: {prompt[:50]}...")
|
||||
search_results = await vector_store.search(prompt, top_k=3)
|
||||
if search_results:
|
||||
reference_content = [result['text'] for result in search_results]
|
||||
|
||||
logger.info(f"Found {len(reference_content)} similar content items to reference")
|
||||
for i, content in enumerate(reference_content):
|
||||
logger.debug(f"Reference content {i+1}: {content[:100]}...")
|
||||
else:
|
||||
logger.warning("No similar content found in vector store for reference")
|
||||
|
||||
# Step 3: Add length and CTA instructions if needed
|
||||
if length:
|
||||
branded_prompt += f"\n- Generate {length} content"
|
||||
if include_cta:
|
||||
branded_prompt += "\n- Include a direct, empowering call to action"
|
||||
|
||||
|
||||
# Step 4: Add reference content if available
|
||||
if reference_content:
|
||||
branded_prompt += "\n\nReference these successful examples for tone and style:\n"
|
||||
branded_prompt += "\n---\n".join(reference_content)
|
||||
|
||||
|
||||
# Step 5: Generate content using the LLM
|
||||
generated_content = await self._call_llm_api(branded_prompt, max_tokens)
|
||||
|
||||
# Step 6: Check content alignment with brand style
|
||||
|
||||
# Step 6: Post-process to remove any mentions of Adriana James
|
||||
generated_content = self._remove_name_mentions(generated_content)
|
||||
|
||||
# Step 7: Check content alignment with brand style
|
||||
alignment_check = brand_style_manager.check_content_alignment(generated_content)
|
||||
|
||||
|
||||
# Step 7: Generate alternative headline suggestions
|
||||
headline_suggestions = await self._generate_headline_suggestions(prompt, generated_content)
|
||||
|
||||
|
||||
# Step 8: Return the generated content with metadata
|
||||
result = {
|
||||
"content": generated_content,
|
||||
@@ -79,21 +88,21 @@ class Copywriter:
|
||||
"generated_at": None # Will be added by the API
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Add alignment issues if any
|
||||
if alignment_check['taboo_words_found'] or alignment_check['terminology_issues']:
|
||||
result["alignment_issues"] = {
|
||||
"taboo_words_found": alignment_check['taboo_words_found'],
|
||||
"terminology_issues": alignment_check['terminology_issues']
|
||||
}
|
||||
|
||||
|
||||
logger.info(f"Generated content with {len(generated_content)} characters")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating copy: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
async def _call_llm_api(self, prompt: str, max_tokens: int = 1000) -> str:
|
||||
"""
|
||||
@@ -102,12 +111,11 @@ class Copywriter:
|
||||
Args:
|
||||
prompt: The formatted prompt for the LLM
|
||||
max_tokens: Maximum tokens for the generated response
|
||||
|
||||
|
||||
Returns:
|
||||
Generated content as a string
|
||||
Generated content as a string with preserved formatting
|
||||
"""
|
||||
try:
|
||||
# Use Cohere's generate API with the API key from config
|
||||
cohere_api_key = config.COHERE_API_KEY
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -118,19 +126,30 @@ class Copywriter:
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": "command", # Cohere's generation model
|
||||
"prompt": prompt,
|
||||
"model": "command",
|
||||
"prompt": f"{prompt}\n\nNote: Please preserve formatting with proper paragraphs, line breaks, and bullet points where appropriate.",
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
"k": 0,
|
||||
"p": 0.75
|
||||
"p": 0.75,
|
||||
"return_likelihoods": "NONE"
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["generations"][0]["text"].strip()
|
||||
generated_text = result["generations"][0]["text"].strip()
|
||||
|
||||
# Preserve paragraph breaks and formatting
|
||||
formatted_text = (
|
||||
generated_text
|
||||
.replace("\n\n", "<paragraph-break>") # Preserve paragraph breaks
|
||||
.replace("\n- ", "\n• ") # Convert hyphens to bullets
|
||||
.replace("<paragraph-break>", "\n\n") # Restore paragraph breaks
|
||||
)
|
||||
|
||||
return formatted_text
|
||||
else:
|
||||
logger.error(f"Cohere API error: {response.status_code}, {response.text}")
|
||||
raise Exception(f"Cohere API error: {response.status_code}")
|
||||
@@ -138,24 +157,25 @@ class Copywriter:
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Cohere API: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _generate_headline_suggestions(self, original_prompt: str, generated_content: str) -> List[str]:
|
||||
"""
|
||||
Generate alternative headline suggestions based on the content.
|
||||
|
||||
|
||||
Args:
|
||||
original_prompt: The original user prompt
|
||||
generated_content: The generated marketing content
|
||||
|
||||
|
||||
Returns:
|
||||
List of headline suggestions
|
||||
"""
|
||||
try:
|
||||
# Create a prompt for headline generation
|
||||
headline_prompt = f"""
|
||||
Generate 3 alternative marketing headlines for the following content.
|
||||
Generate 3 alternative marketing headlines for the following content.
|
||||
Make headlines compelling, concise, and aligned with the content's message.
|
||||
Each headline should be unique and capture attention.
|
||||
IMPORTANT: Do not mention any specific person's name in the headlines.
|
||||
|
||||
ORIGINAL PROMPT:
|
||||
{original_prompt}
|
||||
@@ -179,6 +199,9 @@ class Copywriter:
|
||||
if headline.strip() and not headline.lower().startswith(('headline', 'title', '-', '*', '•'))
|
||||
]
|
||||
|
||||
# Remove any mentions of Adriana James from headlines
|
||||
headlines = [self._remove_name_mentions(headline) for headline in headlines]
|
||||
|
||||
# Ensure we have exactly 3 headlines
|
||||
if len(headlines) > 3:
|
||||
headlines = headlines[:3]
|
||||
@@ -192,15 +215,15 @@ class Copywriter:
|
||||
logger.error(f"Error generating headline suggestions: {str(e)}")
|
||||
# Return empty list instead of mock response on error
|
||||
return []
|
||||
|
||||
|
||||
async def improve_copy(self, content: str, feedback: str) -> str:
|
||||
"""
|
||||
Improve content based on user feedback.
|
||||
|
||||
|
||||
Args:
|
||||
content: Original generated content
|
||||
feedback: User feedback for improvement
|
||||
|
||||
|
||||
Returns:
|
||||
Improved content
|
||||
"""
|
||||
@@ -208,53 +231,57 @@ class Copywriter:
|
||||
# Format prompt for improvement
|
||||
improve_prompt = f"""
|
||||
Please improve the following marketing content based on the feedback provided:
|
||||
|
||||
IMPORTANT: Do not mention any specific person's name in the content.
|
||||
|
||||
ORIGINAL CONTENT:
|
||||
{content}
|
||||
|
||||
|
||||
FEEDBACK:
|
||||
{feedback}
|
||||
|
||||
|
||||
IMPROVED CONTENT:
|
||||
"""
|
||||
|
||||
|
||||
# Call LLM to improve content
|
||||
improved_content = await self._call_llm_api(improve_prompt, max_tokens=1200)
|
||||
|
||||
|
||||
# Remove any mentions of Adriana James from improved content
|
||||
improved_content = self._remove_name_mentions(improved_content)
|
||||
|
||||
logger.info(f"Improved content based on feedback")
|
||||
return improved_content
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error improving content: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def analyze_content_performance(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze marketing content for performance prediction.
|
||||
|
||||
|
||||
Args:
|
||||
content: Marketing content to analyze
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with analysis results
|
||||
"""
|
||||
try:
|
||||
# This would be enhanced with actual ML models in production
|
||||
# Simplified mock response for demonstration
|
||||
|
||||
|
||||
# Very basic analysis using length and keyword presence
|
||||
word_count = len(content.split())
|
||||
has_cta = any(phrase in content.lower() for phrase in ["call", "contact", "get started", "try", "buy", "sign up"])
|
||||
sentence_count = len([s for s in content.split(".") if s.strip()])
|
||||
avg_words_per_sentence = word_count / max(1, sentence_count)
|
||||
|
||||
|
||||
# Simple scoring system
|
||||
readability_score = 100 - min(100, max(0, abs(avg_words_per_sentence - 15) * 5))
|
||||
cta_score = 90 if has_cta else 60
|
||||
length_score = min(100, max(0, word_count / 3))
|
||||
|
||||
|
||||
overall_score = (readability_score + cta_score + length_score) / 3
|
||||
|
||||
|
||||
return {
|
||||
"overall_score": round(overall_score, 1),
|
||||
"readability_score": round(readability_score, 1),
|
||||
@@ -272,10 +299,38 @@ class Copywriter:
|
||||
"Consider adding more content for better engagement" if word_count < 100 else "Your content length is appropriate"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing content: {str(e)}")
|
||||
raise
|
||||
|
||||
def _remove_name_mentions(self, content: str) -> str:
|
||||
"""
|
||||
Remove any mentions of specific names from the generated content.
|
||||
|
||||
Args:
|
||||
content: The generated content to process
|
||||
|
||||
Returns:
|
||||
Content with name mentions removed
|
||||
"""
|
||||
try:
|
||||
# Remove any mentions of "Adriana James" (case insensitive)
|
||||
import re
|
||||
pattern = re.compile(r'\bAdriana\s+James\b', re.IGNORECASE)
|
||||
content = pattern.sub('', content)
|
||||
|
||||
# Clean up any double spaces that might result from the removal
|
||||
content = re.sub(r'\s+', ' ', content)
|
||||
|
||||
# Clean up any lines that might now be empty
|
||||
content = '\n'.join([line for line in content.split('\n') if line.strip()])
|
||||
|
||||
logger.info("Removed any name mentions from generated content")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing name mentions: {str(e)}")
|
||||
return content
|
||||
|
||||
# Create a singleton instance
|
||||
copywriter = Copywriter()
|
||||
|
||||
+143
-70
@@ -18,22 +18,27 @@ from embeddings import embeddings_manager
|
||||
|
||||
class VectorStore:
|
||||
"""Manages vector database operations for content retrieval."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VectorStore with FAISS index."""
|
||||
self.store_path = Path(config.VECTOR_DB_PATH)
|
||||
self.store_path.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
self.index_path = self.store_path / "faiss_index.bin"
|
||||
self.metadata_path = self.store_path / "metadata.pkl"
|
||||
|
||||
|
||||
self.dimension = None
|
||||
self.index = None
|
||||
self.metadata = []
|
||||
|
||||
|
||||
self._load_or_create_index()
|
||||
logger.info("VectorStore initialized successfully")
|
||||
|
||||
|
||||
# Check if the index is empty and load sample data if needed
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Vector store is empty. Loading sample data...")
|
||||
self._load_sample_data()
|
||||
|
||||
def _load_or_create_index(self) -> None:
|
||||
"""Load existing index or create new one if it doesn't exist."""
|
||||
try:
|
||||
@@ -46,17 +51,17 @@ class VectorStore:
|
||||
logger.info(f"Loaded existing vector index with {self.index.ntotal} vectors")
|
||||
else:
|
||||
# Default dimension for Cohere embeddings
|
||||
self.dimension = 1024
|
||||
self.dimension = 1024
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info(f"Created new vector index with dimension {self.dimension}")
|
||||
|
||||
|
||||
# Save the empty index and metadata
|
||||
self._save_index()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or creating index: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _save_index(self) -> None:
|
||||
"""Save the index and metadata to disk."""
|
||||
try:
|
||||
@@ -67,19 +72,19 @@ class VectorStore:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving index: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def add_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
self,
|
||||
texts: List[str],
|
||||
metadata_list: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Add documents to the vector store.
|
||||
|
||||
|
||||
Args:
|
||||
texts: List of text documents to add
|
||||
metadata_list: List of metadata dictionaries for each document
|
||||
|
||||
|
||||
Returns:
|
||||
List of document IDs (vector indices)
|
||||
"""
|
||||
@@ -87,16 +92,16 @@ class VectorStore:
|
||||
if not texts:
|
||||
logger.warning("No texts provided to add to vector store")
|
||||
return []
|
||||
|
||||
|
||||
if metadata_list is None:
|
||||
metadata_list = [{} for _ in texts]
|
||||
|
||||
|
||||
if len(texts) != len(metadata_list):
|
||||
raise ValueError("Number of texts and metadata entries must match")
|
||||
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = await embeddings_manager.get_embeddings(texts)
|
||||
|
||||
|
||||
# Check if embeddings match our dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
@@ -107,93 +112,97 @@ class VectorStore:
|
||||
logger.info(f"Adapted to new dimension: {self.dimension}")
|
||||
else:
|
||||
raise ValueError(f"Embedding dimension mismatch: expected {self.dimension}, got {embeddings.shape[1]}")
|
||||
|
||||
|
||||
# Add timestamp to metadata
|
||||
timestamp = datetime.now().isoformat()
|
||||
for meta in metadata_list:
|
||||
meta['timestamp'] = timestamp
|
||||
meta['document_id'] = len(self.metadata) + len(metadata_list)
|
||||
|
||||
|
||||
# Store texts in metadata
|
||||
for i, (text, meta) in enumerate(zip(texts, metadata_list)):
|
||||
meta['text'] = text
|
||||
|
||||
|
||||
# Add vectors to index
|
||||
start_idx = self.index.ntotal
|
||||
self.index.add(embeddings.astype(np.float32))
|
||||
self.metadata.extend(metadata_list)
|
||||
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
|
||||
# Return document IDs
|
||||
doc_ids = list(range(start_idx, start_idx + len(texts)))
|
||||
logger.info(f"Added {len(texts)} documents to vector store")
|
||||
return doc_ids
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding documents to vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
rerank: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar documents.
|
||||
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
top_k: Number of results to return
|
||||
filters: Dictionary of metadata filters
|
||||
rerank: Whether to use Cohere's reranking
|
||||
|
||||
|
||||
Returns:
|
||||
List of result dictionaries with document content and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Searching vector store with query: {query[:50]}... (top_k={top_k})")
|
||||
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Empty vector store, no results to return")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"Vector store contains {self.index.ntotal} documents")
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embeddings_manager.get_query_embedding(query)
|
||||
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
||||
|
||||
|
||||
# First pass: find more candidates than needed for reranking
|
||||
search_k = top_k * 3 if rerank else top_k
|
||||
search_k = min(search_k, self.index.ntotal) # Don't request more than we have
|
||||
|
||||
|
||||
distances, indices = self.index.search(query_embedding, search_k)
|
||||
|
||||
|
||||
# Get metadata and texts for matching indices
|
||||
results = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
if idx < 0 or idx >= len(self.metadata):
|
||||
continue # Skip invalid indices
|
||||
|
||||
|
||||
metadata = self.metadata[idx]
|
||||
text = metadata.get('text', '')
|
||||
|
||||
|
||||
# Apply filters if any
|
||||
if filters and not self._matches_filters(metadata, filters):
|
||||
continue
|
||||
|
||||
|
||||
results.append({
|
||||
'document_id': idx,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text'},
|
||||
'distance': float(distances[0][i])
|
||||
})
|
||||
|
||||
|
||||
# Apply reranking if requested
|
||||
if rerank and results:
|
||||
texts = [r['text'] for r in results]
|
||||
reranked = await embeddings_manager.rerank_results(query, texts, top_n=top_k)
|
||||
|
||||
|
||||
# Map reranked results back to our original results
|
||||
reranked_results = []
|
||||
for item in reranked:
|
||||
@@ -203,41 +212,41 @@ class VectorStore:
|
||||
**results[orig_idx],
|
||||
'relevance_score': item['relevance_score']
|
||||
})
|
||||
|
||||
|
||||
results = reranked_results
|
||||
else:
|
||||
# Just take the top_k results
|
||||
results = results[:top_k]
|
||||
|
||||
|
||||
logger.info(f"Found {len(results)} matching documents for query")
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching vector store: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
|
||||
"""Check if metadata matches the specified filters."""
|
||||
for key, value in filters.items():
|
||||
if key not in metadata:
|
||||
return False
|
||||
|
||||
|
||||
if isinstance(value, list):
|
||||
# Check if metadata value is in the list
|
||||
if metadata[key] not in value:
|
||||
return False
|
||||
elif metadata[key] != value:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_document(self, document_id: int) -> bool:
|
||||
"""
|
||||
Delete a document from the vector store.
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to delete
|
||||
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
@@ -245,28 +254,28 @@ class VectorStore:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
|
||||
# FAISS doesn't support direct deletion, so we need to rebuild the index
|
||||
# Mark the document as deleted in metadata
|
||||
self.metadata[document_id]['deleted'] = True
|
||||
|
||||
|
||||
# Save updated metadata
|
||||
self._save_index()
|
||||
|
||||
|
||||
logger.info(f"Marked document {document_id} as deleted")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting document: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_document(self, document_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve a document by ID.
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to retrieve
|
||||
|
||||
|
||||
Returns:
|
||||
Document with metadata or None if not found
|
||||
"""
|
||||
@@ -274,35 +283,35 @@ class VectorStore:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return None
|
||||
|
||||
|
||||
metadata = self.metadata[document_id]
|
||||
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if metadata.get('deleted', False):
|
||||
logger.warning(f"Document {document_id} is marked as deleted")
|
||||
return None
|
||||
|
||||
|
||||
text = metadata.get('text', '')
|
||||
|
||||
|
||||
return {
|
||||
'document_id': document_id,
|
||||
'text': text,
|
||||
'metadata': {k: v for k, v in metadata.items() if k != 'text' and k != 'deleted'}
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def update_document(self, document_id: int, text: str, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Update a document in the vector store.
|
||||
|
||||
|
||||
Args:
|
||||
document_id: ID of the document to update
|
||||
text: New document text
|
||||
metadata: New metadata (will be merged with existing)
|
||||
|
||||
|
||||
Returns:
|
||||
Boolean indicating success
|
||||
"""
|
||||
@@ -310,38 +319,102 @@ class VectorStore:
|
||||
if document_id < 0 or document_id >= len(self.metadata):
|
||||
logger.warning(f"Invalid document ID: {document_id}")
|
||||
return False
|
||||
|
||||
|
||||
# Get existing metadata
|
||||
existing_metadata = self.metadata[document_id]
|
||||
|
||||
|
||||
# Check if document is marked as deleted
|
||||
if existing_metadata.get('deleted', False):
|
||||
logger.warning(f"Cannot update deleted document {document_id}")
|
||||
return False
|
||||
|
||||
|
||||
# Generate new embedding
|
||||
embeddings = await embeddings_manager.get_embeddings([text])
|
||||
|
||||
|
||||
# Update the vector in the index
|
||||
faiss.IndexFlatL2_update_vectors(self.index, embeddings.astype(np.float32), np.array([document_id], dtype=np.int64))
|
||||
|
||||
|
||||
# Update metadata
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
existing_metadata[key] = value
|
||||
|
||||
|
||||
existing_metadata['text'] = text
|
||||
existing_metadata['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
|
||||
# Save updated index
|
||||
self._save_index()
|
||||
|
||||
|
||||
logger.info(f"Updated document {document_id}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_sample_data(self) -> None:
|
||||
"""Load sample data from past campaigns into the vector store."""
|
||||
try:
|
||||
# Path to past campaigns directory
|
||||
campaigns_dir = Path(config.DATA_DIR) / "past_campaigns"
|
||||
|
||||
if not campaigns_dir.exists() or not campaigns_dir.is_dir():
|
||||
logger.warning(f"Past campaigns directory not found: {campaigns_dir}")
|
||||
return
|
||||
|
||||
# Find all JSON files in the directory
|
||||
campaign_files = list(campaigns_dir.glob("*.json"))
|
||||
if not campaign_files:
|
||||
logger.warning("No campaign files found in past_campaigns directory")
|
||||
return
|
||||
|
||||
# Load and process each campaign file
|
||||
texts = []
|
||||
metadata_list = []
|
||||
|
||||
for file_path in campaign_files:
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
campaign_data = json.load(f)
|
||||
|
||||
# Extract content and metadata
|
||||
if 'content' in campaign_data:
|
||||
texts.append(campaign_data['content'])
|
||||
|
||||
# Create metadata entry
|
||||
metadata = {
|
||||
'content_type': campaign_data.get('content_type', 'unknown'),
|
||||
'campaign_name': campaign_data.get('metadata', {}).get('campaign_name', file_path.stem),
|
||||
'source': 'past_campaign',
|
||||
'file_path': str(file_path)
|
||||
}
|
||||
|
||||
# Add performance metrics if available
|
||||
if 'metadata' in campaign_data and 'performance_metrics' in campaign_data['metadata']:
|
||||
metadata['performance_metrics'] = campaign_data['metadata']['performance_metrics']
|
||||
|
||||
metadata_list.append(metadata)
|
||||
logger.debug(f"Loaded campaign from {file_path.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading campaign file {file_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not texts:
|
||||
logger.warning("No valid campaign content found in files")
|
||||
return
|
||||
|
||||
# Add documents to vector store
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
doc_ids = loop.run_until_complete(self.add_documents(texts, metadata_list))
|
||||
logger.info(f"Added {len(doc_ids)} past campaigns to vector store")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading sample data: {str(e)}")
|
||||
|
||||
# Create a singleton instance
|
||||
vector_store = VectorStore()
|
||||
Reference in New Issue
Block a user