2025-07-07 18:31:38 +01:00
|
|
|
"""Vector database operations using FAISS"""
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import pickle
|
2025-07-08 16:45:38 +01:00
|
|
|
import time
|
2025-07-07 18:31:38 +01:00
|
|
|
import numpy as np
|
|
|
|
|
import faiss
|
|
|
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from config import settings
|
|
|
|
|
|
|
|
|
|
class VectorStore:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.index_path = settings.vector_index_path
|
|
|
|
|
self.metadata_path = self.index_path.replace('.faiss', '_metadata.pkl')
|
|
|
|
|
self.dimension = settings.vector_dimension
|
2025-07-08 16:45:38 +01:00
|
|
|
|
2025-07-07 18:31:38 +01:00
|
|
|
# Initialize FAISS index
|
|
|
|
|
self.index = None
|
|
|
|
|
self.articles_metadata = []
|
2025-07-08 16:45:38 +01:00
|
|
|
|
|
|
|
|
# Simple in-memory cache for frequent queries
|
|
|
|
|
self._cache = {}
|
|
|
|
|
self._cache_ttl = 300 # 5 minutes
|
|
|
|
|
|
2025-07-07 18:31:38 +01:00
|
|
|
# Load existing index if available
|
|
|
|
|
self.load_index()
|
|
|
|
|
|
|
|
|
|
def create_index(self, dimension: int):
|
|
|
|
|
"""Create a new FAISS index"""
|
|
|
|
|
# Using IndexFlatIP for cosine similarity (Inner Product)
|
|
|
|
|
# We'll normalize vectors before adding them
|
|
|
|
|
self.index = faiss.IndexFlatIP(dimension)
|
|
|
|
|
self.articles_metadata = []
|
|
|
|
|
print(f"Created new FAISS index with dimension {dimension}")
|
|
|
|
|
|
|
|
|
|
def normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
|
|
|
|
|
"""Normalize vectors for cosine similarity"""
|
|
|
|
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
|
|
|
|
norms[norms == 0] = 1 # Avoid division by zero
|
|
|
|
|
return vectors / norms
|
|
|
|
|
|
|
|
|
|
def add_articles(self, articles: List[Dict[str, Any]], embeddings: np.ndarray):
|
|
|
|
|
"""Add articles and their embeddings to the vector store"""
|
|
|
|
|
if len(articles) != len(embeddings):
|
|
|
|
|
raise ValueError("Number of articles must match number of embeddings")
|
|
|
|
|
|
|
|
|
|
# Create index if it doesn't exist
|
|
|
|
|
if self.index is None:
|
|
|
|
|
self.create_index(embeddings.shape[1])
|
|
|
|
|
|
|
|
|
|
# Normalize embeddings for cosine similarity
|
|
|
|
|
normalized_embeddings = self.normalize_vectors(embeddings.astype(np.float32))
|
|
|
|
|
|
|
|
|
|
# Add to FAISS index
|
|
|
|
|
self.index.add(normalized_embeddings)
|
|
|
|
|
|
|
|
|
|
# Store metadata
|
|
|
|
|
for i, article in enumerate(articles):
|
|
|
|
|
metadata = {
|
|
|
|
|
'id': article.get('id'),
|
|
|
|
|
'title': article.get('title'),
|
|
|
|
|
'content': article.get('content', '')[:200], # Truncate for storage
|
|
|
|
|
'url': article.get('url'),
|
|
|
|
|
'source': article.get('source'),
|
|
|
|
|
'published_date': article.get('published_date'),
|
|
|
|
|
'added_date': datetime.now().isoformat(),
|
|
|
|
|
'vector_index': len(self.articles_metadata) # Current index in FAISS
|
|
|
|
|
}
|
|
|
|
|
self.articles_metadata.append(metadata)
|
|
|
|
|
|
|
|
|
|
print(f"Added {len(articles)} articles to vector store")
|
|
|
|
|
print(f"Total articles in store: {len(self.articles_metadata)}")
|
|
|
|
|
|
|
|
|
|
# Save the updated index
|
|
|
|
|
self.save_index()
|
|
|
|
|
|
|
|
|
|
def search_similar(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
|
|
|
"""Search for similar articles"""
|
|
|
|
|
if self.index is None or len(self.articles_metadata) == 0:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
# Normalize query embedding
|
|
|
|
|
query_embedding = self.normalize_vectors(query_embedding.reshape(1, -1))
|
|
|
|
|
|
|
|
|
|
# Search in FAISS
|
|
|
|
|
similarities, indices = self.index.search(query_embedding, min(top_k, len(self.articles_metadata)))
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for similarity, idx in zip(similarities[0], indices[0]):
|
|
|
|
|
if idx >= 0 and idx < len(self.articles_metadata): # Valid index
|
|
|
|
|
article = self.articles_metadata[idx].copy()
|
|
|
|
|
article['similarity_score'] = float(similarity)
|
2025-07-08 18:46:26 +01:00
|
|
|
|
|
|
|
|
# Always include results (threshold removed for better recall)
|
|
|
|
|
results.append(article)
|
2025-07-07 18:31:38 +01:00
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
def get_article_by_id(self, article_id: str) -> Optional[Dict[str, Any]]:
|
|
|
|
|
"""Get article metadata by ID"""
|
|
|
|
|
for article in self.articles_metadata:
|
|
|
|
|
if article.get('id') == article_id:
|
|
|
|
|
return article
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_all_articles(self) -> List[Dict[str, Any]]:
|
|
|
|
|
"""Get all articles metadata"""
|
|
|
|
|
return self.articles_metadata.copy()
|
|
|
|
|
|
|
|
|
|
def save_index(self):
|
|
|
|
|
"""Save FAISS index and metadata to disk"""
|
|
|
|
|
try:
|
|
|
|
|
# Ensure directory exists
|
|
|
|
|
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Save FAISS index
|
|
|
|
|
if self.index is not None:
|
|
|
|
|
faiss.write_index(self.index, self.index_path)
|
|
|
|
|
|
|
|
|
|
# Save metadata
|
|
|
|
|
with open(self.metadata_path, 'wb') as f:
|
|
|
|
|
pickle.dump(self.articles_metadata, f)
|
|
|
|
|
|
|
|
|
|
print(f"Saved vector store to {self.index_path}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Error saving vector store: {e}")
|
|
|
|
|
|
|
|
|
|
def load_index(self):
|
|
|
|
|
"""Load FAISS index and metadata from disk"""
|
|
|
|
|
try:
|
|
|
|
|
# Load FAISS index
|
|
|
|
|
if os.path.exists(self.index_path):
|
|
|
|
|
self.index = faiss.read_index(self.index_path)
|
|
|
|
|
print(f"Loaded FAISS index from {self.index_path}")
|
|
|
|
|
|
|
|
|
|
# Load metadata
|
|
|
|
|
if os.path.exists(self.metadata_path):
|
|
|
|
|
with open(self.metadata_path, 'rb') as f:
|
|
|
|
|
self.articles_metadata = pickle.load(f)
|
|
|
|
|
print(f"Loaded {len(self.articles_metadata)} articles metadata")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Error loading vector store: {e}")
|
|
|
|
|
# Create new index if loading fails
|
|
|
|
|
self.index = None
|
|
|
|
|
self.articles_metadata = []
|
|
|
|
|
|
|
|
|
|
def clear_index(self):
|
|
|
|
|
"""Clear the entire vector store"""
|
|
|
|
|
self.index = None
|
|
|
|
|
self.articles_metadata = []
|
|
|
|
|
|
|
|
|
|
# Remove files
|
|
|
|
|
for path in [self.index_path, self.metadata_path]:
|
|
|
|
|
if os.path.exists(path):
|
|
|
|
|
os.remove(path)
|
|
|
|
|
|
|
|
|
|
print("Cleared vector store")
|
|
|
|
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get vector store statistics"""
|
|
|
|
|
return {
|
|
|
|
|
'total_articles': len(self.articles_metadata),
|
|
|
|
|
'index_dimension': self.dimension,
|
|
|
|
|
'index_exists': self.index is not None,
|
|
|
|
|
'index_size': self.index.ntotal if self.index else 0,
|
|
|
|
|
'last_updated': max([a.get('added_date', '') for a in self.articles_metadata]) if self.articles_metadata else None
|
|
|
|
|
}
|
|
|
|
|
|
2025-07-08 16:45:38 +01:00
|
|
|
def _get_cache_key(self, operation: str, *args) -> str:
|
|
|
|
|
"""Generate cache key for operation"""
|
|
|
|
|
import hashlib
|
|
|
|
|
key_data = f"{operation}:{':'.join(map(str, args))}"
|
|
|
|
|
return hashlib.md5(key_data.encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
def _get_from_cache(self, key: str) -> Optional[Any]:
|
|
|
|
|
"""Get value from cache if not expired"""
|
|
|
|
|
if key in self._cache:
|
|
|
|
|
cached_data, timestamp = self._cache[key]
|
|
|
|
|
if time.time() - timestamp < self._cache_ttl:
|
|
|
|
|
return cached_data
|
|
|
|
|
else:
|
|
|
|
|
del self._cache[key]
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _set_cache(self, key: str, value: Any) -> None:
|
|
|
|
|
"""Set value in cache with timestamp"""
|
|
|
|
|
self._cache[key] = (value, time.time())
|
|
|
|
|
|
|
|
|
|
def _clear_cache(self) -> None:
|
|
|
|
|
"""Clear all cache entries"""
|
|
|
|
|
self._cache.clear()
|
|
|
|
|
|
2025-07-07 18:31:38 +01:00
|
|
|
# Test function
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
# Test vector store
|
|
|
|
|
store = VectorStore()
|
|
|
|
|
stats = store.get_stats()
|
|
|
|
|
print(f"Vector store stats: {stats}")
|