marketing assistant added
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||
from langchain_groq import ChatGroq
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from chroma_manager import ChromaManager
|
||||
from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY
|
||||
|
||||
class RAGSystem:
|
||||
def __init__(self):
|
||||
self.chroma_manager = ChromaManager()
|
||||
self.llm = ChatGroq(
|
||||
temperature=0.01,
|
||||
groq_api_key=GROQ_API_KEY,
|
||||
model_name=GROQ_MODEL
|
||||
)
|
||||
self._init_retriever()
|
||||
self._init_chain()
|
||||
|
||||
def _init_retriever(self):
|
||||
model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME)
|
||||
reranker = CrossEncoderReranker(model=model, top_n=5)
|
||||
self.retriever = ContextualCompressionRetriever(
|
||||
base_compressor=reranker,
|
||||
base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10})
|
||||
)
|
||||
|
||||
def _init_chain(self):
|
||||
template = """...""" # Your existing template here
|
||||
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
self.rag_chain = (
|
||||
{"context": self.retriever, "question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| self.llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
def query(self, question: str):
|
||||
return self.rag_chain.invoke(question)
|
||||
Reference in New Issue
Block a user