marketing assistant added

This commit is contained in:
kowshik24
2025-02-06 03:20:41 +06:00
commit 480f6f06c2
7 changed files with 246 additions and 0 deletions
+42
View File
@@ -0,0 +1,42 @@
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from chroma_manager import ChromaManager
from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY
class RAGSystem:
def __init__(self):
self.chroma_manager = ChromaManager()
self.llm = ChatGroq(
temperature=0.01,
groq_api_key=GROQ_API_KEY,
model_name=GROQ_MODEL
)
self._init_retriever()
self._init_chain()
def _init_retriever(self):
model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME)
reranker = CrossEncoderReranker(model=model, top_n=5)
self.retriever = ContextualCompressionRetriever(
base_compressor=reranker,
base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10})
)
def _init_chain(self):
template = """...""" # Your existing template here
prompt = ChatPromptTemplate.from_template(template)
self.rag_chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
)
def query(self, question: str):
return self.rag_chain.invoke(question)