Implement MarketingAssistant class for query classification and content generation; update query endpoint to utilize new functionality

This commit is contained in:
2025-02-08 04:09:59 +06:00
parent 65f12d7528
commit daf09de530
2 changed files with 105 additions and 2 deletions
+7 -2
View File
@@ -20,6 +20,7 @@ from utils import save_upload_file, load_and_split_documents
from chroma_manager import ChromaManager from chroma_manager import ChromaManager
from rag import generate_marketing_response,format_context, RERANKER from rag import generate_marketing_response,format_context, RERANKER
from config import UPLOAD_DIR from config import UPLOAD_DIR
from marketing_assistant import MarketingAssistant
app = FastAPI(title="Marketing Assistant AI") app = FastAPI(title="Marketing Assistant AI")
@@ -95,11 +96,15 @@ async def upload_document(
# return {"status": "success", "new_id": new_id} # return {"status": "success", "new_id": new_id}
@app.post("/query") @app.post("/query")
async def query_documents(request: QueryRequest, async def query_documents(request: QueryRequest):
category: CategoryEnum):
"""Query documents and generate marketing response""" """Query documents and generate marketing response"""
try: try:
# Initial retrieval from vector store # Initial retrieval from vector store
assistant = MarketingAssistant()
content_type = assistant.classify_query(request.query)
print(f"Query classified as: {content_type}")
category = content_type
initial_results = chroma_manager.query_documents( initial_results = chroma_manager.query_documents(
query=request.query, query=request.query,
category=category if category else None, category=category if category else None,
@@ -0,0 +1,98 @@
# models/marketing_assistant.py
from langchain.chains.llm import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_groq import ChatGroq
from typing import Dict
from dotenv import load_dotenv
import os
load_dotenv()
class MarketingAssistant:
def __init__(self):
self.templates = self._load_templates()
self.groq_llm = ChatGroq(
temperature=0.01,
groq_api_key=os.getenv("GROQ_API_KEY"),
model_name=os.getenv("GROQ_MODEL_NAME")
)
def _load_templates(self) -> Dict[str, PromptTemplate]:
"""Load prompt templates for different content types"""
return {
'email': PromptTemplate(
input_variables=["query", "topic", "style"],
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an email newsletter about {topic} in {style} style. The query is: {query}"
),
'social': PromptTemplate(
input_variables=["query","topic", "platform"],
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Create a {platform} post about {topic}... The query is: {query}"
),
'book': PromptTemplate(
input_variables=["query","topic", "style"],
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write a book blurb about {topic} in {style} style... The query is: {query}"
),
'article': PromptTemplate(
input_variables=["query","topic", "style"],
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an article about {topic} in {style} style... The query is: {query}"
),
}
def generate_content(self, content_type: str, **kwargs) -> str:
"""Generate marketing content"""
template = self.templates.get(content_type)
if not template:
raise ValueError(f"Unknown content type: {content_type}")
chain = LLMChain(llm=self.groq_llm, prompt=template)
return chain.run(**kwargs)
def classify_query(self, query: str) -> str:
"""Classify the intent of a query"""
prompt = PromptTemplate(
input_variables=["query"],
template="Classify the following query into one of the following categories: book, article, email, social. Query: {query}. Just return the category name, nothing else."
)
chain = LLMChain(llm=self.groq_llm, prompt=prompt)
result = chain.run(query=query)
if "book" in result.lower():
return "book"
elif "article" in result.lower():
return "article"
elif "email" in result.lower():
return "email"
elif "social" in result.lower():
return "social"
else:
raise ValueError(f"Unknown query classification: {result}")
if __name__ == "__main__":
#query = "Write an email for a marketing campaign about Time Line Therapy®"
#query = "Write a book about Time Line Therapy®"
query = "Write a social media post about Time Line Therapy®"
#query = "Write an article about Time Line Therapy®"
assistant = MarketingAssistant(api_key="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")
try:
content_type = assistant.classify_query(query)
print(f"Query classified as: {content_type}")
content = assistant.generate_content(content_type, query = query, topic="Time Line Therapy®", style="casual", platform="Instagram")
print(f"Generated content for {content_type}:")
print(content)
except ValueError as e:
print(e)
exit(1)
# email = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram")
# social = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram")
# book = assistant.generate_content("book", topic="Time Line Therapy®", style="formal")
# article = assistant.generate_content("article", topic="Time Line Therapy®", style="casual")
# print("Generated content for book:")
# print(book)
# print("Generated content for social media post:")
# print(social)
# print("Generated content for article:")
# print(article)
# print("Generated content for email:")
# print(email)