Files
DS_TASK_AI_VIEWS/backend/vector_store.py
T

202 lines
7.4 KiB
Python
Raw Normal View History

"""Vector database operations using FAISS"""
import os
import json
import pickle
import time
import numpy as np
import faiss
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from config import settings
class VectorStore:
def __init__(self):
self.index_path = settings.vector_index_path
self.metadata_path = self.index_path.replace('.faiss', '_metadata.pkl')
self.dimension = settings.vector_dimension
# Initialize FAISS index
self.index = None
self.articles_metadata = []
# Simple in-memory cache for frequent queries
self._cache = {}
self._cache_ttl = 300 # 5 minutes
# Load existing index if available
self.load_index()
def create_index(self, dimension: int):
"""Create a new FAISS index"""
# Using IndexFlatIP for cosine similarity (Inner Product)
# We'll normalize vectors before adding them
self.index = faiss.IndexFlatIP(dimension)
self.articles_metadata = []
print(f"Created new FAISS index with dimension {dimension}")
def normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
"""Normalize vectors for cosine similarity"""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return vectors / norms
def add_articles(self, articles: List[Dict[str, Any]], embeddings: np.ndarray):
"""Add articles and their embeddings to the vector store"""
if len(articles) != len(embeddings):
raise ValueError("Number of articles must match number of embeddings")
# Create index if it doesn't exist
if self.index is None:
self.create_index(embeddings.shape[1])
# Normalize embeddings for cosine similarity
normalized_embeddings = self.normalize_vectors(embeddings.astype(np.float32))
# Add to FAISS index
self.index.add(normalized_embeddings)
# Store metadata
for i, article in enumerate(articles):
metadata = {
'id': article.get('id'),
'title': article.get('title'),
'content': article.get('content', '')[:200], # Truncate for storage
'url': article.get('url'),
'source': article.get('source'),
'published_date': article.get('published_date'),
'added_date': datetime.now().isoformat(),
'vector_index': len(self.articles_metadata) # Current index in FAISS
}
self.articles_metadata.append(metadata)
print(f"Added {len(articles)} articles to vector store")
print(f"Total articles in store: {len(self.articles_metadata)}")
# Save the updated index
self.save_index()
def search_similar(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
"""Search for similar articles"""
if self.index is None or len(self.articles_metadata) == 0:
return []
# Normalize query embedding
query_embedding = self.normalize_vectors(query_embedding.reshape(1, -1))
# Search in FAISS
similarities, indices = self.index.search(query_embedding, min(top_k, len(self.articles_metadata)))
results = []
for similarity, idx in zip(similarities[0], indices[0]):
if idx >= 0 and idx < len(self.articles_metadata): # Valid index
article = self.articles_metadata[idx].copy()
article['similarity_score'] = float(similarity)
# Always include results (threshold removed for better recall)
results.append(article)
return results
def get_article_by_id(self, article_id: str) -> Optional[Dict[str, Any]]:
"""Get article metadata by ID"""
for article in self.articles_metadata:
if article.get('id') == article_id:
return article
return None
def get_all_articles(self) -> List[Dict[str, Any]]:
"""Get all articles metadata"""
return self.articles_metadata.copy()
def save_index(self):
"""Save FAISS index and metadata to disk"""
try:
# Ensure directory exists
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
# Save FAISS index
if self.index is not None:
faiss.write_index(self.index, self.index_path)
# Save metadata
with open(self.metadata_path, 'wb') as f:
pickle.dump(self.articles_metadata, f)
print(f"Saved vector store to {self.index_path}")
except Exception as e:
print(f"Error saving vector store: {e}")
def load_index(self):
"""Load FAISS index and metadata from disk"""
try:
# Load FAISS index
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
print(f"Loaded FAISS index from {self.index_path}")
# Load metadata
if os.path.exists(self.metadata_path):
with open(self.metadata_path, 'rb') as f:
self.articles_metadata = pickle.load(f)
print(f"Loaded {len(self.articles_metadata)} articles metadata")
except Exception as e:
print(f"Error loading vector store: {e}")
# Create new index if loading fails
self.index = None
self.articles_metadata = []
def clear_index(self):
"""Clear the entire vector store"""
self.index = None
self.articles_metadata = []
# Remove files
for path in [self.index_path, self.metadata_path]:
if os.path.exists(path):
os.remove(path)
print("Cleared vector store")
def get_stats(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return {
'total_articles': len(self.articles_metadata),
'index_dimension': self.dimension,
'index_exists': self.index is not None,
'index_size': self.index.ntotal if self.index else 0,
'last_updated': max([a.get('added_date', '') for a in self.articles_metadata]) if self.articles_metadata else None
}
def _get_cache_key(self, operation: str, *args) -> str:
"""Generate cache key for operation"""
import hashlib
key_data = f"{operation}:{':'.join(map(str, args))}"
return hashlib.md5(key_data.encode()).hexdigest()
def _get_from_cache(self, key: str) -> Optional[Any]:
"""Get value from cache if not expired"""
if key in self._cache:
cached_data, timestamp = self._cache[key]
if time.time() - timestamp < self._cache_ttl:
return cached_data
else:
del self._cache[key]
return None
def _set_cache(self, key: str, value: Any) -> None:
"""Set value in cache with timestamp"""
self._cache[key] = (value, time.time())
def _clear_cache(self) -> None:
"""Clear all cache entries"""
self._cache.clear()
# Test function
if __name__ == "__main__":
# Test vector store
store = VectorStore()
stats = store.get_stats()
print(f"Vector store stats: {stats}")