Implement MarketingAssistant class for query classification and content generation; update query endpoint to utilize new functionality
This commit is contained in:
@@ -20,6 +20,7 @@ from utils import save_upload_file, load_and_split_documents
|
|||||||
from chroma_manager import ChromaManager
|
from chroma_manager import ChromaManager
|
||||||
from rag import generate_marketing_response,format_context, RERANKER
|
from rag import generate_marketing_response,format_context, RERANKER
|
||||||
from config import UPLOAD_DIR
|
from config import UPLOAD_DIR
|
||||||
|
from marketing_assistant import MarketingAssistant
|
||||||
|
|
||||||
app = FastAPI(title="Marketing Assistant AI")
|
app = FastAPI(title="Marketing Assistant AI")
|
||||||
|
|
||||||
@@ -95,11 +96,15 @@ async def upload_document(
|
|||||||
# return {"status": "success", "new_id": new_id}
|
# return {"status": "success", "new_id": new_id}
|
||||||
|
|
||||||
@app.post("/query")
|
@app.post("/query")
|
||||||
async def query_documents(request: QueryRequest,
|
async def query_documents(request: QueryRequest):
|
||||||
category: CategoryEnum):
|
|
||||||
"""Query documents and generate marketing response"""
|
"""Query documents and generate marketing response"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initial retrieval from vector store
|
# Initial retrieval from vector store
|
||||||
|
assistant = MarketingAssistant()
|
||||||
|
content_type = assistant.classify_query(request.query)
|
||||||
|
print(f"Query classified as: {content_type}")
|
||||||
|
category = content_type
|
||||||
initial_results = chroma_manager.query_documents(
|
initial_results = chroma_manager.query_documents(
|
||||||
query=request.query,
|
query=request.query,
|
||||||
category=category if category else None,
|
category=category if category else None,
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
# models/marketing_assistant.py
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_community.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain_groq import ChatGroq
|
||||||
|
from typing import Dict
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class MarketingAssistant:
|
||||||
|
def __init__(self):
|
||||||
|
self.templates = self._load_templates()
|
||||||
|
self.groq_llm = ChatGroq(
|
||||||
|
temperature=0.01,
|
||||||
|
groq_api_key=os.getenv("GROQ_API_KEY"),
|
||||||
|
model_name=os.getenv("GROQ_MODEL_NAME")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_templates(self) -> Dict[str, PromptTemplate]:
|
||||||
|
"""Load prompt templates for different content types"""
|
||||||
|
return {
|
||||||
|
'email': PromptTemplate(
|
||||||
|
input_variables=["query", "topic", "style"],
|
||||||
|
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an email newsletter about {topic} in {style} style. The query is: {query}"
|
||||||
|
),
|
||||||
|
'social': PromptTemplate(
|
||||||
|
input_variables=["query","topic", "platform"],
|
||||||
|
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Create a {platform} post about {topic}... The query is: {query}"
|
||||||
|
),
|
||||||
|
'book': PromptTemplate(
|
||||||
|
input_variables=["query","topic", "style"],
|
||||||
|
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write a book blurb about {topic} in {style} style... The query is: {query}"
|
||||||
|
),
|
||||||
|
'article': PromptTemplate(
|
||||||
|
input_variables=["query","topic", "style"],
|
||||||
|
template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an article about {topic} in {style} style... The query is: {query}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_content(self, content_type: str, **kwargs) -> str:
|
||||||
|
"""Generate marketing content"""
|
||||||
|
template = self.templates.get(content_type)
|
||||||
|
if not template:
|
||||||
|
raise ValueError(f"Unknown content type: {content_type}")
|
||||||
|
|
||||||
|
chain = LLMChain(llm=self.groq_llm, prompt=template)
|
||||||
|
return chain.run(**kwargs)
|
||||||
|
|
||||||
|
def classify_query(self, query: str) -> str:
|
||||||
|
"""Classify the intent of a query"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["query"],
|
||||||
|
template="Classify the following query into one of the following categories: book, article, email, social. Query: {query}. Just return the category name, nothing else."
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=self.groq_llm, prompt=prompt)
|
||||||
|
result = chain.run(query=query)
|
||||||
|
if "book" in result.lower():
|
||||||
|
return "book"
|
||||||
|
elif "article" in result.lower():
|
||||||
|
return "article"
|
||||||
|
elif "email" in result.lower():
|
||||||
|
return "email"
|
||||||
|
elif "social" in result.lower():
|
||||||
|
return "social"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown query classification: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#query = "Write an email for a marketing campaign about Time Line Therapy®"
|
||||||
|
#query = "Write a book about Time Line Therapy®"
|
||||||
|
query = "Write a social media post about Time Line Therapy®"
|
||||||
|
#query = "Write an article about Time Line Therapy®"
|
||||||
|
assistant = MarketingAssistant(api_key="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")
|
||||||
|
try:
|
||||||
|
content_type = assistant.classify_query(query)
|
||||||
|
print(f"Query classified as: {content_type}")
|
||||||
|
content = assistant.generate_content(content_type, query = query, topic="Time Line Therapy®", style="casual", platform="Instagram")
|
||||||
|
print(f"Generated content for {content_type}:")
|
||||||
|
print(content)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
exit(1)
|
||||||
|
# email = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram")
|
||||||
|
# social = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram")
|
||||||
|
# book = assistant.generate_content("book", topic="Time Line Therapy®", style="formal")
|
||||||
|
# article = assistant.generate_content("article", topic="Time Line Therapy®", style="casual")
|
||||||
|
# print("Generated content for book:")
|
||||||
|
# print(book)
|
||||||
|
# print("Generated content for social media post:")
|
||||||
|
# print(social)
|
||||||
|
# print("Generated content for article:")
|
||||||
|
# print(article)
|
||||||
|
# print("Generated content for email:")
|
||||||
|
# print(email)
|
||||||
Reference in New Issue
Block a user