Files
ds_tjc/utils/vector_db.py
T

70 lines
2.7 KiB
Python
Raw Normal View History

import json
from typing import List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor
import os
import numpy as np
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_cohere import CohereEmbeddings
import faiss
from langchain_core.documents import Document
from config import COHERE_API_KEY, EMBEDDING_MODEL, EMBEDDING_DIMENSION
class VectorDB:
def __init__(self):
self._executor = ThreadPoolExecutor(max_workers=10)
self.COHERE_API_KEY = COHERE_API_KEY
os.environ["COHERE_API_KEY"] = self.COHERE_API_KEY
self.embeddings = CohereEmbeddings(model=EMBEDDING_MODEL)
self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
self.vector_score = FAISS(
embedding_function=self.embeddings,
index=self.index,
docstore=InMemoryDocstore(),
index_to_docstore_id={},
)
def load_embeddings(self, file_id: str, file_path: str):
"""
Load embeddings from file
"""
try:
if not os.path.isdir(file_path):
raise Exception(f"{file_path} is not a valid directory.")
print("Files in directory: ", os.listdir(file_path))
print("Current working directory: ", os.getcwd())
os.chdir("/home/kowshik/work/ds_tjc/index/faiss_index")
print("Changed directory to: ", os.getcwd())
new_vector_store = FAISS.load_local(
folder_path=file_path,
index_name="index",
embeddings=self.embeddings,
allow_dangerous_deserialization=True,
)
return new_vector_store
except Exception as e:
raise Exception(f"Error loading embeddings: {str(e)}")
def search(self, new_vector_store, query: str, top_k: int = 5) -> List[Dict]:
"""
Search for similar documents and return serializable results
"""
try:
raw_results = new_vector_store.similarity_search_with_score(query, k=top_k)
# Convert results to serializable format
processed_results = []
for doc, score in raw_results:
processed_result = {
'content': doc.page_content,
'metadata': doc.metadata,
'score': float(score) # Convert numpy.float32 to Python float
}
processed_results.append(processed_result)
return processed_results
except Exception as e:
raise Exception(f"Error during search: {str(e)}")