diff --git a/src/marketing_assistant_ai/main.py b/src/marketing_assistant_ai/main.py index 7a4f7bc..f36f33f 100644 --- a/src/marketing_assistant_ai/main.py +++ b/src/marketing_assistant_ai/main.py @@ -20,6 +20,7 @@ from utils import save_upload_file, load_and_split_documents from chroma_manager import ChromaManager from rag import generate_marketing_response,format_context, RERANKER from config import UPLOAD_DIR +from marketing_assistant import MarketingAssistant app = FastAPI(title="Marketing Assistant AI") @@ -95,11 +96,15 @@ async def upload_document( # return {"status": "success", "new_id": new_id} @app.post("/query") -async def query_documents(request: QueryRequest, - category: CategoryEnum): +async def query_documents(request: QueryRequest): """Query documents and generate marketing response""" + try: # Initial retrieval from vector store + assistant = MarketingAssistant() + content_type = assistant.classify_query(request.query) + print(f"Query classified as: {content_type}") + category = content_type initial_results = chroma_manager.query_documents( query=request.query, category=category if category else None, diff --git a/src/marketing_assistant_ai/marketing_assistant.py b/src/marketing_assistant_ai/marketing_assistant.py new file mode 100644 index 0000000..3692930 --- /dev/null +++ b/src/marketing_assistant_ai/marketing_assistant.py @@ -0,0 +1,98 @@ +# models/marketing_assistant.py +from langchain.chains.llm import LLMChain +from langchain_core.prompts import PromptTemplate +from langchain_community.chat_models.openai import ChatOpenAI +from langchain_groq import ChatGroq +from typing import Dict +from dotenv import load_dotenv +import os + +load_dotenv() + +class MarketingAssistant: + def __init__(self): + self.templates = self._load_templates() + self.groq_llm = ChatGroq( + temperature=0.01, + groq_api_key=os.getenv("GROQ_API_KEY"), + model_name=os.getenv("GROQ_MODEL_NAME") + ) + + def _load_templates(self) -> Dict[str, PromptTemplate]: + """Load prompt templates for different content types""" + return { + 'email': PromptTemplate( + input_variables=["query", "topic", "style"], + template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an email newsletter about {topic} in {style} style. The query is: {query}" + ), + 'social': PromptTemplate( + input_variables=["query","topic", "platform"], + template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Create a {platform} post about {topic}... The query is: {query}" + ), + 'book': PromptTemplate( + input_variables=["query","topic", "style"], + template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write a book blurb about {topic} in {style} style... The query is: {query}" + ), + 'article': PromptTemplate( + input_variables=["query","topic", "style"], + template="Act like you are Adriana James (Adriana James is a woman of force and character, beauty and charm and an expert leader in the field of Neuro-Linguistic Programming (NLP), NLP Coaching and Time Line Therapy®.), write marketing copy in her signature style. Write an article about {topic} in {style} style... The query is: {query}" + ), + } + + def generate_content(self, content_type: str, **kwargs) -> str: + """Generate marketing content""" + template = self.templates.get(content_type) + if not template: + raise ValueError(f"Unknown content type: {content_type}") + + chain = LLMChain(llm=self.groq_llm, prompt=template) + return chain.run(**kwargs) + + def classify_query(self, query: str) -> str: + """Classify the intent of a query""" + + prompt = PromptTemplate( + input_variables=["query"], + template="Classify the following query into one of the following categories: book, article, email, social. Query: {query}. Just return the category name, nothing else." + ) + chain = LLMChain(llm=self.groq_llm, prompt=prompt) + result = chain.run(query=query) + if "book" in result.lower(): + return "book" + elif "article" in result.lower(): + return "article" + elif "email" in result.lower(): + return "email" + elif "social" in result.lower(): + return "social" + else: + raise ValueError(f"Unknown query classification: {result}") + + +if __name__ == "__main__": + #query = "Write an email for a marketing campaign about Time Line Therapy®" + #query = "Write a book about Time Line Therapy®" + query = "Write a social media post about Time Line Therapy®" + #query = "Write an article about Time Line Therapy®" + assistant = MarketingAssistant(api_key="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") + try: + content_type = assistant.classify_query(query) + print(f"Query classified as: {content_type}") + content = assistant.generate_content(content_type, query = query, topic="Time Line Therapy®", style="casual", platform="Instagram") + print(f"Generated content for {content_type}:") + print(content) + except ValueError as e: + print(e) + exit(1) + # email = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram") + # social = assistant.generate_content("social", topic="Time Line Therapy®", style="casual", platform="Instagram") + # book = assistant.generate_content("book", topic="Time Line Therapy®", style="formal") + # article = assistant.generate_content("article", topic="Time Line Therapy®", style="casual") + # print("Generated content for book:") + # print(book) + # print("Generated content for social media post:") + # print(social) + # print("Generated content for article:") + # print(article) + # print("Generated content for email:") + # print(email) \ No newline at end of file