174 lines
6.5 KiB
Python
174 lines
6.5 KiB
Python
|
|
"""Vector database operations using FAISS"""
|
||
|
|
import os
|
||
|
|
import json
|
||
|
|
import pickle
|
||
|
|
import numpy as np
|
||
|
|
import faiss
|
||
|
|
from typing import List, Dict, Any, Optional, Tuple
|
||
|
|
from datetime import datetime
|
||
|
|
from config import settings
|
||
|
|
|
||
|
|
class VectorStore:
|
||
|
|
def __init__(self):
|
||
|
|
self.index_path = settings.vector_index_path
|
||
|
|
self.metadata_path = self.index_path.replace('.faiss', '_metadata.pkl')
|
||
|
|
self.dimension = settings.vector_dimension
|
||
|
|
|
||
|
|
# Initialize FAISS index
|
||
|
|
self.index = None
|
||
|
|
self.articles_metadata = []
|
||
|
|
|
||
|
|
# Load existing index if available
|
||
|
|
self.load_index()
|
||
|
|
|
||
|
|
def create_index(self, dimension: int):
|
||
|
|
"""Create a new FAISS index"""
|
||
|
|
# Using IndexFlatIP for cosine similarity (Inner Product)
|
||
|
|
# We'll normalize vectors before adding them
|
||
|
|
self.index = faiss.IndexFlatIP(dimension)
|
||
|
|
self.articles_metadata = []
|
||
|
|
print(f"Created new FAISS index with dimension {dimension}")
|
||
|
|
|
||
|
|
def normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
|
||
|
|
"""Normalize vectors for cosine similarity"""
|
||
|
|
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||
|
|
norms[norms == 0] = 1 # Avoid division by zero
|
||
|
|
return vectors / norms
|
||
|
|
|
||
|
|
def add_articles(self, articles: List[Dict[str, Any]], embeddings: np.ndarray):
|
||
|
|
"""Add articles and their embeddings to the vector store"""
|
||
|
|
if len(articles) != len(embeddings):
|
||
|
|
raise ValueError("Number of articles must match number of embeddings")
|
||
|
|
|
||
|
|
# Create index if it doesn't exist
|
||
|
|
if self.index is None:
|
||
|
|
self.create_index(embeddings.shape[1])
|
||
|
|
|
||
|
|
# Normalize embeddings for cosine similarity
|
||
|
|
normalized_embeddings = self.normalize_vectors(embeddings.astype(np.float32))
|
||
|
|
|
||
|
|
# Add to FAISS index
|
||
|
|
self.index.add(normalized_embeddings)
|
||
|
|
|
||
|
|
# Store metadata
|
||
|
|
for i, article in enumerate(articles):
|
||
|
|
metadata = {
|
||
|
|
'id': article.get('id'),
|
||
|
|
'title': article.get('title'),
|
||
|
|
'content': article.get('content', '')[:200], # Truncate for storage
|
||
|
|
'url': article.get('url'),
|
||
|
|
'source': article.get('source'),
|
||
|
|
'published_date': article.get('published_date'),
|
||
|
|
'added_date': datetime.now().isoformat(),
|
||
|
|
'vector_index': len(self.articles_metadata) # Current index in FAISS
|
||
|
|
}
|
||
|
|
self.articles_metadata.append(metadata)
|
||
|
|
|
||
|
|
print(f"Added {len(articles)} articles to vector store")
|
||
|
|
print(f"Total articles in store: {len(self.articles_metadata)}")
|
||
|
|
|
||
|
|
# Save the updated index
|
||
|
|
self.save_index()
|
||
|
|
|
||
|
|
def search_similar(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
|
||
|
|
"""Search for similar articles"""
|
||
|
|
if self.index is None or len(self.articles_metadata) == 0:
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Normalize query embedding
|
||
|
|
query_embedding = self.normalize_vectors(query_embedding.reshape(1, -1))
|
||
|
|
|
||
|
|
# Search in FAISS
|
||
|
|
similarities, indices = self.index.search(query_embedding, min(top_k, len(self.articles_metadata)))
|
||
|
|
|
||
|
|
results = []
|
||
|
|
for similarity, idx in zip(similarities[0], indices[0]):
|
||
|
|
if idx >= 0 and idx < len(self.articles_metadata): # Valid index
|
||
|
|
article = self.articles_metadata[idx].copy()
|
||
|
|
article['similarity_score'] = float(similarity)
|
||
|
|
|
||
|
|
# Only include if above threshold
|
||
|
|
if similarity >= settings.similarity_threshold:
|
||
|
|
results.append(article)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def get_article_by_id(self, article_id: str) -> Optional[Dict[str, Any]]:
|
||
|
|
"""Get article metadata by ID"""
|
||
|
|
for article in self.articles_metadata:
|
||
|
|
if article.get('id') == article_id:
|
||
|
|
return article
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_all_articles(self) -> List[Dict[str, Any]]:
|
||
|
|
"""Get all articles metadata"""
|
||
|
|
return self.articles_metadata.copy()
|
||
|
|
|
||
|
|
def save_index(self):
|
||
|
|
"""Save FAISS index and metadata to disk"""
|
||
|
|
try:
|
||
|
|
# Ensure directory exists
|
||
|
|
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
||
|
|
|
||
|
|
# Save FAISS index
|
||
|
|
if self.index is not None:
|
||
|
|
faiss.write_index(self.index, self.index_path)
|
||
|
|
|
||
|
|
# Save metadata
|
||
|
|
with open(self.metadata_path, 'wb') as f:
|
||
|
|
pickle.dump(self.articles_metadata, f)
|
||
|
|
|
||
|
|
print(f"Saved vector store to {self.index_path}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error saving vector store: {e}")
|
||
|
|
|
||
|
|
def load_index(self):
|
||
|
|
"""Load FAISS index and metadata from disk"""
|
||
|
|
try:
|
||
|
|
# Load FAISS index
|
||
|
|
if os.path.exists(self.index_path):
|
||
|
|
self.index = faiss.read_index(self.index_path)
|
||
|
|
print(f"Loaded FAISS index from {self.index_path}")
|
||
|
|
|
||
|
|
# Load metadata
|
||
|
|
if os.path.exists(self.metadata_path):
|
||
|
|
with open(self.metadata_path, 'rb') as f:
|
||
|
|
self.articles_metadata = pickle.load(f)
|
||
|
|
print(f"Loaded {len(self.articles_metadata)} articles metadata")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error loading vector store: {e}")
|
||
|
|
# Create new index if loading fails
|
||
|
|
self.index = None
|
||
|
|
self.articles_metadata = []
|
||
|
|
|
||
|
|
def clear_index(self):
|
||
|
|
"""Clear the entire vector store"""
|
||
|
|
self.index = None
|
||
|
|
self.articles_metadata = []
|
||
|
|
|
||
|
|
# Remove files
|
||
|
|
for path in [self.index_path, self.metadata_path]:
|
||
|
|
if os.path.exists(path):
|
||
|
|
os.remove(path)
|
||
|
|
|
||
|
|
print("Cleared vector store")
|
||
|
|
|
||
|
|
def get_stats(self) -> Dict[str, Any]:
|
||
|
|
"""Get vector store statistics"""
|
||
|
|
return {
|
||
|
|
'total_articles': len(self.articles_metadata),
|
||
|
|
'index_dimension': self.dimension,
|
||
|
|
'index_exists': self.index is not None,
|
||
|
|
'index_size': self.index.ntotal if self.index else 0,
|
||
|
|
'last_updated': max([a.get('added_date', '') for a in self.articles_metadata]) if self.articles_metadata else None
|
||
|
|
}
|
||
|
|
|
||
|
|
# Test function
|
||
|
|
if __name__ == "__main__":
|
||
|
|
# Test vector store
|
||
|
|
store = VectorStore()
|
||
|
|
stats = store.get_stats()
|
||
|
|
print(f"Vector store stats: {stats}")
|