43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
|
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||
|
|
from langchain_community.vectorstores import FAISS
|
||
|
|
|
||
|
|
|
||
|
|
# loading the embedding model
|
||
|
|
def load_embedding_model():
|
||
|
|
model_name = "BAAI/bge-small-en"
|
||
|
|
model_kwargs = {"device": "cuda"} #can also be cpu
|
||
|
|
encode_kwargs = {"normalize_embeddings": True}
|
||
|
|
embeddings = HuggingFaceBgeEmbeddings(
|
||
|
|
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
|
||
|
|
)
|
||
|
|
return embeddings
|
||
|
|
|
||
|
|
# loading the embedding model
|
||
|
|
embeddings = load_embedding_model()
|
||
|
|
|
||
|
|
|
||
|
|
# A function to create the vector store
|
||
|
|
def create_vector_store(document, embeddings=embeddings):
|
||
|
|
embed_db = FAISS.from_documents(document, embeddings)
|
||
|
|
return embed_db
|
||
|
|
|
||
|
|
# A function to save the embedded data
|
||
|
|
def save_embedded_data(docs, key="pdf"):
|
||
|
|
docs.save_local(f"vec-db/index/faiss_index_{key}")
|
||
|
|
print("Embeddings saved")
|
||
|
|
|
||
|
|
# A function to load the embedded data
|
||
|
|
def load_embedded_data(embeddings=embeddings, key="pdf"):
|
||
|
|
embed_db = FAISS.load_local(f"vec-db/index/faiss_index_{key}", embeddings, allow_dangerous_deserialization=True)
|
||
|
|
return embed_db
|
||
|
|
|
||
|
|
# A document search function
|
||
|
|
def search(db, query, k=4):
|
||
|
|
docs = db.similarity_search(query, k)
|
||
|
|
all = ""
|
||
|
|
pages = []
|
||
|
|
for doc in docs:
|
||
|
|
all += f"{doc.page_content}\n"
|
||
|
|
pages.append(doc.metadata['page'])
|
||
|
|
return docs[0].page_content, all, pages
|