From 480f6f06c23c462e9e8f726d7db4307f1cb2a1fd Mon Sep 17 00:00:00 2001 From: kowshik24 Date: Thu, 6 Feb 2025 03:20:41 +0600 Subject: [PATCH] marketing assistant added --- src/marketing_assistant_ai/__init__.py | 0 src/marketing_assistant_ai/chroma_manager.py | 55 +++++++++++++ src/marketing_assistant_ai/config.py | 20 +++++ src/marketing_assistant_ai/main.py | 87 ++++++++++++++++++++ src/marketing_assistant_ai/rag.py | 42 ++++++++++ src/marketing_assistant_ai/schemas.py | 20 +++++ src/marketing_assistant_ai/utils.py | 22 +++++ 7 files changed, 246 insertions(+) create mode 100644 src/marketing_assistant_ai/__init__.py create mode 100644 src/marketing_assistant_ai/chroma_manager.py create mode 100644 src/marketing_assistant_ai/config.py create mode 100644 src/marketing_assistant_ai/main.py create mode 100644 src/marketing_assistant_ai/rag.py create mode 100644 src/marketing_assistant_ai/schemas.py create mode 100644 src/marketing_assistant_ai/utils.py diff --git a/src/marketing_assistant_ai/__init__.py b/src/marketing_assistant_ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/marketing_assistant_ai/chroma_manager.py b/src/marketing_assistant_ai/chroma_manager.py new file mode 100644 index 0000000..8f8c6c8 --- /dev/null +++ b/src/marketing_assistant_ai/chroma_manager.py @@ -0,0 +1,55 @@ +from typing import List, Optional +from langchain_core.documents import Document +from langchain_chroma import Chroma +import uuid +import chromadb +from config import CHROMA_PATH, COLLECTION_NAME, MODEL_NAME +from langchain_huggingface import HuggingFaceEmbeddings + +class ChromaManager: + def __init__(self): + self.embed_model = HuggingFaceEmbeddings( + model_name=MODEL_NAME, + encode_kwargs={'normalize_embeddings': True} + ) + self.vector_store = Chroma( + collection_name=COLLECTION_NAME, + persist_directory=str(CHROMA_PATH), + embedding_function=self.embed_model + ) + + def get_collection_info(self): + return { + "document_count": self.vector_store._collection.count(), + "collection_name": COLLECTION_NAME + } + + def add_documents(self, documents: List[Document]): + try: + ids = [str(uuid.uuid4()) for _ in documents] + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + + self.vector_store.add_texts( + texts=texts, + metadatas=metadatas, + ids=ids + ) + return ids + except Exception as e: + raise ValueError(f"Error adding documents: {str(e)}") + + def delete_document(self, doc_id: str): + try: + self.vector_store._collection.delete(ids=[doc_id]) + return True + except Exception as e: + raise ValueError(f"Delete error: {str(e)}") + + def update_document(self, doc_id: str, new_content: str, metadata: dict): + try: + new_doc = Document(page_content=new_content, metadata=metadata) + self.delete_document(doc_id) + return self.add_documents([new_doc])[0] + except Exception as e: + raise ValueError(f"Update error: {str(e)}") \ No newline at end of file diff --git a/src/marketing_assistant_ai/config.py b/src/marketing_assistant_ai/config.py new file mode 100644 index 0000000..14e56af --- /dev/null +++ b/src/marketing_assistant_ai/config.py @@ -0,0 +1,20 @@ +import os +from pathlib import Path + +# Base directory +BASE_DIR = Path(__file__).parent.parent + +# Configuration +MODEL_NAME = "BAAI/bge-large-en-v1.5" +RERANKER_NAME = "BAAI/bge-reranker-base" +GROQ_MODEL = "llama-3.3-70b-versatile" +DOCS_PATH = BASE_DIR / "client_assets" +CHROMA_PATH = BASE_DIR / "chroma_index" +COLLECTION_NAME = "marketing_docs" + +# Create directories if they don't exist +DOCS_PATH.mkdir(exist_ok=True) +CHROMA_PATH.mkdir(exist_ok=True) + +# Groq API Key (Set through environment variable) +GROQ_API_KEY = os.getenv("GROQ_API_KEY") \ No newline at end of file diff --git a/src/marketing_assistant_ai/main.py b/src/marketing_assistant_ai/main.py new file mode 100644 index 0000000..5897935 --- /dev/null +++ b/src/marketing_assistant_ai/main.py @@ -0,0 +1,87 @@ +from fastapi import FastAPI, UploadFile, File, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from schemas import DocumentUpload, DocumentUpdate, DocumentDelete, QueryRequest, ResponseSchema +from utils import process_uploaded_files, save_upload_file +from chroma_manager import ChromaManager +from rag import RAGSystem +from config import DOCS_PATH +import uuid + +app = FastAPI(title="Marketing Assistant AI") + +# CORS Configuration +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize components +chroma_manager = ChromaManager() +rag_system = RAGSystem() + +@app.post("/upload/", response_model=ResponseSchema) +async def upload_document(file: UploadFile = File(...)): + try: + # Save file + filename = f"{uuid.uuid4()}_{file.filename}" + save_upload_file(file, filename) + + # Process and add to Chroma + documents = process_uploaded_files() + chroma_manager.add_documents(documents) + + return { + "result": "Documents processed successfully", + "collection_info": chroma_manager.get_collection_info() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.put("/update/", response_model=ResponseSchema) +async def update_document(update_data: DocumentUpdate): + try: + new_id = chroma_manager.update_document( + update_data.doc_id, + update_data.new_content, + update_data.metadata + ) + return { + "result": f"Document updated with ID: {new_id}", + "collection_info": chroma_manager.get_collection_info() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.delete("/delete/", response_model=ResponseSchema) +async def delete_document(delete_data: DocumentDelete): + try: + chroma_manager.delete_document(delete_data.doc_id) + return { + "result": "Document deleted successfully", + "collection_info": chroma_manager.get_collection_info() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/query/", response_model=ResponseSchema) +async def process_query(query: QueryRequest): + try: + response = rag_system.query(query.question) + return { + "result": response, + "collection_info": chroma_manager.get_collection_info() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/collection-info/", response_model=ResponseSchema) +async def get_collection_info(): + try: + return { + "result": "Current collection status", + "collection_info": chroma_manager.get_collection_info() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/src/marketing_assistant_ai/rag.py b/src/marketing_assistant_ai/rag.py new file mode 100644 index 0000000..7e0cb56 --- /dev/null +++ b/src/marketing_assistant_ai/rag.py @@ -0,0 +1,42 @@ +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_community.cross_encoders import HuggingFaceCrossEncoder +from langchain_groq import ChatGroq +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnablePassthrough +from chroma_manager import ChromaManager +from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY + +class RAGSystem: + def __init__(self): + self.chroma_manager = ChromaManager() + self.llm = ChatGroq( + temperature=0.01, + groq_api_key=GROQ_API_KEY, + model_name=GROQ_MODEL + ) + self._init_retriever() + self._init_chain() + + def _init_retriever(self): + model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME) + reranker = CrossEncoderReranker(model=model, top_n=5) + self.retriever = ContextualCompressionRetriever( + base_compressor=reranker, + base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10}) + ) + + def _init_chain(self): + template = """...""" # Your existing template here + + prompt = ChatPromptTemplate.from_template(template) + self.rag_chain = ( + {"context": self.retriever, "question": RunnablePassthrough()} + | prompt + | self.llm + | StrOutputParser() + ) + + def query(self, question: str): + return self.rag_chain.invoke(question) \ No newline at end of file diff --git a/src/marketing_assistant_ai/schemas.py b/src/marketing_assistant_ai/schemas.py new file mode 100644 index 0000000..3a5aa9e --- /dev/null +++ b/src/marketing_assistant_ai/schemas.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel +from typing import List, Optional + +class DocumentUpload(BaseModel): + description: Optional[str] = None + +class DocumentUpdate(BaseModel): + doc_id: str + new_content: str + metadata: dict + +class DocumentDelete(BaseModel): + doc_id: str + +class QueryRequest(BaseModel): + question: str + +class ResponseSchema(BaseModel): + result: str + collection_info: dict \ No newline at end of file diff --git a/src/marketing_assistant_ai/utils.py b/src/marketing_assistant_ai/utils.py new file mode 100644 index 0000000..d086f51 --- /dev/null +++ b/src/marketing_assistant_ai/utils.py @@ -0,0 +1,22 @@ +from langchain_community.document_loaders import DirectoryLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_core.documents import Document +from config import DOCS_PATH + +def process_uploaded_files(): + """Process documents in the upload directory""" + loader = DirectoryLoader(DOCS_PATH, glob=["**/*.pdf", "**/*.txt"]) + documents = loader.load() + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200 + ) + return text_splitter.split_documents(documents) + +def save_upload_file(file, filename: str): + """Save uploaded file to documents directory""" + file_path = DOCS_PATH / filename + with open(file_path, "wb") as buffer: + buffer.write(file.file.read()) + return file_path \ No newline at end of file