42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
|
|
from langchain.retrievers import ContextualCompressionRetriever
|
||
|
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||
|
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||
|
|
from langchain_groq import ChatGroq
|
||
|
|
from langchain_core.prompts import ChatPromptTemplate
|
||
|
|
from langchain_core.output_parsers import StrOutputParser
|
||
|
|
from langchain_core.runnables import RunnablePassthrough
|
||
|
|
from chroma_manager import ChromaManager
|
||
|
|
from config import RERANKER_NAME, GROQ_MODEL, GROQ_API_KEY
|
||
|
|
|
||
|
|
class RAGSystem:
|
||
|
|
def __init__(self):
|
||
|
|
self.chroma_manager = ChromaManager()
|
||
|
|
self.llm = ChatGroq(
|
||
|
|
temperature=0.01,
|
||
|
|
groq_api_key=GROQ_API_KEY,
|
||
|
|
model_name=GROQ_MODEL
|
||
|
|
)
|
||
|
|
self._init_retriever()
|
||
|
|
self._init_chain()
|
||
|
|
|
||
|
|
def _init_retriever(self):
|
||
|
|
model = HuggingFaceCrossEncoder(model_name=RERANKER_NAME)
|
||
|
|
reranker = CrossEncoderReranker(model=model, top_n=5)
|
||
|
|
self.retriever = ContextualCompressionRetriever(
|
||
|
|
base_compressor=reranker,
|
||
|
|
base_retriever=self.chroma_manager.vector_store.as_retriever(search_kwargs={"k": 10})
|
||
|
|
)
|
||
|
|
|
||
|
|
def _init_chain(self):
|
||
|
|
template = """...""" # Your existing template here
|
||
|
|
|
||
|
|
prompt = ChatPromptTemplate.from_template(template)
|
||
|
|
self.rag_chain = (
|
||
|
|
{"context": self.retriever, "question": RunnablePassthrough()}
|
||
|
|
| prompt
|
||
|
|
| self.llm
|
||
|
|
| StrOutputParser()
|
||
|
|
)
|
||
|
|
|
||
|
|
def query(self, question: str):
|
||
|
|
return self.rag_chain.invoke(question)
|