Refactor schemas and configuration for marketing assistant; add new test endpoints and utility functions for file handling and document processing
This commit is contained in:
@@ -1,73 +1,270 @@
|
|||||||
|
from typing import List, Optional, Dict
|
||||||
|
import chromadb
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from datetime import datetime
|
||||||
from langchain_chroma import Chroma
|
from config import CHROMA_PATH, COLLECTION_NAME
|
||||||
from langchain_core.documents import Document
|
from rag import EMBED_FUNCTION
|
||||||
from config import settings
|
|
||||||
from utils import CustomEmbeddings
|
|
||||||
|
|
||||||
class ChromaManager:
|
class ChromaManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vector_store = Chroma(
|
self.client = chromadb.PersistentClient(path=str(CHROMA_PATH))
|
||||||
collection_name=settings.COLLECTION_NAME,
|
self.collection = self.client.get_or_create_collection(
|
||||||
persist_directory=settings.CHROMA_PATH,
|
name=COLLECTION_NAME,
|
||||||
embedding_function=CustomEmbeddings(settings.MODEL_NAME)
|
embedding_function=EMBED_FUNCTION
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_collection_info(self):
|
def _calculate_index_size(self) -> str:
|
||||||
index_size = 0
|
"""Calculate the total size of the Chroma index directory"""
|
||||||
if os.path.exists(settings.CHROMA_PATH):
|
total_size = 0
|
||||||
for dirpath, _, filenames in os.walk(settings.CHROMA_PATH):
|
for dirpath, _, filenames in os.walk(CHROMA_PATH):
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
fp = os.path.join(dirpath, f)
|
fp = os.path.join(dirpath, f)
|
||||||
index_size += os.path.getsize(fp)
|
total_size += os.path.getsize(fp)
|
||||||
|
return f"{total_size / (1024 * 1024):.2f} MB"
|
||||||
|
|
||||||
|
def add_documents(self, documents: List[Document], category: str) -> List[str]:
|
||||||
|
"""Add documents with category metadata"""
|
||||||
|
ids = [str(uuid.uuid4()) for _ in documents]
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [
|
||||||
|
{**doc.metadata, "category": category}
|
||||||
|
for doc in documents
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
|
||||||
"collection_name": settings.COLLECTION_NAME,
|
|
||||||
"document_count": self.vector_store._collection.count(),
|
|
||||||
"index_size": f"{index_size/1024/1024:.2f} MB"
|
|
||||||
}
|
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document]):
|
|
||||||
try:
|
try:
|
||||||
ids = [str(uuid.uuid4()) for _ in documents]
|
self.collection.add(
|
||||||
self.vector_store.add_documents(documents, ids=ids)
|
documents=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
return ids
|
return ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error adding documents: {str(e)}")
|
print(f"Error adding documents: {str(e)}")
|
||||||
|
raise
|
||||||
def delete_document(self, doc_id: str):
|
|
||||||
|
def delete_document(self, doc_id: str) -> bool:
|
||||||
|
"""Delete document by ID"""
|
||||||
try:
|
try:
|
||||||
self.vector_store._collection.delete(ids=[doc_id])
|
self.collection.delete(ids=[doc_id])
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Delete error: {str(e)}")
|
print(f"Error deleting document: {str(e)}")
|
||||||
|
return False
|
||||||
def get_all_files_metadata(self):
|
|
||||||
try:
|
def update_document(self, doc_id: str, new_content: str, metadata: dict) -> Optional[str]:
|
||||||
print(len(self.vector_store.get()["ids"]))
|
"""Update document content and metadata"""
|
||||||
for x in range(len(self.vector_store.get()["ids"])):
|
|
||||||
# print(db.get()["metadatas"][x])
|
|
||||||
doc = self.vector_store.get()["metadatas"][x]
|
|
||||||
source = doc["source"]
|
|
||||||
print(source)
|
|
||||||
|
|
||||||
return self.vector_store.get()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error retrieving files metadata: {str(e)}")
|
|
||||||
|
|
||||||
def update_document(self, doc_id: str, new_content: str, metadata: dict):
|
|
||||||
try:
|
try:
|
||||||
self.delete_document(doc_id)
|
self.delete_document(doc_id)
|
||||||
new_doc = Document(page_content=new_content, metadata=metadata)
|
new_id = str(uuid.uuid4())
|
||||||
return self.add_documents([new_doc])[0]
|
self.collection.add(
|
||||||
|
documents=[new_content],
|
||||||
|
metadatas=[metadata],
|
||||||
|
ids=[new_id]
|
||||||
|
)
|
||||||
|
return new_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Update error: {str(e)}")
|
print(f"Error updating document: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
def get_collection_info(self):
|
||||||
# chroma_manager = ChromaManager()
|
"""Get collection statistics including category counts"""
|
||||||
# #chroma_manager.create_collection()
|
try:
|
||||||
# #chroma_manager.add_documents([Document(page_content="Test document", metadata={"source": "test"})])
|
count = self.collection.count()
|
||||||
# print(chroma_manager.get_all_files_metadata())
|
all_metadata = self.collection.get()
|
||||||
# print(chroma_manager.get_collection_info())
|
metadatas = all_metadata["metadatas"] if all_metadata else []
|
||||||
|
|
||||||
|
category_counts = {}
|
||||||
|
for metadata in metadatas:
|
||||||
|
category = metadata.get("category", "unknown")
|
||||||
|
category_counts[category] = category_counts.get(category, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"collection_name": COLLECTION_NAME,
|
||||||
|
"document_count": count,
|
||||||
|
"index_size": self._calculate_index_size(),
|
||||||
|
"category_counts": category_counts
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting collection info: {str(e)}")
|
||||||
|
raise
|
||||||
|
def query_documents(self, query: str, category: Optional[str] = None, top_k: int = 5) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Query documents from the collection with optional category filter
|
||||||
|
Returns list of documents with their content and metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prepare where clause for category filter
|
||||||
|
where = {"category": category} if category else None
|
||||||
|
|
||||||
|
# Query the collection
|
||||||
|
results = self.collection.query(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=top_k,
|
||||||
|
where=where,
|
||||||
|
include=["documents", "metadatas", "distances"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
documents = []
|
||||||
|
if results and results['documents'] and results['documents'][0]:
|
||||||
|
for doc, metadata, distance in zip(
|
||||||
|
results['documents'][0],
|
||||||
|
results['metadatas'][0],
|
||||||
|
results['distances'][0]
|
||||||
|
):
|
||||||
|
documents.append({
|
||||||
|
'content': doc,
|
||||||
|
'metadata': metadata,
|
||||||
|
'relevance_score': 1 - (distance / 2) # Convert distance to similarity score
|
||||||
|
})
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error querying documents: {str(e)}")
|
||||||
|
raise
|
||||||
|
def get_files_by_category(self, category: Optional[str] = None) -> Dict[str, dict]:
|
||||||
|
"""Get all files for a specific category or all categories"""
|
||||||
|
try:
|
||||||
|
# Get all documents
|
||||||
|
results = self.collection.get(
|
||||||
|
include=['metadatas']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize empty result structure
|
||||||
|
files_map = {}
|
||||||
|
|
||||||
|
if results and results['metadatas']:
|
||||||
|
# Group by filename and category
|
||||||
|
for metadata in results['metadatas']:
|
||||||
|
doc_category = metadata.get('category', 'unknown')
|
||||||
|
print("The metadata is", metadata)
|
||||||
|
filename = metadata.get('source', 'unknown')
|
||||||
|
doc_id = metadata.get('document_id')
|
||||||
|
print("The doc_id is", doc_id)
|
||||||
|
|
||||||
|
# Skip if category filter is applied and doesn't match
|
||||||
|
if category and doc_category != category:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Initialize category if not exists
|
||||||
|
if doc_category not in files_map:
|
||||||
|
files_map[doc_category] = {
|
||||||
|
'category': doc_category, # Add category field
|
||||||
|
'files': [],
|
||||||
|
'total_files': 0,
|
||||||
|
'_file_map': {} # Temporary map for aggregating
|
||||||
|
}
|
||||||
|
|
||||||
|
# Aggregate file information
|
||||||
|
if filename not in files_map[doc_category]['_file_map']:
|
||||||
|
files_map[doc_category]['_file_map'][filename] = {
|
||||||
|
'filename': filename,
|
||||||
|
'category': doc_category,
|
||||||
|
'upload_date': metadata.get('upload_date'),
|
||||||
|
'doc_ids': []
|
||||||
|
}
|
||||||
|
|
||||||
|
if doc_id:
|
||||||
|
files_map[doc_category]['_file_map'][filename]['doc_ids'].append(doc_id)
|
||||||
|
|
||||||
|
# Convert temporary map to final format
|
||||||
|
result = {}
|
||||||
|
for cat, data in files_map.items():
|
||||||
|
files = list(data['_file_map'].values())
|
||||||
|
result[cat] = {
|
||||||
|
'category': cat, # Include category in response
|
||||||
|
'files': files,
|
||||||
|
'total_files': len(files)
|
||||||
|
}
|
||||||
|
|
||||||
|
# If category is specified, return only that category
|
||||||
|
if category:
|
||||||
|
return {
|
||||||
|
category: result.get(category, {
|
||||||
|
'category': category, # Include category in empty response
|
||||||
|
'files': [],
|
||||||
|
'total_files': 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error getting files by category: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def delete_file(self, filename: str, category: str) -> bool:
|
||||||
|
print("The filename is: ", filename)
|
||||||
|
print("The category is: ", category)
|
||||||
|
try:
|
||||||
|
full_path = f"/home/kowshik/work/ds_tjc/marketing_data/{category}/{filename}"
|
||||||
|
print("The full path is: ", full_path)
|
||||||
|
|
||||||
|
# Get all documents with matching source and category
|
||||||
|
results = self.collection.get(
|
||||||
|
where={
|
||||||
|
"$and": [
|
||||||
|
{"source": {"$eq": full_path}},
|
||||||
|
{"category": {"$eq": category}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
include=["documents", "metadatas"] # Specify fields to include
|
||||||
|
)
|
||||||
|
|
||||||
|
if results and results['ids']:
|
||||||
|
# Delete the documents
|
||||||
|
self.collection.delete(
|
||||||
|
where={
|
||||||
|
"$and": [
|
||||||
|
{"source": {"$eq": full_path}},
|
||||||
|
{"category": {"$eq": category}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in delete_file: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def update_file_content(self, filename: str, category: str, new_content: str) -> bool:
|
||||||
|
"""Update all chunks associated with a specific file"""
|
||||||
|
try:
|
||||||
|
# First delete existing chunks
|
||||||
|
if not self.delete_file(filename, category):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create new chunks
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
chunks = text_splitter.split_text(new_content)
|
||||||
|
|
||||||
|
# Add new chunks
|
||||||
|
doc_ids = [str(uuid.uuid4()) for _ in chunks]
|
||||||
|
metadatas = [{
|
||||||
|
'filename': filename,
|
||||||
|
'category': category,
|
||||||
|
'doc_id': doc_id,
|
||||||
|
'upload_date': datetime.utcnow().isoformat(),
|
||||||
|
'chunk_index': idx
|
||||||
|
} for idx, doc_id in enumerate(doc_ids)]
|
||||||
|
|
||||||
|
self.collection.add(
|
||||||
|
documents=chunks,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=doc_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error updating file: {str(e)}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1,15 +1,29 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class Settings:
|
# Base directory configuration
|
||||||
MODEL_NAME = "BAAI/bge-large-en-v1.5"
|
#BASE_DIR = Path(__file__).resolve().parent
|
||||||
RERANKER_NAME = "BAAI/bge-reranker-large"
|
BASE_DIR = Path("/home/kowshik/work/ds_tjc")
|
||||||
GROQ_MODEL = "llama-3.3-70b-versatile"
|
|
||||||
#DOCS_PATH = "/home/kowshik/work/ds_tjc/datasets/Client-Assets"
|
|
||||||
DOCS_PATH = "/home/kowshik/work/ds_tjc/datasets/marketing_data"
|
|
||||||
CHROMA_PATH = "/home/kowshik/work/ds_tjc/chroma_index"
|
|
||||||
COLLECTION_NAME = "marketing_docs"
|
|
||||||
API_KEY = "4BkwTtVd5VwhTiFDdG3NfzgATrCq7aD8AjnvWNeivirTntHgRvL6Xe84ULHcVTLB"
|
|
||||||
SERVER_URL = "https://ma.rommelcorral.com"
|
|
||||||
GROQ_API_KEY = "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"
|
|
||||||
|
|
||||||
settings = Settings()
|
# Data directories
|
||||||
|
UPLOAD_DIR = BASE_DIR / "marketing_data"
|
||||||
|
CHROMA_PATH = BASE_DIR / "chroma_index"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
UPLOAD_DIR.mkdir(exist_ok=True)
|
||||||
|
for category in ["email", "books", "article", "social"]:
|
||||||
|
(UPLOAD_DIR / category).mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Model configurations
|
||||||
|
MODEL_NAME = "BAAI/bge-large-en-v1.5"
|
||||||
|
RERANKER_NAME = "BAAI/bge-reranker-large"
|
||||||
|
GROQ_MODEL = "llama-3.3-70b-versatile"
|
||||||
|
COLLECTION_NAME = "marketing_docs"
|
||||||
|
|
||||||
|
# API configurations
|
||||||
|
API_KEY = "4BkwTtVd5VwhTiFDdG3NfzgATrCq7aD8AjnvWNeivirTntHgRvL6Xe84ULHcVTLB"
|
||||||
|
SERVER_URL = "https://ma.rommelcorral.com"
|
||||||
|
GROQ_API_KEY = "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"
|
||||||
|
|
||||||
|
# Valid document categories
|
||||||
|
VALID_CATEGORIES = ["email", "books", "article", "social"]
|
||||||
@@ -1,83 +1,265 @@
|
|||||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Query
|
||||||
from fastapi.responses import JSONResponse
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List, Optional, Dict
|
||||||
import base64
|
from datetime import datetime
|
||||||
from langchain_core.documents import Document
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from chroma_manager import ChromaManager
|
import uvicorn
|
||||||
from rag import RAGSystem
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from schemas import (
|
from schemas import (
|
||||||
DocumentUpload,
|
CategoryEnum, DocumentResponse,
|
||||||
QueryRequest,
|
UpdateDocument, QueryRequest, CollectionInfo,
|
||||||
DocumentResponse,
|
CategoryFiles,
|
||||||
CollectionInfo,
|
)
|
||||||
UpdateDocumentRequest
|
from utils import save_upload_file, load_and_split_documents
|
||||||
|
from chroma_manager import ChromaManager
|
||||||
|
from rag import generate_marketing_response,format_context, RERANKER
|
||||||
|
from config import UPLOAD_DIR
|
||||||
|
|
||||||
|
app = FastAPI(title="Marketing Assistant AI")
|
||||||
|
|
||||||
|
# Ensure upload directories exist
|
||||||
|
for category in CategoryEnum:
|
||||||
|
os.makedirs(UPLOAD_DIR / category.value, exist_ok=True)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
from utils import save_uploaded_file
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
chroma_manager = ChromaManager()
|
chroma_manager = ChromaManager()
|
||||||
rag_system = RAGSystem()
|
|
||||||
|
|
||||||
@app.post("/upload/")
|
|
||||||
async def upload_document(file: UploadFile = File(...), file_category: str = Form(...)):
|
@app.get("/health")
|
||||||
|
async def check_health():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.post("/upload/{category}", response_model=DocumentResponse)
|
||||||
|
async def upload_document(
|
||||||
|
category: CategoryEnum,
|
||||||
|
file: UploadFile = File(...)
|
||||||
|
):
|
||||||
|
"""Upload and process a document for a specific category"""
|
||||||
|
if not file.filename.lower().endswith(('.pdf', '.txt', '.docx', '.pptx', '.png', '.jpg', '.jpeg')):
|
||||||
|
raise HTTPException(400, "Only PDF and TXT files are supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if file_category not in ["email", "books", "article", "social"]:
|
# Save file temporarily
|
||||||
raise HTTPException(status_code=400, detail="Invalid file category")
|
file_path = save_upload_file(file, category)
|
||||||
|
|
||||||
content = await file.read()
|
# Process and split document
|
||||||
filepath = save_uploaded_file(content, file.filename)
|
splits = load_and_split_documents(file_path, folder_path=UPLOAD_DIR / category.value)
|
||||||
|
|
||||||
document = Document(
|
if not splits:
|
||||||
page_content=str(content), # Convert bytes to string representation
|
raise HTTPException(400, "No content could be extracted from the file")
|
||||||
metadata={"source": filepath, "filename": file.filename}
|
|
||||||
)
|
|
||||||
|
|
||||||
doc_id = chroma_manager.add_documents([document])[0]
|
# Add to vector store
|
||||||
return JSONResponse(
|
doc_ids = chroma_manager.add_documents(splits, category.value)
|
||||||
content={"message": "Document uploaded successfully", "doc_id": doc_id},
|
|
||||||
status_code=201
|
return DocumentResponse(
|
||||||
|
document_id=doc_ids[0],
|
||||||
|
category=category,
|
||||||
|
filename=file.filename,
|
||||||
|
status="success"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(500, f"Error processing document: {str(e)}")
|
||||||
|
|
||||||
@app.post("/query/")
|
# @app.delete("/document/{doc_id}")
|
||||||
async def process_query(query: QueryRequest):
|
# async def delete_document(doc_id: str):
|
||||||
try:
|
# """Delete a document by ID"""
|
||||||
response = rag_system.get_response(query.question)
|
# success = chroma_manager.delete_document(doc_id)
|
||||||
return {"response": response}
|
# if not success:
|
||||||
except Exception as e:
|
# raise HTTPException(404, "Document not found")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
# return {"status": "success", "message": "Document deleted"}
|
||||||
|
|
||||||
@app.delete("/documents/{doc_id}")
|
# @app.put("/document/{doc_id}")
|
||||||
async def delete_document(doc_id: str):
|
# async def update_document(doc_id: str, update: UpdateDocument):
|
||||||
try:
|
# """Update document content and metadata"""
|
||||||
success = chroma_manager.delete_document(doc_id)
|
# new_id = chroma_manager.update_document(
|
||||||
if success:
|
# doc_id,
|
||||||
return {"message": "Document deleted successfully"}
|
# update.content,
|
||||||
raise HTTPException(status_code=404, detail="Document not found")
|
# update.metadata
|
||||||
except Exception as e:
|
# )
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
# if not new_id:
|
||||||
|
# raise HTTPException(404, "Document not found")
|
||||||
|
# return {"status": "success", "new_id": new_id}
|
||||||
|
|
||||||
@app.put("/documents/{doc_id}")
|
@app.post("/query")
|
||||||
async def update_document(doc_id: str, update_data: UpdateDocumentRequest):
|
async def query_documents(request: QueryRequest,
|
||||||
|
category: CategoryEnum):
|
||||||
|
"""Query documents and generate marketing response"""
|
||||||
try:
|
try:
|
||||||
new_id = chroma_manager.update_document(
|
# Initial retrieval from vector store
|
||||||
doc_id,
|
initial_results = chroma_manager.query_documents(
|
||||||
update_data.new_content,
|
query=request.query,
|
||||||
update_data.metadata
|
category=category if category else None,
|
||||||
|
top_k=10 # Retrieve more documents initially for reranking
|
||||||
)
|
)
|
||||||
return {"message": "Document updated", "new_doc_id": new_id}
|
|
||||||
|
if not initial_results:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="No relevant documents found for the query"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rerank the results
|
||||||
|
reranked_results = RERANKER.rerank(
|
||||||
|
query=request.query,
|
||||||
|
documents=initial_results,
|
||||||
|
top_k=5 # Keep top 5 most relevant documents after reranking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format the context from reranked results
|
||||||
|
context = format_context(reranked_results)
|
||||||
|
|
||||||
|
# Generate response using the formatted context
|
||||||
|
response = generate_marketing_response(request.query, context)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": response,
|
||||||
|
"context": context, # Optionally include context for transparency
|
||||||
|
"documents": [ # Optionally include document metadata
|
||||||
|
{
|
||||||
|
"category": doc["metadata"].get("category"),
|
||||||
|
"relevance_score": doc["relevance_score"]
|
||||||
|
}
|
||||||
|
for doc in reranked_results
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(500, f"Error processing query: {str(e)}")
|
||||||
|
|
||||||
@app.get("/collection-info/", response_model=CollectionInfo)
|
@app.get("/collection-info", response_model=CollectionInfo)
|
||||||
async def get_collection_info():
|
async def get_collection_info():
|
||||||
|
"""Get information about the document collection"""
|
||||||
try:
|
try:
|
||||||
info = chroma_manager.get_collection_info()
|
return chroma_manager.get_collection_info()
|
||||||
return info
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(500, f"Error getting collection info: {str(e)}")
|
||||||
|
|
||||||
|
@app.get("/files/{category}", response_model=Dict[str, CategoryFiles])
|
||||||
|
async def get_files(
|
||||||
|
category: CategoryEnum
|
||||||
|
):
|
||||||
|
"""Get all files for a specific category or all categories"""
|
||||||
|
try:
|
||||||
|
result = chroma_manager.get_files_by_category(category)
|
||||||
|
|
||||||
|
# Ensure we have a valid response structure even if no files are found
|
||||||
|
if category and category not in result:
|
||||||
|
result[category] = {
|
||||||
|
'category': category, # Include category field
|
||||||
|
'files': [],
|
||||||
|
'total_files': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(500, f"Error getting files: {str(e)}")
|
||||||
|
|
||||||
|
@app.delete("/files/{category}/{filename}")
|
||||||
|
async def delete_file(
|
||||||
|
category: CategoryEnum,
|
||||||
|
filename: str
|
||||||
|
):
|
||||||
|
"""Delete a specific file and all its chunks"""
|
||||||
|
try:
|
||||||
|
# # URL decode the filename
|
||||||
|
# decoded_filename = unquote(filename)
|
||||||
|
|
||||||
|
# # If it's a full path, extract just the filename
|
||||||
|
# if '/' in decoded_filename:
|
||||||
|
# decoded_filename = os.path.basename(decoded_filename)
|
||||||
|
|
||||||
|
print(f"Attempting to delete file: {filename} from category: {category.value}")
|
||||||
|
|
||||||
|
success = chroma_manager.delete_file(filename, category.value)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"File {filename} not found in category {category}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"File {filename} deleted successfully",
|
||||||
|
"category": category.value,
|
||||||
|
"filename": filename
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in delete_file endpoint: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Error deleting file: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.put("/files/{category}/{filename}")
|
||||||
|
async def update_file(
|
||||||
|
category: CategoryEnum,
|
||||||
|
filename: str,
|
||||||
|
content: str
|
||||||
|
):
|
||||||
|
"""Update content for a specific file"""
|
||||||
|
try:
|
||||||
|
success = chroma_manager.update_file_content(filename, category.value, content)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(404, f"File {filename} not found in category {category}")
|
||||||
|
return {"status": "success", "message": f"File {filename} updated successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(500, f"Error updating file: {str(e)}")
|
||||||
|
|
||||||
|
# Modify the upload endpoint to include filename in metadata
|
||||||
|
@app.post("/upload/{category}", response_model=DocumentResponse)
|
||||||
|
async def upload_document(
|
||||||
|
category: CategoryEnum,
|
||||||
|
file: UploadFile = File(...)
|
||||||
|
):
|
||||||
|
"""Upload and process a document for a specific category"""
|
||||||
|
if not file.filename.lower().endswith(('.pdf', '.txt')):
|
||||||
|
raise HTTPException(400, "Only PDF and TXT files are supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save file temporarily
|
||||||
|
file_path = save_upload_file(file, category)
|
||||||
|
|
||||||
|
# Process and split document
|
||||||
|
splits = load_and_split_documents(file_path)
|
||||||
|
|
||||||
|
if not splits:
|
||||||
|
raise HTTPException(400, "No content could be extracted from the file")
|
||||||
|
|
||||||
|
# Add metadata to splits
|
||||||
|
for split in splits:
|
||||||
|
split.metadata.update({
|
||||||
|
'filename': file.filename,
|
||||||
|
'category': category.value,
|
||||||
|
'upload_date': datetime.utcnow().isoformat(),
|
||||||
|
'doc_id': str(uuid.uuid4())
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add to vector store
|
||||||
|
doc_ids = chroma_manager.add_documents(splits, category.value)
|
||||||
|
|
||||||
|
return DocumentResponse(
|
||||||
|
document_id=doc_ids[0],
|
||||||
|
category=category,
|
||||||
|
filename=file.filename,
|
||||||
|
status="success"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(500, f"Error processing document: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
||||||
@@ -1,91 +1,132 @@
|
|||||||
from langchain.retrievers import ContextualCompressionRetriever
|
from typing import List, Dict
|
||||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
import requests
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.runnables import RunnablePassthrough
|
from chromadb.api.types import Documents, EmbeddingFunction
|
||||||
from config import settings
|
from config import (
|
||||||
from utils import CustomEmbeddings, CustomCrossEncoder
|
MODEL_NAME, RERANKER_NAME, API_KEY,
|
||||||
from langchain_chroma import Chroma
|
SERVER_URL, GROQ_API_KEY, GROQ_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
class RAGSystem:
|
class CustomEmbeddingFunction(EmbeddingFunction):
|
||||||
def __init__(self):
|
def __init__(self, model_name: str):
|
||||||
self.embeddings = CustomEmbeddings(settings.MODEL_NAME)
|
self.model_name = model_name
|
||||||
self.reranker = CrossEncoderReranker(
|
self._api_key = API_KEY
|
||||||
model=CustomCrossEncoder(settings.RERANKER_NAME),
|
self._server_url = SERVER_URL
|
||||||
top_n=5
|
|
||||||
)
|
def __call__(self, input: Documents) -> List[List[float]]:
|
||||||
|
"""Implementation of the embedding function"""
|
||||||
|
if not input:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": input
|
||||||
|
}
|
||||||
|
|
||||||
self.vector_store = Chroma(
|
try:
|
||||||
collection_name=settings.COLLECTION_NAME,
|
response = requests.post(
|
||||||
persist_directory=settings.CHROMA_PATH,
|
f"{self._server_url}/embeddings",
|
||||||
embedding_function=self.embeddings
|
json=payload,
|
||||||
)
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return [item['embedding'] for item in response.json()['data']]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in embedding: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
class CustomReranker:
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
self._api_key = API_KEY
|
||||||
|
self._server_url = SERVER_URL
|
||||||
|
|
||||||
|
def rerank(self, query: str, documents: List[Dict], top_k: int = 5) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Rerank documents using the reranking model
|
||||||
|
"""
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self._api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
self.retriever = ContextualCompressionRetriever(
|
payload = {
|
||||||
base_compressor=self.reranker,
|
"model": self.model_name,
|
||||||
base_retriever=self.vector_store.as_retriever(search_kwargs={"k": 10})
|
"query": query,
|
||||||
)
|
"documents": [doc['content'] for doc in documents]
|
||||||
|
}
|
||||||
|
|
||||||
self.llm = ChatGroq(
|
try:
|
||||||
temperature=0.01,
|
response = requests.post(
|
||||||
groq_api_key=settings.GROQ_API_KEY,
|
f"{self._server_url}/rerank",
|
||||||
model_name=settings.GROQ_MODEL
|
json=payload,
|
||||||
)
|
headers=headers
|
||||||
|
)
|
||||||
self.prompt = ChatPromptTemplate.from_template("""
|
response.raise_for_status()
|
||||||
Act like you are Adriana James, write marketing copy in her signature style. Just mimic her style and provide the answer to the user's query. Make sure that you are Adriana James, and you are providing the answer to the user's query.
|
|
||||||
|
# Get reranked results
|
||||||
|
reranked_results = response.json()['results']
|
||||||
|
|
||||||
|
# Sort documents based on reranking scores
|
||||||
|
reranked_docs = []
|
||||||
|
for result in reranked_results[:top_k]:
|
||||||
|
doc_idx = result['index']
|
||||||
|
doc = documents[doc_idx].copy()
|
||||||
|
doc['relevance_score'] = result['relevance_score']
|
||||||
|
reranked_docs.append(doc)
|
||||||
|
|
||||||
|
return reranked_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in reranking: {str(e)}")
|
||||||
|
return documents # Fall back to original ordering if reranking fails
|
||||||
|
|
||||||
Here is some of her past **Email_Templates**:
|
# Initialize global instances
|
||||||
|
EMBED_FUNCTION = CustomEmbeddingFunction(model_name=MODEL_NAME)
|
||||||
|
RERANKER = CustomReranker(model_name=RERANKER_NAME)
|
||||||
|
LLM = ChatGroq(temperature=0.01, groq_api_key=GROQ_API_KEY, model_name=GROQ_MODEL)
|
||||||
|
|
||||||
Template - 1:
|
def format_context(documents: List[Dict]) -> str:
|
||||||
Dear friend,
|
"""Format retrieved documents into a context string"""
|
||||||
|
context_parts = []
|
||||||
As we approach the final days of 2024, I wanted to reach out with a message of hope and possibility for the year ahead. The dawn of 2025 brings with it an opportunity not just for fresh starts, but for transformative growth and achievement.
|
for doc in documents:
|
||||||
You may have already begun thinking about your aspirations for the coming year. Whether you have or haven't, I'd like to personally invite you to join me for an intimate goal-setting session where we'll work together to crystallize your vision for 2025.
|
metadata = doc['metadata']
|
||||||
I believe that every remarkable success story begins with clarity - knowing exactly what you want and placing it firmly in your future Time Line in a way that makes it inevitable. This isn't just about writing down wishes; it's about crafting the blueprint for your next chapter.
|
category = metadata.get('category', 'unknown')
|
||||||
Before our session, I encourage you to reflect on three crucial questions:
|
content = doc['content']
|
||||||
|
context_parts.append(f"[{category.upper()}]\n{content}\n")
|
||||||
1. What is the most important achievement you envision for 2025?
|
|
||||||
2. How can you leverage your unique experiences and skills to create positive change in the world?
|
|
||||||
3. What stepping stones will you need to place along your path to ensure your primary goal becomes reality?
|
|
||||||
|
|
||||||
Here's what makes this journey so powerful: as you pursue specific goals, you naturally develop new skills, strategies, and behaviors. Sometimes, achieving a goal requires you to become an entirely new version of yourself - and that transformation is often the most valuable reward of all.
|
|
||||||
Join me for this complimentary goal-setting session:
|
|
||||||
Date: Thursday 15 January 2025
|
|
||||||
Time: 4pm AEDT
|
|
||||||
Register Today
|
|
||||||
The more attention you invest in this process, the more you'll free yourself from limitations, unleash your creativity, and uncover possibilities you never imagined. This creates a beautiful cycle: greater goals lead to greater successes, which build self-confidence and positive momentum.
|
|
||||||
Register today for this special session. I look forward to helping you lay the foundation for an extraordinary 2025.
|
|
||||||
Be Well!
|
|
||||||
|
|
||||||
Template - 2:
|
|
||||||
Hi [[contact.first_name]],
|
|
||||||
I trust you've been putting the valuable insights from our recent Goal-Setting Masterclass to good use. I hope you've had a chance to dive into the videos and set your sights on exciting goals for 2025 across all areas of your life -career, relationships, finance, health, and beyond.
|
|
||||||
Now, I'm thrilled to invite you to join me for an exclusive live Q&A session on Monday, February 3rd, 2025, at 7PM AEDT. This is your opportunity to delve deeper into the techniques shared and learn more about how to make 2025 truly exceptional.
|
|
||||||
Whether you're looking to fine-tune your objectives, overcome obstacles, or gain more insights into applying these powerful techniques, I'm here to support you. Let's work together to make sure you're fully equipped to create a prosperous and successful year in every aspect of your life.
|
|
||||||
Come prepared with your questions, and let's turn your 2025 goals into reality!
|
|
||||||
Zoom details
|
|
||||||
Be Well!
|
|
||||||
Dr Adriana James and team
|
|
||||||
For more information
|
|
||||||
visit www.nlpcoaching.com or email us via info@nlpcoaching.com | Copyright 2025 The Tad James Company. All rights reserved. Australia/International: 90-96 Bourke Road Alexandria, NSW 2015 United States / International: 1450 W Horizon Ridge Pkway #544, Henderson NV, 89012 Unsubscribe
|
|
||||||
|
|
||||||
|
|
||||||
Query: {question}
|
|
||||||
Adriana James Resource Context: {context}
|
|
||||||
|
|
||||||
Now, write marketing copy in Adriana James' signature style from the context(Adriana James content) above and provide the answer to the user's query.
|
|
||||||
|
|
||||||
Note: Don't provide anything extra. Just give me the response no extra words nothing at all. Just the response to the user's query.
|
|
||||||
""")
|
|
||||||
|
|
||||||
self.rag_chain = (
|
|
||||||
{"context": self.retriever, "question": RunnablePassthrough()}
|
|
||||||
| self.prompt
|
|
||||||
| self.llm
|
|
||||||
| StrOutputParser()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_response(self, question: str) -> str:
|
return "\n".join(context_parts)
|
||||||
return self.rag_chain.invoke(question)
|
|
||||||
|
# Template for marketing copy
|
||||||
|
TEMPLATE = """
|
||||||
|
Act like you are Adriana James, write marketing copy in her signature style. Just mimic her style and provide the answer to the user's query. Make sure that you are Adriana James, and you are providing the answer to the user's query.
|
||||||
|
|
||||||
|
Query: {question}
|
||||||
|
Adriana James Resource Context: {context}
|
||||||
|
|
||||||
|
Note: Don't provide anything extra. Just give me the response no extra words nothing at all.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROMPT = ChatPromptTemplate.from_template(TEMPLATE)
|
||||||
|
|
||||||
|
def generate_marketing_response(query: str, context: str) -> str:
|
||||||
|
"""Generate marketing response using RAG"""
|
||||||
|
chain = (
|
||||||
|
PROMPT
|
||||||
|
| LLM
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
return chain.invoke({
|
||||||
|
"question": query,
|
||||||
|
"context": context
|
||||||
|
})
|
||||||
@@ -1,24 +1,45 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import Optional, Dict, List
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class DocumentUpload(BaseModel):
|
class CategoryEnum(str, Enum):
|
||||||
file: str # Base64 encoded file content
|
email = "email"
|
||||||
filename: str
|
books = "books"
|
||||||
metadata: Optional[dict] = {}
|
article = "article"
|
||||||
|
social = "social"
|
||||||
class QueryRequest(BaseModel):
|
|
||||||
question: str
|
|
||||||
|
|
||||||
class DocumentResponse(BaseModel):
|
class DocumentResponse(BaseModel):
|
||||||
id: str
|
document_id: str
|
||||||
|
category: CategoryEnum
|
||||||
|
filename: str
|
||||||
|
status: str
|
||||||
|
|
||||||
|
class UpdateDocument(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
metadata: dict
|
metadata: dict
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
class CollectionInfo(BaseModel):
|
class CollectionInfo(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
document_count: int
|
document_count: int
|
||||||
index_size: str
|
index_size: str # This field was missing in the response
|
||||||
|
category_counts: Dict[str, int]
|
||||||
|
|
||||||
class UpdateDocumentRequest(BaseModel):
|
class FileInfo(BaseModel):
|
||||||
new_content: str
|
filename: str
|
||||||
metadata: dict
|
category: str
|
||||||
|
upload_date: Optional[datetime] = None
|
||||||
|
doc_ids: List[str] = []
|
||||||
|
|
||||||
|
class CategoryFiles(BaseModel):
|
||||||
|
category: str # This field was missing in the response
|
||||||
|
files: List[FileInfo] = []
|
||||||
|
total_files: int = 0
|
||||||
|
|
||||||
|
class CategoryResponse(BaseModel):
|
||||||
|
"""Response model for files by category"""
|
||||||
|
files: List[FileInfo] = []
|
||||||
|
total_files: int = 0
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
# Server URL
|
||||||
|
server_url = "https://ma.rommelcorral.com" # Change to ma.nlpdynamo.com later if needed
|
||||||
|
api_key = "4BkwTtVd5VwhTiFDdG3NfzgATrCq7aD8AjnvWNeivirTntHgRvL6Xe84ULHcVTLB"
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
def test_connection():
|
||||||
|
try:
|
||||||
|
response = requests.get(server_url, headers={"Authorization": f"Bearer {api_key}"})
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("✅ Server is reachable.")
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Server responded with status code {response.status_code}: {response.text}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"❌ Connection failed: {e}")
|
||||||
|
|
||||||
|
# Test embedding API
|
||||||
|
def test_embedding():
|
||||||
|
endpoint = f"{server_url}/embeddings"
|
||||||
|
payload = {
|
||||||
|
"model": "BAAI/bge-large-en-v1.5",
|
||||||
|
"input": "This is a test sentence." # Use a string instead of a list
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(endpoint, json=payload, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("✅ Embedding API works.")
|
||||||
|
print("Response:", response.json())
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Embedding failed: {response.status_code} {response.text}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"❌ Request failed: {e}")
|
||||||
|
|
||||||
|
# Test reranker API
|
||||||
|
def test_reranker():
|
||||||
|
endpoint = f"{server_url}/rerank"
|
||||||
|
payload = {
|
||||||
|
"model": "BAAI/bge-reranker-large",
|
||||||
|
"query": "What is AI?",
|
||||||
|
"documents": ["AI is artificial intelligence.", "AI stands for artificial innovation."]
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(endpoint, json=payload, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("✅ Reranker API works.")
|
||||||
|
print("Response:", response.json())
|
||||||
|
else:
|
||||||
|
print(f"⚠️ Reranker failed: {response.status_code} {response.text}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"❌ Request failed: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests
|
||||||
|
test_connection()
|
||||||
|
test_embedding()
|
||||||
|
test_reranker()
|
||||||
@@ -1,37 +1,117 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
import shutil
|
||||||
from typing import List, Tuple
|
from typing import List
|
||||||
|
from pathlib import Path
|
||||||
|
from config import UPLOAD_DIR
|
||||||
|
import nltk
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from PyPDF2 import PdfReader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from config import settings
|
from PIL import Image
|
||||||
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
|
import pytesseract
|
||||||
class CustomEmbeddings:
|
import easyocr
|
||||||
def __init__(self, model_name: str):
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def save_upload_file(file, category) -> Path:
|
||||||
headers = {"Authorization": f"Bearer {settings.API_KEY}"}
|
"""Save uploaded file to appropriate category directory"""
|
||||||
payload = {"model": self.model_name, "input": texts}
|
# Convert CategoryEnum to string if needed
|
||||||
response = requests.post(f"{settings.SERVER_URL}/embeddings", json=payload, headers=headers)
|
category_str = category.value if hasattr(category, 'value') else str(category)
|
||||||
response.raise_for_status()
|
|
||||||
return [item['embedding'] for item in response.json()['data']]
|
file_path = UPLOAD_DIR / category_str / file.filename
|
||||||
|
print(f"Uploading dir: {UPLOAD_DIR}")
|
||||||
|
print(f"Category: {category_str}")
|
||||||
|
print(f"File name: {file.filename}")
|
||||||
|
print(f"Saving file to: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "wb") as buffer:
|
||||||
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
return file_path
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
return self.embed_documents([text])[0]
|
|
||||||
|
|
||||||
class CustomCrossEncoder(BaseCrossEncoder):
|
def image_to_text(image_path: Path) -> str:
|
||||||
def __init__(self, model_name: str):
|
"""Convert image to text using OCR"""
|
||||||
self.model_name = model_name
|
reader = easyocr.Reader(['en'])
|
||||||
|
result = reader.readtext(str(image_path))
|
||||||
|
text = " ".join([res[1] for res in result])
|
||||||
|
return text
|
||||||
|
|
||||||
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
def load_and_split_documents(file_path: Path, folder_path: Path) -> List:
|
||||||
query, documents = text_pairs[0][0], [doc for _, doc in text_pairs]
|
"""Load and split documents into chunks"""
|
||||||
headers = {"Authorization": f"Bearer {settings.API_KEY}"}
|
print(f"Loading file: {file_path}")
|
||||||
payload = {"model": self.model_name, "query": query, "documents": documents}
|
print(f"Loading folder: {folder_path}")
|
||||||
response = requests.post(f"{settings.SERVER_URL}/rerank", json=payload, headers=headers)
|
|
||||||
response.raise_for_status()
|
# Download required NLTK data
|
||||||
return [item['relevance_score'] for item in sorted(response.json()['results'], key=lambda x: x['index'])]
|
try:
|
||||||
|
nltk.download('punkt', quiet=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"NLTK download warning (non-critical): {e}")
|
||||||
|
|
||||||
|
# Handle PDF files directly
|
||||||
|
if file_path.suffix.lower() == '.pdf':
|
||||||
|
try:
|
||||||
|
pdf_reader = PdfReader(str(file_path))
|
||||||
|
text = ""
|
||||||
|
for page in pdf_reader.pages:
|
||||||
|
text += page.extract_text() or ""
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
raise ValueError("No text could be extracted from PDF")
|
||||||
|
|
||||||
|
documents = [Document(page_content=text, metadata={"source": str(file_path)})]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading PDF: {e}")
|
||||||
|
return []
|
||||||
|
elif file_path.suffix.lower() == '.txt':
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
text = f.read()
|
||||||
|
documents = [Document(page_content=text, metadata={"source": str(file_path)})]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading text file: {e}")
|
||||||
|
return []
|
||||||
|
elif file_path.suffix.lower() == '.png' or file_path.suffix.lower() == '.jpg' or file_path.suffix.lower() == '.jpeg':
|
||||||
|
try:
|
||||||
|
image = Image.open(file_path)
|
||||||
|
text = image_to_text(file_path)
|
||||||
|
documents = [Document(page_content=text, metadata={"source": str(file_path)})]
|
||||||
|
print(f"Extracted text from image: {text}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading image file: {e}")
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
reader = SimpleDirectoryReader(
|
||||||
|
input_dir=str(folder_path),
|
||||||
|
recursive=True
|
||||||
|
)
|
||||||
|
all_docs = []
|
||||||
|
for docs in reader.iter_data():
|
||||||
|
# <do something with the documents per file>
|
||||||
|
all_docs.extend(docs)
|
||||||
|
documents = all_docs
|
||||||
|
print(f"Read {len(documents)} documents from directory {folder_path}")
|
||||||
|
print("Documents:", documents)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading directory: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200
|
||||||
|
)
|
||||||
|
|
||||||
def save_uploaded_file(content: bytes, filename: str) -> str:
|
# Remove existing files in folder
|
||||||
filepath = os.path.join(settings.DOCS_PATH, filename)
|
for file in folder_path.iterdir():
|
||||||
with open(filepath, "wb") as f:
|
if file.is_file():
|
||||||
f.write(content)
|
os.remove(file)
|
||||||
return filepath
|
|
||||||
|
return splitter.split_documents(documents)
|
||||||
|
|
||||||
|
def clean_category_files(category: str):
|
||||||
|
"""Clean up files in category directory after processing"""
|
||||||
|
category_dir = UPLOAD_DIR / category
|
||||||
|
for file in category_dir.iterdir():
|
||||||
|
if file.is_file():
|
||||||
|
os.remove(file)
|
||||||
Reference in New Issue
Block a user