marketing assistant added
This commit is contained in:
@@ -0,0 +1,55 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_chroma import Chroma
|
||||||
|
import uuid
|
||||||
|
import chromadb
|
||||||
|
from config import CHROMA_PATH, COLLECTION_NAME, MODEL_NAME
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
class ChromaManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.embed_model = HuggingFaceEmbeddings(
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
encode_kwargs={'normalize_embeddings': True}
|
||||||
|
)
|
||||||
|
self.vector_store = Chroma(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
persist_directory=str(CHROMA_PATH),
|
||||||
|
embedding_function=self.embed_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_collection_info(self):
|
||||||
|
return {
|
||||||
|
"document_count": self.vector_store._collection.count(),
|
||||||
|
"collection_name": COLLECTION_NAME
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_documents(self, documents: List[Document]):
|
||||||
|
try:
|
||||||
|
ids = [str(uuid.uuid4()) for _ in documents]
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [doc.metadata for doc in documents]
|
||||||
|
|
||||||
|
self.vector_store.add_texts(
|
||||||
|
texts=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
|
return ids
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error adding documents: {str(e)}")
|
||||||
|
|
||||||
|
def delete_document(self, doc_id: str):
|
||||||
|
try:
|
||||||
|
self.vector_store._collection.delete(ids=[doc_id])
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Delete error: {str(e)}")
|
||||||
|
|
||||||
|
def update_document(self, doc_id: str, new_content: str, metadata: dict):
|
||||||
|
try:
|
||||||
|
new_doc = Document(page_content=new_content, metadata=metadata)
|
||||||
|
self.delete_document(doc_id)
|
||||||
|
return self.add_documents([new_doc])[0]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Update error: {str(e)}")
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Base directory
|
||||||
|
BASE_DIR = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
MODEL_NAME = "BAAI/bge-large-en-v1.5"
|
||||||
|
RERANKER_NAME = "BAAI/bge-reranker-base"
|
||||||
|
GROQ_MODEL = "llama-3.3-70b-versatile"
|
||||||
|
DOCS_PATH = BASE_DIR / "client_assets"
|
||||||
|
CHROMA_PATH = BASE_DIR / "chroma_index"
|
||||||
|
COLLECTION_NAME = "marketing_docs"
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
DOCS_PATH.mkdir(exist_ok=True)
|
||||||
|
CHROMA_PATH.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Groq API Key (Set through environment variable)
|
||||||
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from schemas import DocumentUpload, DocumentUpdate, DocumentDelete, QueryRequest, ResponseSchema
|
||||||
|
from utils import process_uploaded_files, save_upload_file
|
||||||
|
from chroma_manager import ChromaManager
|
||||||
|
from rag import RAGSystem
|
||||||
|
from config import DOCS_PATH
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
app = FastAPI(title="Marketing Assistant AI")
|
||||||
|
|
||||||
|
# CORS Configuration
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
chroma_manager = ChromaManager()
|
||||||
|
rag_system = RAGSystem()
|
||||||
|
|
||||||
|
@app.post("/upload/", response_model=ResponseSchema)
|
||||||
|
async def upload_document(file: UploadFile = File(...)):
|
||||||
|
try:
|
||||||
|
# Save file
|
||||||
|
filename = f"{uuid.uuid4()}_{file.filename}"
|
||||||
|
save_upload_file(file, filename)
|
||||||
|
|
||||||
|
# Process and add to Chroma
|
||||||
|
documents = process_uploaded_files()
|
||||||
|
chroma_manager.add_documents(documents)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": "Documents processed successfully",
|
||||||
|
"collection_info": chroma_manager.get_collection_info()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.put("/update/", response_model=ResponseSchema)
|
||||||
|
async def update_document(update_data: DocumentUpdate):
|
||||||
|
try:
|
||||||
|
new_id = chroma_manager.update_document(
|
||||||
|
update_data.doc_id,
|
||||||
|
update_data.new_content,
|
||||||
|
update_data.metadata
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"result": f"Document updated with ID: {new_id}",
|
||||||
|
"collection_info": chroma_manager.get_collection_info()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.delete("/delete/", response_model=ResponseSchema)
|
||||||
|
async def delete_document(delete_data: DocumentDelete):
|
||||||
|
try:
|
||||||
|
chroma_manager.delete_document(delete_data.doc_id)
|
||||||
|
return {
|
||||||
|
"result": "Document deleted successfully",
|
||||||
|
"collection_info": chroma_manager.get_collection_info()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post("/query/", response_model=ResponseSchema)
|
||||||
|
async def process_query(query: QueryRequest):
|
||||||
|
try:
|
||||||
|
response = rag_system.query(query.question)
|
||||||
|
return {
|
||||||
|
"result": response,
|
||||||
|
"collection_info": chroma_manager.get_collection_info()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/collection-info/", response_model=ResponseSchema)
|
||||||
|
async def get_collection_info():
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"result": "Current collection status",
|
||||||
|
"collection_info": chroma_manager.get_collection_info()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from langchain.retrievers import ContextualCompressionRetriever
|
||||||
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||||
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||||
|
from langchain_groq import ChatGroq
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import RunnablePassthrough
|
||||||
|
from chroma_manager import ChromaManager
|
||||||
|
from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY
|
||||||
|
|
||||||
|
class RAGSystem:
|
||||||
|
def __init__(self):
|
||||||
|
self.chroma_manager = ChromaManager()
|
||||||
|
self.llm = ChatGroq(
|
||||||
|
temperature=0.01,
|
||||||
|
groq_api_key=GROQ_API_KEY,
|
||||||
|
model_name=GROQ_MODEL
|
||||||
|
)
|
||||||
|
self._init_retriever()
|
||||||
|
self._init_chain()
|
||||||
|
|
||||||
|
def _init_retriever(self):
|
||||||
|
model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME)
|
||||||
|
reranker = CrossEncoderReranker(model=model, top_n=5)
|
||||||
|
self.retriever = ContextualCompressionRetriever(
|
||||||
|
base_compressor=reranker,
|
||||||
|
base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10})
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_chain(self):
|
||||||
|
template = """...""" # Your existing template here
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
self.rag_chain = (
|
||||||
|
{"context": self.retriever, "question": RunnablePassthrough()}
|
||||||
|
| prompt
|
||||||
|
| self.llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(self, question: str):
|
||||||
|
return self.rag_chain.invoke(question)
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
class DocumentUpload(BaseModel):
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class DocumentUpdate(BaseModel):
|
||||||
|
doc_id: str
|
||||||
|
new_content: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
class DocumentDelete(BaseModel):
|
||||||
|
doc_id: str
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
question: str
|
||||||
|
|
||||||
|
class ResponseSchema(BaseModel):
|
||||||
|
result: str
|
||||||
|
collection_info: dict
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from langchain_community.document_loaders import DirectoryLoader
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from config import DOCS_PATH
|
||||||
|
|
||||||
|
def process_uploaded_files():
|
||||||
|
"""Process documents in the upload directory"""
|
||||||
|
loader = DirectoryLoader(DOCS_PATH, glob=["**/*.pdf", "**/*.txt"])
|
||||||
|
documents = loader.load()
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
return text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
def save_upload_file(file, filename: str):
|
||||||
|
"""Save uploaded file to documents directory"""
|
||||||
|
file_path = DOCS_PATH / filename
|
||||||
|
with open(file_path, "wb") as buffer:
|
||||||
|
buffer.write(file.file.read())
|
||||||
|
return file_path
|
||||||
Reference in New Issue
Block a user