2025-02-08 02:22:34 +06:00
from typing import List , Dict
import requests
2025-02-06 03:20:41 +06:00
from langchain_groq import ChatGroq
from langchain_core . prompts import ChatPromptTemplate
from langchain_core . output_parsers import StrOutputParser
2025-02-08 02:22:34 +06:00
from chromadb . api . types import Documents , EmbeddingFunction
from config import (
MODEL_NAME , RERANKER_NAME , API_KEY ,
SERVER_URL , GROQ_API_KEY , GROQ_MODEL
)
2025-02-06 03:20:41 +06:00
2025-02-08 02:22:34 +06:00
class CustomEmbeddingFunction ( EmbeddingFunction ) :
def __init__ ( self , model_name : str ) :
self . model_name = model_name
self . _api_key = API_KEY
self . _server_url = SERVER_URL
def __call__ ( self , input : Documents ) - > List [ List [ float ] ] :
""" Implementation of the embedding function """
if not input :
return [ ]
headers = {
" Authorization " : f " Bearer { self . _api_key } " ,
" Content-Type " : " application/json "
}
payload = {
" model " : self . model_name ,
" input " : input
}
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
try :
response = requests . post (
f " { self . _server_url } /embeddings " ,
json = payload ,
headers = headers
)
response . raise_for_status ( )
return [ item [ ' embedding ' ] for item in response . json ( ) [ ' data ' ] ]
except Exception as e :
print ( f " Error in embedding: { str ( e ) } " )
raise
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
class CustomReranker :
def __init__ ( self , model_name : str ) :
self . model_name = model_name
self . _api_key = API_KEY
self . _server_url = SERVER_URL
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
def rerank ( self , query : str , documents : List [ Dict ] , top_k : int = 5 ) - > List [ Dict ] :
"""
Rerank documents using the reranking model
"""
if not documents :
return [ ]
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
headers = {
" Authorization " : f " Bearer { self . _api_key } " ,
" Content-Type " : " application/json "
}
payload = {
" model " : self . model_name ,
" query " : query ,
" documents " : [ doc [ ' content ' ] for doc in documents ]
}
try :
response = requests . post (
f " { self . _server_url } /rerank " ,
json = payload ,
headers = headers
)
response . raise_for_status ( )
# Get reranked results
reranked_results = response . json ( ) [ ' results ' ]
# Sort documents based on reranking scores
reranked_docs = [ ]
for result in reranked_results [ : top_k ] :
doc_idx = result [ ' index ' ]
doc = documents [ doc_idx ] . copy ( )
doc [ ' relevance_score ' ] = result [ ' relevance_score ' ]
reranked_docs . append ( doc )
return reranked_docs
except Exception as e :
print ( f " Error in reranking: { str ( e ) } " )
return documents # Fall back to original ordering if reranking fails
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
# Initialize global instances
EMBED_FUNCTION = CustomEmbeddingFunction ( model_name = MODEL_NAME )
RERANKER = CustomReranker ( model_name = RERANKER_NAME )
LLM = ChatGroq ( temperature = 0.01 , groq_api_key = GROQ_API_KEY , model_name = GROQ_MODEL )
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
def format_context ( documents : List [ Dict ] ) - > str :
""" Format retrieved documents into a context string """
context_parts = [ ]
for doc in documents :
metadata = doc [ ' metadata ' ]
category = metadata . get ( ' category ' , ' unknown ' )
content = doc [ ' content ' ]
context_parts . append ( f " [ { category . upper ( ) } ] \n { content } \n " )
return " \n " . join ( context_parts )
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
# Template for marketing copy
TEMPLATE = """
Act like you are Adriana James, write marketing copy in her signature style. Just mimic her style and provide the answer to the user ' s query. Make sure that you are Adriana James, and you are providing the answer to the user ' s query.
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
Query: {question}
Adriana James Resource Context: {context}
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
Note: Don ' t provide anything extra. Just give me the response no extra words nothing at all.
"""
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
PROMPT = ChatPromptTemplate . from_template ( TEMPLATE )
2025-02-07 19:24:57 +06:00
2025-02-08 02:22:34 +06:00
def generate_marketing_response ( query : str , context : str ) - > str :
""" Generate marketing response using RAG """
chain = (
PROMPT
| LLM
| StrOutputParser ( )
)
2025-02-06 03:20:41 +06:00
2025-02-08 02:22:34 +06:00
return chain . invoke ( {
" question " : query ,
" context " : context
} )