Files
DS_Task_AI_News/backend/vector_store.py
T
2025-07-08 19:57:35 +01:00

60 lines
1.9 KiB
Python

import numpy as np
import faiss
import os
class VectorDB:
def __init__(self):
self.index = None
self.articles = []
self.dimension = 1024 # Cohere embedding dimension
def initialize_index(self):
"""Initialize FAISS index"""
self.index = faiss.IndexFlatL2(self.dimension)
def add_vectors(self, embeddings, articles):
"""Add vectors to the database"""
if self.index is None:
self.initialize_index()
embeddings_array = np.array(embeddings).astype('float32')
self.index.add(embeddings_array)
self.articles.extend(articles)
def search(self, query_embedding, k=5):
"""Search for similar vectors"""
if self.index is None or self.index.ntotal == 0:
return []
query_array = np.array([query_embedding]).astype('float32')
distances, indices = self.index.search(query_array, k)
results = []
for i, idx in enumerate(indices[0]):
if idx < len(self.articles):
result = self.articles[idx].copy()
result['similarity_score'] = float(distances[0][i])
results.append(result)
return results
def save_index(self, filepath):
"""Save the index to file"""
if self.index is not None:
faiss.write_index(self.index, filepath)
def load_index(self, filepath):
"""Load the index from file"""
if os.path.exists(filepath):
self.index = faiss.read_index(filepath)
def get_article_by_id(self, article_id):
"""Get article by ID/slug"""
for article in self.articles:
if (article.get('slug') == article_id or
article.get('title') == article_id or
article.get('title', '').lower().replace(' ', '-').replace(',', '').replace('.', '') == article_id):
return article
return None