Files
ds_tjc/src/marketing_assistant_ai/rag.py
T

132 lines
4.2 KiB
Python
Raw Normal View History

from typing import List, Dict
import requests
2025-02-06 03:20:41 +06:00
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from chromadb.api.types import Documents, EmbeddingFunction
from config import (
MODEL_NAME, RERANKER_NAME, API_KEY,
SERVER_URL, GROQ_API_KEY, GROQ_MODEL
)
2025-02-06 03:20:41 +06:00
class CustomEmbeddingFunction(EmbeddingFunction):
def __init__(self, model_name: str):
self.model_name = model_name
self._api_key = API_KEY
self._server_url = SERVER_URL
def __call__(self, input: Documents) -> List[List[float]]:
"""Implementation of the embedding function"""
if not input:
return []
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model_name,
"input": input
}
try:
response = requests.post(
f"{self._server_url}/embeddings",
json=payload,
headers=headers
)
response.raise_for_status()
return [item['embedding'] for item in response.json()['data']]
except Exception as e:
print(f"Error in embedding: {str(e)}")
raise
class CustomReranker:
def __init__(self, model_name: str):
self.model_name = model_name
self._api_key = API_KEY
self._server_url = SERVER_URL
def rerank(self, query: str, documents: List[Dict], top_k: int = 5) -> List[Dict]:
"""
Rerank documents using the reranking model
"""
if not documents:
return []
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model_name,
"query": query,
"documents": [doc['content'] for doc in documents]
}
try:
response = requests.post(
f"{self._server_url}/rerank",
json=payload,
headers=headers
)
response.raise_for_status()
# Get reranked results
reranked_results = response.json()['results']
# Sort documents based on reranking scores
reranked_docs = []
for result in reranked_results[:top_k]:
doc_idx = result['index']
doc = documents[doc_idx].copy()
doc['relevance_score'] = result['relevance_score']
reranked_docs.append(doc)
return reranked_docs
except Exception as e:
print(f"Error in reranking: {str(e)}")
return documents # Fall back to original ordering if reranking fails
# Initialize global instances
EMBED_FUNCTION = CustomEmbeddingFunction(model_name=MODEL_NAME)
RERANKER = CustomReranker(model_name=RERANKER_NAME)
LLM = ChatGroq(temperature=0.01, groq_api_key=GROQ_API_KEY, model_name=GROQ_MODEL)
def format_context(documents: List[Dict]) -> str:
"""Format retrieved documents into a context string"""
context_parts = []
for doc in documents:
metadata = doc['metadata']
category = metadata.get('category', 'unknown')
content = doc['content']
context_parts.append(f"[{category.upper()}]\n{content}\n")
return "\n".join(context_parts)
# Template for marketing copy
TEMPLATE = """
Act like you are Adriana James, write marketing copy in her signature style. Just mimic her style and provide the answer to the user's query. Make sure that you are Adriana James, and you are providing the answer to the user's query.
Query: {question}
Adriana James Resource Context: {context}
Note: Don't provide anything extra. Just give me the response no extra words nothing at all.
"""
PROMPT = ChatPromptTemplate.from_template(TEMPLATE)
def generate_marketing_response(query: str, context: str) -> str:
"""Generate marketing response using RAG"""
chain = (
PROMPT
| LLM
| StrOutputParser()
)
2025-02-06 03:20:41 +06:00
return chain.invoke({
"question": query,
"context": context
})