update all endpoints
This commit is contained in:
+53
-8
@@ -1,14 +1,59 @@
|
||||
import numpy as np
|
||||
import faiss
|
||||
from backend.config import Config
|
||||
import os
|
||||
|
||||
|
||||
class VectorDB:
|
||||
def __init__(self):
|
||||
self.index = faiss.IndexFlatL2(768) # Cohere embedding dim
|
||||
|
||||
def add_vectors(self, ids, embeddings):
|
||||
self.index.add(np.array(embeddings).astype('float32'))
|
||||
|
||||
self.index = None
|
||||
self.articles = []
|
||||
self.dimension = 1024 # Cohere embedding dimension
|
||||
|
||||
def initialize_index(self):
|
||||
"""Initialize FAISS index"""
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
|
||||
def add_vectors(self, embeddings, articles):
|
||||
"""Add vectors to the database"""
|
||||
if self.index is None:
|
||||
self.initialize_index()
|
||||
|
||||
embeddings_array = np.array(embeddings).astype('float32')
|
||||
self.index.add(embeddings_array)
|
||||
self.articles.extend(articles)
|
||||
|
||||
def search(self, query_embedding, k=5):
|
||||
distances, indices = self.index.search(np.array([query_embedding]), k)
|
||||
return indices[0]
|
||||
"""Search for similar vectors"""
|
||||
if self.index is None or self.index.ntotal == 0:
|
||||
return []
|
||||
|
||||
query_array = np.array([query_embedding]).astype('float32')
|
||||
distances, indices = self.index.search(query_array, k)
|
||||
|
||||
results = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
if idx < len(self.articles):
|
||||
result = self.articles[idx].copy()
|
||||
result['similarity_score'] = float(distances[0][i])
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def save_index(self, filepath):
|
||||
"""Save the index to file"""
|
||||
if self.index is not None:
|
||||
faiss.write_index(self.index, filepath)
|
||||
|
||||
def load_index(self, filepath):
|
||||
"""Load the index from file"""
|
||||
if os.path.exists(filepath):
|
||||
self.index = faiss.read_index(filepath)
|
||||
|
||||
def get_article_by_id(self, article_id):
|
||||
"""Get article by ID/slug"""
|
||||
for article in self.articles:
|
||||
if (article.get('slug') == article_id or
|
||||
article.get('title') == article_id or
|
||||
article.get('title', '').lower().replace(' ', '-').replace(',', '').replace('.', '') == article_id):
|
||||
return article
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user