marketing assistant added
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_chroma import Chroma
|
||||
import uuid
|
||||
import chromadb
|
||||
from config import CHROMA_PATH, COLLECTION_NAME, MODEL_NAME
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
class ChromaManager:
|
||||
def __init__(self):
|
||||
self.embed_model = HuggingFaceEmbeddings(
|
||||
model_name=MODEL_NAME,
|
||||
encode_kwargs={'normalize_embeddings': True}
|
||||
)
|
||||
self.vector_store = Chroma(
|
||||
collection_name=COLLECTION_NAME,
|
||||
persist_directory=str(CHROMA_PATH),
|
||||
embedding_function=self.embed_model
|
||||
)
|
||||
|
||||
def get_collection_info(self):
|
||||
return {
|
||||
"document_count": self.vector_store._collection.count(),
|
||||
"collection_name": COLLECTION_NAME
|
||||
}
|
||||
|
||||
def add_documents(self, documents: List[Document]):
|
||||
try:
|
||||
ids = [str(uuid.uuid4()) for _ in documents]
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
|
||||
self.vector_store.add_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
return ids
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error adding documents: {str(e)}")
|
||||
|
||||
def delete_document(self, doc_id: str):
|
||||
try:
|
||||
self.vector_store._collection.delete(ids=[doc_id])
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(f"Delete error: {str(e)}")
|
||||
|
||||
def update_document(self, doc_id: str, new_content: str, metadata: dict):
|
||||
try:
|
||||
new_doc = Document(page_content=new_content, metadata=metadata)
|
||||
self.delete_document(doc_id)
|
||||
return self.add_documents([new_doc])[0]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Update error: {str(e)}")
|
||||
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Base directory
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
|
||||
# Configuration
|
||||
MODEL_NAME = "BAAI/bge-large-en-v1.5"
|
||||
RERANKER_NAME = "BAAI/bge-reranker-base"
|
||||
GROQ_MODEL = "llama-3.3-70b-versatile"
|
||||
DOCS_PATH = BASE_DIR / "client_assets"
|
||||
CHROMA_PATH = BASE_DIR / "chroma_index"
|
||||
COLLECTION_NAME = "marketing_docs"
|
||||
|
||||
# Create directories if they don't exist
|
||||
DOCS_PATH.mkdir(exist_ok=True)
|
||||
CHROMA_PATH.mkdir(exist_ok=True)
|
||||
|
||||
# Groq API Key (Set through environment variable)
|
||||
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||
@@ -0,0 +1,87 @@
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from schemas import DocumentUpload, DocumentUpdate, DocumentDelete, QueryRequest, ResponseSchema
|
||||
from utils import process_uploaded_files, save_upload_file
|
||||
from chroma_manager import ChromaManager
|
||||
from rag import RAGSystem
|
||||
from config import DOCS_PATH
|
||||
import uuid
|
||||
|
||||
app = FastAPI(title="Marketing Assistant AI")
|
||||
|
||||
# CORS Configuration
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize components
|
||||
chroma_manager = ChromaManager()
|
||||
rag_system = RAGSystem()
|
||||
|
||||
@app.post("/upload/", response_model=ResponseSchema)
|
||||
async def upload_document(file: UploadFile = File(...)):
|
||||
try:
|
||||
# Save file
|
||||
filename = f"{uuid.uuid4()}_{file.filename}"
|
||||
save_upload_file(file, filename)
|
||||
|
||||
# Process and add to Chroma
|
||||
documents = process_uploaded_files()
|
||||
chroma_manager.add_documents(documents)
|
||||
|
||||
return {
|
||||
"result": "Documents processed successfully",
|
||||
"collection_info": chroma_manager.get_collection_info()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.put("/update/", response_model=ResponseSchema)
|
||||
async def update_document(update_data: DocumentUpdate):
|
||||
try:
|
||||
new_id = chroma_manager.update_document(
|
||||
update_data.doc_id,
|
||||
update_data.new_content,
|
||||
update_data.metadata
|
||||
)
|
||||
return {
|
||||
"result": f"Document updated with ID: {new_id}",
|
||||
"collection_info": chroma_manager.get_collection_info()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.delete("/delete/", response_model=ResponseSchema)
|
||||
async def delete_document(delete_data: DocumentDelete):
|
||||
try:
|
||||
chroma_manager.delete_document(delete_data.doc_id)
|
||||
return {
|
||||
"result": "Document deleted successfully",
|
||||
"collection_info": chroma_manager.get_collection_info()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/query/", response_model=ResponseSchema)
|
||||
async def process_query(query: QueryRequest):
|
||||
try:
|
||||
response = rag_system.query(query.question)
|
||||
return {
|
||||
"result": response,
|
||||
"collection_info": chroma_manager.get_collection_info()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/collection-info/", response_model=ResponseSchema)
|
||||
async def get_collection_info():
|
||||
try:
|
||||
return {
|
||||
"result": "Current collection status",
|
||||
"collection_info": chroma_manager.get_collection_info()
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,42 @@
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from chroma_manager import ChromaManager
|
||||
from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY
|
||||
|
||||
class RAGSystem:
|
||||
def __init__(self):
|
||||
self.chroma_manager = ChromaManager()
|
||||
self.llm = ChatGroq(
|
||||
temperature=0.01,
|
||||
groq_api_key=GROQ_API_KEY,
|
||||
model_name=GROQ_MODEL
|
||||
)
|
||||
self._init_retriever()
|
||||
self._init_chain()
|
||||
|
||||
def _init_retriever(self):
|
||||
model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME)
|
||||
reranker = CrossEncoderReranker(model=model, top_n=5)
|
||||
self.retriever = ContextualCompressionRetriever(
|
||||
base_compressor=reranker,
|
||||
base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10})
|
||||
)
|
||||
|
||||
def _init_chain(self):
|
||||
template = """...""" # Your existing template here
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
self.rag_chain = (
|
||||
{"context": self.retriever, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| self.llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
def query(self, question: str):
|
||||
return self.rag_chain.invoke(question)
|
||||
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
class DocumentUpload(BaseModel):
|
||||
description: Optional[str] = None
|
||||
|
||||
class DocumentUpdate(BaseModel):
|
||||
doc_id: str
|
||||
new_content: str
|
||||
metadata: dict
|
||||
|
||||
class DocumentDelete(BaseModel):
|
||||
doc_id: str
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
result: str
|
||||
collection_info: dict
|
||||
@@ -0,0 +1,22 @@
|
||||
from langchain_community.document_loaders import DirectoryLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_core.documents import Document
|
||||
from config import DOCS_PATH
|
||||
|
||||
def process_uploaded_files():
|
||||
"""Process documents in the upload directory"""
|
||||
loader = DirectoryLoader(DOCS_PATH, glob=["**/*.pdf", "**/*.txt"])
|
||||
documents = loader.load()
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200
|
||||
)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
def save_upload_file(file, filename: str):
|
||||
"""Save uploaded file to documents directory"""
|
||||
file_path = DOCS_PATH / filename
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(file.file.read())
|
||||
return file_path
|
||||
Reference in New Issue
Block a user