Refactor BrandStyleManager and main.py to integrate database functionality for campaign management. Replace JSON file handling with database calls for loading, saving, updating, and deleting campaigns. Update sample campaign loading to retrieve data from the database instead of JSON files.

This commit is contained in:
boladeE
2025-04-21 15:31:44 +01:00
parent 3c0dd1d972
commit b74180e595
4 changed files with 98 additions and 69 deletions
+9 -28
View File
@@ -6,6 +6,7 @@ from PyPDF2 import PdfReader
from embeddings import CohereEmbeddings
from vector_store import VectorStore
from config import settings
from database import Database
class BrandStyleManager:
def __init__(self, embeddings: CohereEmbeddings, vector_store: VectorStore):
@@ -13,6 +14,7 @@ class BrandStyleManager:
self.embeddings = embeddings
self.vector_store = vector_store
self.brand_voice = self._load_brand_voice()
self.db = Database()
self.sample_campaigns = self._load_sample_campaigns()
def _load_brand_voice(self) -> Dict[str, Any]:
@@ -24,14 +26,8 @@ class BrandStyleManager:
return {}
def _load_sample_campaigns(self) -> List[Dict[str, Any]]:
"""Load sample campaigns from JSON."""
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "past_campaigns", "sample_campaigns.json"))
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get("campaigns", [])
return []
"""Load sample campaigns from the database."""
return self.db.get_all_campaigns()
def _extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from a PDF file."""
@@ -39,9 +35,7 @@ class BrandStyleManager:
try:
reader = PdfReader(pdf_path)
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n\n"
text += page.extract_text() + "\n"
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return text
@@ -83,23 +77,10 @@ class BrandStyleManager:
else:
print("No content found to add to vector store")
def get_relevant_context(self, prompt: str, k: int = 5) -> List[Dict]:
"""Get relevant context for a given prompt from book excerpts."""
# Generate embedding for the prompt
prompt_embedding = self.embeddings.generate_embedding(prompt)
# Search for similar content in book excerpts
results = self.vector_store.search(prompt_embedding, k=k)
# Optionally rerank results
if results:
texts = [result["text"] for result in results]
reranked = self.embeddings.rerank_results(prompt, texts, top_n=k)
# Convert reranked results to the expected format
return [{"text": text} for text in reranked]
# If no results, return empty list
return []
def get_relevant_context(self, query: str, top_k: int = 3) -> List[Dict[str, str]]:
"""Get relevant context from the vector store based on the query."""
query_embedding = self.embeddings.generate_embedding(query)
return self.vector_store.search(query_embedding, k=top_k)
def get_brand_voice(self) -> Dict[str, Any]:
"""Get brand voice guidelines."""
+80
View File
@@ -0,0 +1,80 @@
import sqlite3
import os
from typing import List, Dict, Any
import datetime
class Database:
def __init__(self):
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
self.db_path = os.path.join(data_dir, "campaigns.db")
self._init_db()
def _init_db(self):
"""Initialize the database with required tables."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS campaigns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT NOT NULL
)
''')
conn.commit()
def get_all_campaigns(self) -> List[Dict[str, Any]]:
"""Retrieve all campaigns from the database."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT prompt, content, timestamp FROM campaigns ORDER BY timestamp DESC')
rows = cursor.fetchall()
return [
{
"prompt": row[0],
"content": row[1],
"timestamp": row[2]
}
for row in rows
]
def add_campaign(self, prompt: str, content: str) -> None:
"""Add a new campaign to the database."""
timestamp = datetime.datetime.now().isoformat()
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
'INSERT INTO campaigns (prompt, content, timestamp) VALUES (?, ?, ?)',
(prompt, content, timestamp)
)
conn.commit()
def update_campaign(self, index: int, content: str) -> bool:
"""Update an existing campaign's content."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT id FROM campaigns ORDER BY timestamp DESC LIMIT 1 OFFSET ?', (index,))
result = cursor.fetchone()
if result:
campaign_id = result[0]
timestamp = datetime.datetime.now().isoformat()
cursor.execute(
'UPDATE campaigns SET content = ?, timestamp = ? WHERE id = ?',
(content, timestamp, campaign_id)
)
conn.commit()
return True
return False
def delete_campaign(self, index: int) -> bool:
"""Delete a campaign by its index."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute('SELECT id FROM campaigns ORDER BY timestamp DESC LIMIT 1 OFFSET ?', (index,))
result = cursor.fetchone()
if result:
campaign_id = result[0]
cursor.execute('DELETE FROM campaigns WHERE id = ?', (campaign_id,))
conn.commit()
return True
return False
+9 -41
View File
@@ -4,30 +4,17 @@ from typing import Optional, List, Dict, Any
from copywriter import generate_marketing_copy
from brand_style import BrandStyleManager
from config import settings
from database import Database
import os
import json
import datetime
app = Flask(__name__)
# Initialize brand style manager
# Initialize brand style manager and database
data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
campaign_prompt = []
def load_campaigns():
campaigns_file = os.path.join(data_dir, "past_campaigns", "campaigns.json")
if os.path.exists(campaigns_file):
with open(campaigns_file, 'r', encoding='utf-8') as f:
data = json.load(f)
campaigns = data.get("campaigns", [])
return campaigns
return []
def save_campaigns(campaigns):
campaigns_file = os.path.join(data_dir, "past_campaigns", "campaigns.json")
os.makedirs(os.path.dirname(campaigns_file), exist_ok=True)
with open(campaigns_file, 'w', encoding='utf-8') as f:
json.dump({"campaigns": campaigns}, f, indent=4)
db = Database()
@app.route('/', methods=['GET', 'POST'])
def root():
@@ -41,7 +28,7 @@ def root():
@app.route('/campaigns')
def view_campaigns():
campaigns = load_campaigns()
campaigns = db.get_all_campaigns()
return render_template('campaigns.html', campaigns=campaigns)
@app.route('/save-edit', methods=['POST'])
@@ -49,40 +36,21 @@ def save_edit():
edited_copy = request.form.get('editedCopy')
global campaign_prompt
prompt = campaign_prompt[-1]
campaigns = load_campaigns()
new_campaign = {
"prompt": prompt,
"content": edited_copy,
"timestamp": datetime.datetime.now().isoformat()
}
campaigns.append(new_campaign)
save_campaigns(campaigns)
db.add_campaign(prompt, edited_copy)
return render_template('index.html', generated_copy="Campaign saved successfully")
@app.route('/update-campaign', methods=['POST'])
def update_campaign():
index = int(request.form.get('index'))
edited_copy = request.form.get('editedCopy')
campaigns = load_campaigns()
if 0 <= index < len(campaigns):
campaigns[index]['content'] = edited_copy
campaigns[index]['timestamp'] = datetime.datetime.now().isoformat()
save_campaigns(campaigns)
db.update_campaign(index, edited_copy)
return redirect(url_for('view_campaigns'))
@app.route('/delete-campaign', methods=['POST'])
def delete_campaign():
index = int(request.form.get('index'))
campaigns = load_campaigns()
if 0 <= index < len(campaigns):
campaigns.pop(index)
save_campaigns(campaigns)
db.delete_campaign(index)
return redirect(url_for('view_campaigns'))
if __name__ == "__main__":
app.run(host='localhost', port=8000, debug=True)
if __name__ == '__main__':
app.run(debug=True)