2025-07-07 22:08:02 +01:00
|
|
|
import numpy as np
|
|
|
|
|
import faiss
|
2025-07-08 19:57:35 +01:00
|
|
|
import os
|
|
|
|
|
|
2025-07-07 22:08:02 +01:00
|
|
|
|
|
|
|
|
class VectorDB:
|
|
|
|
|
def __init__(self):
|
2025-07-08 19:57:35 +01:00
|
|
|
self.index = None
|
|
|
|
|
self.articles = []
|
|
|
|
|
self.dimension = 1024 # Cohere embedding dimension
|
|
|
|
|
|
|
|
|
|
def initialize_index(self):
|
|
|
|
|
"""Initialize FAISS index"""
|
|
|
|
|
self.index = faiss.IndexFlatL2(self.dimension)
|
|
|
|
|
|
|
|
|
|
def add_vectors(self, embeddings, articles):
|
|
|
|
|
"""Add vectors to the database"""
|
|
|
|
|
if self.index is None:
|
|
|
|
|
self.initialize_index()
|
|
|
|
|
|
|
|
|
|
embeddings_array = np.array(embeddings).astype('float32')
|
|
|
|
|
self.index.add(embeddings_array)
|
|
|
|
|
self.articles.extend(articles)
|
|
|
|
|
|
2025-07-07 22:08:02 +01:00
|
|
|
def search(self, query_embedding, k=5):
|
2025-07-08 19:57:35 +01:00
|
|
|
"""Search for similar vectors"""
|
|
|
|
|
if self.index is None or self.index.ntotal == 0:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
query_array = np.array([query_embedding]).astype('float32')
|
|
|
|
|
distances, indices = self.index.search(query_array, k)
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for i, idx in enumerate(indices[0]):
|
|
|
|
|
if idx < len(self.articles):
|
|
|
|
|
result = self.articles[idx].copy()
|
|
|
|
|
result['similarity_score'] = float(distances[0][i])
|
|
|
|
|
results.append(result)
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
def save_index(self, filepath):
|
|
|
|
|
"""Save the index to file"""
|
|
|
|
|
if self.index is not None:
|
|
|
|
|
faiss.write_index(self.index, filepath)
|
|
|
|
|
|
|
|
|
|
def load_index(self, filepath):
|
|
|
|
|
"""Load the index from file"""
|
|
|
|
|
if os.path.exists(filepath):
|
|
|
|
|
self.index = faiss.read_index(filepath)
|
|
|
|
|
|
|
|
|
|
def get_article_by_id(self, article_id):
|
|
|
|
|
"""Get article by ID/slug"""
|
|
|
|
|
for article in self.articles:
|
|
|
|
|
if (article.get('slug') == article_id or
|
|
|
|
|
article.get('title') == article_id or
|
|
|
|
|
article.get('title', '').lower().replace(' ', '-').replace(',', '').replace('.', '') == article_id):
|
|
|
|
|
return article
|
|
|
|
|
return None
|