132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
from typing import List, Dict
|
|
import requests
|
|
from langchain_groq import ChatGroq
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from chromadb.api.types import Documents, EmbeddingFunction
|
|
from config import (
|
|
MODEL_NAME, RERANKER_NAME, API_KEY,
|
|
SERVER_URL, GROQ_API_KEY, GROQ_MODEL
|
|
)
|
|
|
|
class CustomEmbeddingFunction(EmbeddingFunction):
|
|
def __init__(self, model_name: str):
|
|
self.model_name = model_name
|
|
self._api_key = API_KEY
|
|
self._server_url = SERVER_URL
|
|
|
|
def __call__(self, input: Documents) -> List[List[float]]:
|
|
"""Implementation of the embedding function"""
|
|
if not input:
|
|
return []
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
payload = {
|
|
"model": self.model_name,
|
|
"input": input
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self._server_url}/embeddings",
|
|
json=payload,
|
|
headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
return [item['embedding'] for item in response.json()['data']]
|
|
except Exception as e:
|
|
print(f"Error in embedding: {str(e)}")
|
|
raise
|
|
|
|
class CustomReranker:
|
|
def __init__(self, model_name: str):
|
|
self.model_name = model_name
|
|
self._api_key = API_KEY
|
|
self._server_url = SERVER_URL
|
|
|
|
def rerank(self, query: str, documents: List[Dict], top_k: int = 5) -> List[Dict]:
|
|
"""
|
|
Rerank documents using the reranking model
|
|
"""
|
|
if not documents:
|
|
return []
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"model": self.model_name,
|
|
"query": query,
|
|
"documents": [doc['content'] for doc in documents]
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self._server_url}/rerank",
|
|
json=payload,
|
|
headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
|
|
# Get reranked results
|
|
reranked_results = response.json()['results']
|
|
|
|
# Sort documents based on reranking scores
|
|
reranked_docs = []
|
|
for result in reranked_results[:top_k]:
|
|
doc_idx = result['index']
|
|
doc = documents[doc_idx].copy()
|
|
doc['relevance_score'] = result['relevance_score']
|
|
reranked_docs.append(doc)
|
|
|
|
return reranked_docs
|
|
|
|
except Exception as e:
|
|
print(f"Error in reranking: {str(e)}")
|
|
return documents # Fall back to original ordering if reranking fails
|
|
|
|
# Initialize global instances
|
|
EMBED_FUNCTION = CustomEmbeddingFunction(model_name=MODEL_NAME)
|
|
RERANKER = CustomReranker(model_name=RERANKER_NAME)
|
|
LLM = ChatGroq(temperature=0.01, groq_api_key=GROQ_API_KEY, model_name=GROQ_MODEL)
|
|
|
|
def format_context(documents: List[Dict]) -> str:
|
|
"""Format retrieved documents into a context string"""
|
|
context_parts = []
|
|
for doc in documents:
|
|
metadata = doc['metadata']
|
|
category = metadata.get('category', 'unknown')
|
|
content = doc['content']
|
|
context_parts.append(f"[{category.upper()}]\n{content}\n")
|
|
|
|
return "\n".join(context_parts)
|
|
|
|
# Template for marketing copy
|
|
TEMPLATE = """
|
|
Act like you are Adriana James, write marketing copy in her signature style. Just mimic her style and provide the answer to the user's query. Make sure that you are Adriana James, and you are providing the answer to the user's query.
|
|
|
|
Query: {question}
|
|
Adriana James Resource Context: {context}
|
|
|
|
Note: Don't provide anything extra. Just give me the response no extra words nothing at all.
|
|
"""
|
|
|
|
PROMPT = ChatPromptTemplate.from_template(TEMPLATE)
|
|
|
|
def generate_marketing_response(query: str, context: str) -> str:
|
|
"""Generate marketing response using RAG"""
|
|
chain = (
|
|
PROMPT
|
|
| LLM
|
|
| StrOutputParser()
|
|
)
|
|
|
|
return chain.invoke({
|
|
"question": query,
|
|
"context": context
|
|
}) |