feat: Implement complete RSS news fetching system with multi-source support
This commit is contained in:
@@ -0,0 +1,173 @@
|
||||
"""Vector database operations using FAISS"""
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from config import settings
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self):
|
||||
self.index_path = settings.vector_index_path
|
||||
self.metadata_path = self.index_path.replace('.faiss', '_metadata.pkl')
|
||||
self.dimension = settings.vector_dimension
|
||||
|
||||
# Initialize FAISS index
|
||||
self.index = None
|
||||
self.articles_metadata = []
|
||||
|
||||
# Load existing index if available
|
||||
self.load_index()
|
||||
|
||||
def create_index(self, dimension: int):
|
||||
"""Create a new FAISS index"""
|
||||
# Using IndexFlatIP for cosine similarity (Inner Product)
|
||||
# We'll normalize vectors before adding them
|
||||
self.index = faiss.IndexFlatIP(dimension)
|
||||
self.articles_metadata = []
|
||||
print(f"Created new FAISS index with dimension {dimension}")
|
||||
|
||||
def normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
|
||||
"""Normalize vectors for cosine similarity"""
|
||||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return vectors / norms
|
||||
|
||||
def add_articles(self, articles: List[Dict[str, Any]], embeddings: np.ndarray):
|
||||
"""Add articles and their embeddings to the vector store"""
|
||||
if len(articles) != len(embeddings):
|
||||
raise ValueError("Number of articles must match number of embeddings")
|
||||
|
||||
# Create index if it doesn't exist
|
||||
if self.index is None:
|
||||
self.create_index(embeddings.shape[1])
|
||||
|
||||
# Normalize embeddings for cosine similarity
|
||||
normalized_embeddings = self.normalize_vectors(embeddings.astype(np.float32))
|
||||
|
||||
# Add to FAISS index
|
||||
self.index.add(normalized_embeddings)
|
||||
|
||||
# Store metadata
|
||||
for i, article in enumerate(articles):
|
||||
metadata = {
|
||||
'id': article.get('id'),
|
||||
'title': article.get('title'),
|
||||
'content': article.get('content', '')[:200], # Truncate for storage
|
||||
'url': article.get('url'),
|
||||
'source': article.get('source'),
|
||||
'published_date': article.get('published_date'),
|
||||
'added_date': datetime.now().isoformat(),
|
||||
'vector_index': len(self.articles_metadata) # Current index in FAISS
|
||||
}
|
||||
self.articles_metadata.append(metadata)
|
||||
|
||||
print(f"Added {len(articles)} articles to vector store")
|
||||
print(f"Total articles in store: {len(self.articles_metadata)}")
|
||||
|
||||
# Save the updated index
|
||||
self.save_index()
|
||||
|
||||
def search_similar(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Search for similar articles"""
|
||||
if self.index is None or len(self.articles_metadata) == 0:
|
||||
return []
|
||||
|
||||
# Normalize query embedding
|
||||
query_embedding = self.normalize_vectors(query_embedding.reshape(1, -1))
|
||||
|
||||
# Search in FAISS
|
||||
similarities, indices = self.index.search(query_embedding, min(top_k, len(self.articles_metadata)))
|
||||
|
||||
results = []
|
||||
for similarity, idx in zip(similarities[0], indices[0]):
|
||||
if idx >= 0 and idx < len(self.articles_metadata): # Valid index
|
||||
article = self.articles_metadata[idx].copy()
|
||||
article['similarity_score'] = float(similarity)
|
||||
|
||||
# Only include if above threshold
|
||||
if similarity >= settings.similarity_threshold:
|
||||
results.append(article)
|
||||
|
||||
return results
|
||||
|
||||
def get_article_by_id(self, article_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get article metadata by ID"""
|
||||
for article in self.articles_metadata:
|
||||
if article.get('id') == article_id:
|
||||
return article
|
||||
return None
|
||||
|
||||
def get_all_articles(self) -> List[Dict[str, Any]]:
|
||||
"""Get all articles metadata"""
|
||||
return self.articles_metadata.copy()
|
||||
|
||||
def save_index(self):
|
||||
"""Save FAISS index and metadata to disk"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
||||
|
||||
# Save FAISS index
|
||||
if self.index is not None:
|
||||
faiss.write_index(self.index, self.index_path)
|
||||
|
||||
# Save metadata
|
||||
with open(self.metadata_path, 'wb') as f:
|
||||
pickle.dump(self.articles_metadata, f)
|
||||
|
||||
print(f"Saved vector store to {self.index_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving vector store: {e}")
|
||||
|
||||
def load_index(self):
|
||||
"""Load FAISS index and metadata from disk"""
|
||||
try:
|
||||
# Load FAISS index
|
||||
if os.path.exists(self.index_path):
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
print(f"Loaded FAISS index from {self.index_path}")
|
||||
|
||||
# Load metadata
|
||||
if os.path.exists(self.metadata_path):
|
||||
with open(self.metadata_path, 'rb') as f:
|
||||
self.articles_metadata = pickle.load(f)
|
||||
print(f"Loaded {len(self.articles_metadata)} articles metadata")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading vector store: {e}")
|
||||
# Create new index if loading fails
|
||||
self.index = None
|
||||
self.articles_metadata = []
|
||||
|
||||
def clear_index(self):
|
||||
"""Clear the entire vector store"""
|
||||
self.index = None
|
||||
self.articles_metadata = []
|
||||
|
||||
# Remove files
|
||||
for path in [self.index_path, self.metadata_path]:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
print("Cleared vector store")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get vector store statistics"""
|
||||
return {
|
||||
'total_articles': len(self.articles_metadata),
|
||||
'index_dimension': self.dimension,
|
||||
'index_exists': self.index is not None,
|
||||
'index_size': self.index.ntotal if self.index else 0,
|
||||
'last_updated': max([a.get('added_date', '') for a in self.articles_metadata]) if self.articles_metadata else None
|
||||
}
|
||||
|
||||
# Test function
|
||||
if __name__ == "__main__":
|
||||
# Test vector store
|
||||
store = VectorStore()
|
||||
stats = store.get_stats()
|
||||
print(f"Vector store stats: {stats}")
|
||||
Reference in New Issue
Block a user