feat: Implement complete RSS news fetching system with multi-source support

This commit is contained in:
Aherobo Ovie Victor
2025-07-07 18:31:38 +01:00
parent c158262a49
commit e188af8b17
22 changed files with 2210 additions and 0 deletions
+46
View File
@@ -0,0 +1,46 @@
"""Configuration settings for DS Task AI News"""
import os
from typing import List
from pydantic_settings import BaseSettings
from dotenv import load_dotenv
load_dotenv()
class Settings(BaseSettings):
# API Keys
cohere_api_key: str = os.getenv("COHERE_API_KEY", "")
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
# Vector Database
vector_db_type: str = os.getenv("VECTOR_DB_TYPE", "faiss")
vector_dimension: int = int(os.getenv("VECTOR_DIMENSION", "384"))
# RSS Feeds
@property
def rss_feeds(self) -> List[str]:
feeds_str = os.getenv(
"RSS_FEEDS",
"https://feeds.bbci.co.uk/news/technology/rss.xml,"
"https://techcrunch.com/feed/,"
"https://www.wired.com/feed/rss"
)
return [feed.strip() for feed in feeds_str.split(",") if feed.strip()]
# Server Settings
host: str = os.getenv("HOST", "0.0.0.0")
port: int = int(os.getenv("PORT", "8000"))
debug: bool = os.getenv("DEBUG", "true").lower() == "true"
# Data Storage
raw_news_dir: str = os.getenv("RAW_NEWS_DIR", "data/raw_news")
processed_news_dir: str = os.getenv("PROCESSED_NEWS_DIR", "data/processed_news")
vector_index_path: str = os.getenv("VECTOR_INDEX_PATH", "data/news_vectors.faiss")
# Embedding Model
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
# News Processing
max_articles_per_feed: int = 50
similarity_threshold: float = 0.7
settings = Settings()
+156
View File
@@ -0,0 +1,156 @@
"""Embeddings generation for DS Task AI News"""
import os
import numpy as np
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import cohere
from config import settings
class EmbeddingGenerator:
def __init__(self):
self.cohere_client = None
self.sentence_model = None
self.use_cohere = bool(settings.cohere_api_key)
# Initialize embedding model
if self.use_cohere:
try:
self.cohere_client = cohere.Client(settings.cohere_api_key)
print("Using Cohere for embeddings")
except Exception as e:
print(f"Cohere initialization failed: {e}")
self.use_cohere = False
if not self.use_cohere:
print("Using Sentence Transformers for embeddings")
self.sentence_model = SentenceTransformer(settings.embedding_model)
def create_article_text(self, article: Dict[str, Any]) -> str:
"""Combine article fields into text for embedding"""
title = article.get('title', '')
content = article.get('content', '')
source = article.get('source', '')
# Combine with weights (title is more important)
text = f"{title}. {content}"
if source:
text += f" Source: {source}"
return text.strip()
def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Cohere"""
try:
response = self.cohere_client.embed(
texts=texts,
model='embed-english-v3.0',
input_type='search_document'
)
return np.array(response.embeddings)
except Exception as e:
print(f"Cohere embedding error: {e}")
raise
def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Sentence Transformers"""
try:
embeddings = self.sentence_model.encode(texts, convert_to_numpy=True)
return embeddings
except Exception as e:
print(f"Sentence Transformer embedding error: {e}")
raise
def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray:
"""Generate embeddings for articles"""
if not articles:
return np.array([])
# Create texts for embedding
texts = [self.create_article_text(article) for article in articles]
print(f"Generating embeddings for {len(texts)} articles...")
# Generate embeddings
if self.use_cohere:
embeddings = self.generate_embeddings_cohere(texts)
else:
embeddings = self.generate_embeddings_sentence_transformer(texts)
print(f"Generated embeddings shape: {embeddings.shape}")
return embeddings
def generate_query_embedding(self, query: str) -> np.ndarray:
"""Generate embedding for a search query"""
if self.use_cohere:
try:
response = self.cohere_client.embed(
texts=[query],
model='embed-english-v3.0',
input_type='search_query'
)
return np.array(response.embeddings[0])
except Exception as e:
print(f"Cohere query embedding error: {e}")
# Fallback to sentence transformer
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
else:
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings"""
# Normalize embeddings
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
# Cosine similarity
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
return float(similarity)
def find_similar_articles(self, query_embedding: np.ndarray,
article_embeddings: np.ndarray,
articles: List[Dict[str, Any]],
top_k: int = 5) -> List[Dict[str, Any]]:
"""Find most similar articles to query"""
if len(article_embeddings) == 0:
return []
similarities = []
for i, article_embedding in enumerate(article_embeddings):
similarity = self.compute_similarity(query_embedding, article_embedding)
similarities.append((similarity, i))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top-k results
results = []
for similarity, idx in similarities[:top_k]:
if similarity >= settings.similarity_threshold:
article = articles[idx].copy()
article['similarity_score'] = similarity
results.append(article)
return results
# Test function
if __name__ == "__main__":
# Test with sample articles
sample_articles = [
{
"title": "AI Revolution in Healthcare",
"content": "Artificial intelligence is transforming medical diagnosis and treatment.",
"source": "TechNews"
},
{
"title": "Climate Change Solutions",
"content": "New technologies are being developed to combat global warming.",
"source": "ScienceDaily"
}
]
generator = EmbeddingGenerator()
embeddings = generator.generate_embeddings(sample_articles)
print(f"Test embeddings shape: {embeddings.shape}")
+234
View File
@@ -0,0 +1,234 @@
"""FastAPI backend for DS Task AI News"""
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import uvicorn
from config import settings
from news_fetcher import NewsFetcher
from recommender import NewsRecommender
# Initialize FastAPI app
app = FastAPI(
title="DS Task AI News API",
description="AI-powered news retrieval and recommendation system",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
news_fetcher = NewsFetcher()
recommender = NewsRecommender()
# Pydantic models
class NewsQuery(BaseModel):
query: str
top_k: int = 5
class InterestsQuery(BaseModel):
interests: List[str]
top_k: int = 10
class SearchQuery(BaseModel):
query: str
source: Optional[str] = None
top_k: int = 10
# API Endpoints
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"message": "DS Task AI News API is running!",
"version": "1.0.0",
"status": "healthy"
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
stats = recommender.get_store_stats()
return {
"status": "healthy",
"vector_store": stats,
"settings": {
"embedding_model": settings.embedding_model,
"vector_db_type": settings.vector_db_type,
"rss_feeds_count": len(settings.rss_feeds)
}
}
@app.post("/fetch-news")
async def fetch_news():
"""Fetch news from RSS feeds and add to vector store"""
try:
# Fetch news articles
result = news_fetcher.fetch_and_save_news()
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("message", "Failed to fetch news"))
# Add articles to vector store
articles = result["articles"]
store_result = recommender.add_articles_to_store(articles)
if not store_result["success"]:
raise HTTPException(status_code=500, detail=store_result.get("message", "Failed to add articles to store"))
return {
"success": True,
"message": "News fetched and processed successfully",
"articles_fetched": result["articles_count"],
"articles_stored": store_result["articles_added"],
"total_articles": store_result["total_articles"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching news: {str(e)}")
@app.get("/recommend-news")
async def recommend_news(
article_id: str = Query(..., description="ID of the article to find similar articles for"),
top_k: int = Query(5, description="Number of recommendations to return")
):
"""Get news recommendations based on article ID"""
try:
recommendations = recommender.recommend_by_article_id(article_id, top_k)
return {
"success": True,
"article_id": article_id,
"recommendations": recommendations,
"count": len(recommendations)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
@app.post("/recommend-by-query")
async def recommend_by_query(query_data: NewsQuery):
"""Get news recommendations based on text query"""
try:
recommendations = recommender.recommend_by_query(query_data.query, query_data.top_k)
return {
"success": True,
"query": query_data.query,
"recommendations": recommendations,
"count": len(recommendations)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
@app.post("/recommend-by-interests")
async def recommend_by_interests(interests_data: InterestsQuery):
"""Get news recommendations based on user interests"""
try:
recommendations = recommender.recommend_by_interests(interests_data.interests, interests_data.top_k)
return {
"success": True,
"interests": interests_data.interests,
"recommendations": recommendations,
"count": len(recommendations)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
@app.get("/trending")
async def get_trending_news(top_k: int = Query(10, description="Number of trending articles to return")):
"""Get trending news articles"""
try:
trending = recommender.get_trending_articles(top_k)
return {
"success": True,
"trending_articles": trending,
"count": len(trending)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting trending news: {str(e)}")
@app.get("/articles")
async def get_all_articles(
source: Optional[str] = Query(None, description="Filter by news source"),
limit: int = Query(50, description="Maximum number of articles to return")
):
"""Get all articles with optional filtering"""
try:
if source:
articles = recommender.get_articles_by_source(source, limit)
else:
all_articles = recommender.vector_store.get_all_articles()
articles = sorted(all_articles, key=lambda x: x.get('published_date', ''), reverse=True)[:limit]
return {
"success": True,
"articles": articles,
"count": len(articles),
"source_filter": source
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting articles: {str(e)}")
@app.post("/search")
async def search_articles(search_data: SearchQuery):
"""Advanced search with filters"""
try:
filters = {}
if search_data.source:
filters['source'] = search_data.source
results = recommender.search_articles(search_data.query, filters, search_data.top_k)
return {
"success": True,
"query": search_data.query,
"filters": filters,
"results": results,
"count": len(results)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error searching articles: {str(e)}")
@app.get("/stats")
async def get_stats():
"""Get system statistics"""
try:
stats = recommender.get_store_stats()
# Add RSS feed information
stats['rss_feeds'] = settings.rss_feeds
stats['embedding_model'] = settings.embedding_model
return {
"success": True,
"statistics": stats
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting stats: {str(e)}")
# Run the application
if __name__ == "__main__":
uvicorn.run(
"main:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)
+147
View File
@@ -0,0 +1,147 @@
"""RSS News Fetcher for DS Task AI News"""
import feedparser
import requests
import json
import os
from datetime import datetime
from typing import List, Dict, Any
from urllib.parse import urlparse
import hashlib
from config import settings
class NewsFetcher:
def __init__(self):
self.raw_news_dir = settings.raw_news_dir
self.max_articles = settings.max_articles_per_feed
# Ensure directories exist
os.makedirs(self.raw_news_dir, exist_ok=True)
def generate_article_id(self, title: str, url: str) -> str:
"""Generate unique ID for article"""
content = f"{title}{url}"
return hashlib.md5(content.encode()).hexdigest()[:12]
def clean_content(self, content: str) -> str:
"""Clean and truncate content"""
if not content:
return ""
# Remove HTML tags (basic cleaning)
import re
content = re.sub(r'<[^>]+>', '', content)
# Truncate to reasonable length
return content[:1000] if len(content) > 1000 else content
def fetch_rss_feed(self, feed_url: str) -> List[Dict[str, Any]]:
"""Fetch articles from a single RSS feed"""
try:
print(f"Fetching from: {feed_url}")
feed = feedparser.parse(feed_url)
if feed.bozo:
print(f"Warning: Feed parsing issues for {feed_url}")
articles = []
source_name = getattr(feed.feed, 'title', urlparse(feed_url).netloc)
for entry in feed.entries[:self.max_articles]:
try:
# Extract article data
title = getattr(entry, 'title', 'No Title')
content = getattr(entry, 'summary', getattr(entry, 'description', ''))
url = getattr(entry, 'link', '')
published = getattr(entry, 'published', '')
# Parse date
try:
if published:
pub_date = datetime(*entry.published_parsed[:6])
else:
pub_date = datetime.now()
except:
pub_date = datetime.now()
# Create article object
article = {
"id": self.generate_article_id(title, url),
"title": title,
"content": self.clean_content(content),
"url": url,
"source": source_name,
"published_date": pub_date.isoformat(),
"fetched_date": datetime.now().isoformat(),
"categories": getattr(entry, 'tags', []),
"slug": title.lower().replace(" ", "-").replace("'", "")[:50]
}
articles.append(article)
except Exception as e:
print(f"Error processing entry: {e}")
continue
print(f"Fetched {len(articles)} articles from {source_name}")
return articles
except Exception as e:
print(f"Error fetching RSS feed {feed_url}: {e}")
return []
def fetch_all_news(self) -> List[Dict[str, Any]]:
"""Fetch news from all configured RSS feeds"""
all_articles = []
for feed_url in settings.rss_feeds:
feed_url = feed_url.strip()
if feed_url:
articles = self.fetch_rss_feed(feed_url)
all_articles.extend(articles)
# Remove duplicates based on ID
unique_articles = {}
for article in all_articles:
unique_articles[article['id']] = article
final_articles = list(unique_articles.values())
print(f"Total unique articles fetched: {len(final_articles)}")
return final_articles
def save_articles(self, articles: List[Dict[str, Any]]) -> str:
"""Save articles to JSON file"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"news_{timestamp}.json"
filepath = os.path.join(self.raw_news_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(articles, f, indent=2, ensure_ascii=False)
print(f"Saved {len(articles)} articles to {filepath}")
return filepath
def fetch_and_save_news(self) -> Dict[str, Any]:
"""Fetch news and save to file"""
articles = self.fetch_all_news()
if articles:
filepath = self.save_articles(articles)
return {
"success": True,
"articles_count": len(articles),
"filepath": filepath,
"articles": articles
}
else:
return {
"success": False,
"articles_count": 0,
"message": "No articles fetched"
}
# Test function
if __name__ == "__main__":
fetcher = NewsFetcher()
result = fetcher.fetch_and_save_news()
print(f"Result: {result}")
+151
View File
@@ -0,0 +1,151 @@
"""News recommendation system"""
from typing import List, Dict, Any, Optional
import numpy as np
from embeddings import EmbeddingGenerator
from vector_store import VectorStore
from config import settings
class NewsRecommender:
def __init__(self):
self.embedding_generator = EmbeddingGenerator()
self.vector_store = VectorStore()
def recommend_by_article_id(self, article_id: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""Recommend articles similar to a given article ID"""
# Get the article
article = self.vector_store.get_article_by_id(article_id)
if not article:
return []
# Create text from article for embedding
article_text = self.embedding_generator.create_article_text(article)
# Generate embedding for the article
query_embedding = self.embedding_generator.generate_query_embedding(article_text)
# Search for similar articles
similar_articles = self.vector_store.search_similar(query_embedding, top_k + 1) # +1 to exclude self
# Remove the original article from results
filtered_results = [a for a in similar_articles if a.get('id') != article_id]
return filtered_results[:top_k]
def recommend_by_query(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""Recommend articles based on a text query"""
if not query.strip():
return []
# Generate embedding for query
query_embedding = self.embedding_generator.generate_query_embedding(query)
# Search for similar articles
similar_articles = self.vector_store.search_similar(query_embedding, top_k)
return similar_articles
def recommend_by_interests(self, interests: List[str], top_k: int = 10) -> List[Dict[str, Any]]:
"""Recommend articles based on user interests"""
if not interests:
return []
# Combine interests into a query
query = " ".join(interests)
return self.recommend_by_query(query, top_k)
def get_trending_articles(self, top_k: int = 10) -> List[Dict[str, Any]]:
"""Get trending articles (most recent for now)"""
all_articles = self.vector_store.get_all_articles()
# Sort by published date (most recent first)
sorted_articles = sorted(
all_articles,
key=lambda x: x.get('published_date', ''),
reverse=True
)
return sorted_articles[:top_k]
def get_articles_by_source(self, source: str, top_k: int = 10) -> List[Dict[str, Any]]:
"""Get articles from a specific source"""
all_articles = self.vector_store.get_all_articles()
# Filter by source
source_articles = [
article for article in all_articles
if article.get('source', '').lower() == source.lower()
]
# Sort by published date
sorted_articles = sorted(
source_articles,
key=lambda x: x.get('published_date', ''),
reverse=True
)
return sorted_articles[:top_k]
def add_articles_to_store(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Add new articles to the vector store"""
if not articles:
return {"success": False, "message": "No articles provided"}
try:
# Generate embeddings
embeddings = self.embedding_generator.generate_embeddings(articles)
# Add to vector store
self.vector_store.add_articles(articles, embeddings)
return {
"success": True,
"articles_added": len(articles),
"total_articles": len(self.vector_store.get_all_articles())
}
except Exception as e:
return {
"success": False,
"message": f"Error adding articles: {str(e)}"
}
def get_store_stats(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return self.vector_store.get_stats()
def search_articles(self, query: str, filters: Optional[Dict[str, Any]] = None,
top_k: int = 10) -> List[Dict[str, Any]]:
"""Advanced search with filters"""
# Get basic recommendations
results = self.recommend_by_query(query, top_k * 2) # Get more to allow filtering
# Apply filters if provided
if filters:
filtered_results = []
for article in results:
include = True
# Source filter
if 'source' in filters:
if article.get('source', '').lower() != filters['source'].lower():
include = False
# Date range filter (simplified)
if 'date_from' in filters or 'date_to' in filters:
# This would need proper date parsing in a real implementation
pass
if include:
filtered_results.append(article)
results = filtered_results
return results[:top_k]
# Test function
if __name__ == "__main__":
recommender = NewsRecommender()
stats = recommender.get_store_stats()
print(f"Recommender stats: {stats}")
+80
View File
@@ -0,0 +1,80 @@
# FastAPI and server
fastapi==0.116.0
uvicorn==0.35.0
starlette==0.46.2
# RSS and web scraping
feedparser==6.0.11
requests==2.32.4
beautifulsoup4==4.13.4
# AI and ML - Core
cohere==5.15.0
sentence-transformers==5.0.0
faiss-cpu==1.11.0
numpy==2.2.6
# AI and ML - Supporting
torch==2.7.1
transformers==4.53.1
scikit-learn==1.7.0
huggingface-hub==0.33.2
tokenizers==0.21.2
safetensors==0.5.3
# Data processing
pandas==2.3.0
python-dateutil==2.9.0.post0
scipy==1.15.3
# Environment and config
python-dotenv==1.1.1
pydantic==2.11.7
pydantic-settings==2.10.1
pydantic-core==2.33.2
# LLM Integration
groq==0.29.0
# Utilities
tqdm==4.67.1
click==8.2.1
typing-extensions==4.14.1
packaging==25.0
filelock==3.18.0
fsspec==2025.5.1
PyYAML==6.0.2
regex==2024.11.6
pillow==11.3.0
jinja2==3.1.6
markupsafe==3.0.2
certifi==2025.6.15
urllib3==2.5.0
charset-normalizer==3.4.2
idna==3.10
# HTTP and networking
httpx==0.28.1
httpcore==1.0.9
httpx-sse==0.4.0
anyio==4.9.0
sniffio==1.3.1
h11==0.16.0
# Additional utilities
joblib==1.5.1
threadpoolctl==3.6.0
sympy==1.14.0
mpmath==1.3.0
networkx==3.4.2
six==1.17.0
pytz==2025.2
tzdata==2025.2
colorama==0.4.6
distro==1.9.0
fastavro==1.11.1
soupsieve==2.7
types-requests==2.32.4.20250611
annotated-types==0.7.0
typing-inspection==0.4.1
exceptiongroup==1.3.0
Binary file not shown.
+173
View File
@@ -0,0 +1,173 @@
"""Vector database operations using FAISS"""
import os
import json
import pickle
import numpy as np
import faiss
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from config import settings
class VectorStore:
def __init__(self):
self.index_path = settings.vector_index_path
self.metadata_path = self.index_path.replace('.faiss', '_metadata.pkl')
self.dimension = settings.vector_dimension
# Initialize FAISS index
self.index = None
self.articles_metadata = []
# Load existing index if available
self.load_index()
def create_index(self, dimension: int):
"""Create a new FAISS index"""
# Using IndexFlatIP for cosine similarity (Inner Product)
# We'll normalize vectors before adding them
self.index = faiss.IndexFlatIP(dimension)
self.articles_metadata = []
print(f"Created new FAISS index with dimension {dimension}")
def normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
"""Normalize vectors for cosine similarity"""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return vectors / norms
def add_articles(self, articles: List[Dict[str, Any]], embeddings: np.ndarray):
"""Add articles and their embeddings to the vector store"""
if len(articles) != len(embeddings):
raise ValueError("Number of articles must match number of embeddings")
# Create index if it doesn't exist
if self.index is None:
self.create_index(embeddings.shape[1])
# Normalize embeddings for cosine similarity
normalized_embeddings = self.normalize_vectors(embeddings.astype(np.float32))
# Add to FAISS index
self.index.add(normalized_embeddings)
# Store metadata
for i, article in enumerate(articles):
metadata = {
'id': article.get('id'),
'title': article.get('title'),
'content': article.get('content', '')[:200], # Truncate for storage
'url': article.get('url'),
'source': article.get('source'),
'published_date': article.get('published_date'),
'added_date': datetime.now().isoformat(),
'vector_index': len(self.articles_metadata) # Current index in FAISS
}
self.articles_metadata.append(metadata)
print(f"Added {len(articles)} articles to vector store")
print(f"Total articles in store: {len(self.articles_metadata)}")
# Save the updated index
self.save_index()
def search_similar(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
"""Search for similar articles"""
if self.index is None or len(self.articles_metadata) == 0:
return []
# Normalize query embedding
query_embedding = self.normalize_vectors(query_embedding.reshape(1, -1))
# Search in FAISS
similarities, indices = self.index.search(query_embedding, min(top_k, len(self.articles_metadata)))
results = []
for similarity, idx in zip(similarities[0], indices[0]):
if idx >= 0 and idx < len(self.articles_metadata): # Valid index
article = self.articles_metadata[idx].copy()
article['similarity_score'] = float(similarity)
# Only include if above threshold
if similarity >= settings.similarity_threshold:
results.append(article)
return results
def get_article_by_id(self, article_id: str) -> Optional[Dict[str, Any]]:
"""Get article metadata by ID"""
for article in self.articles_metadata:
if article.get('id') == article_id:
return article
return None
def get_all_articles(self) -> List[Dict[str, Any]]:
"""Get all articles metadata"""
return self.articles_metadata.copy()
def save_index(self):
"""Save FAISS index and metadata to disk"""
try:
# Ensure directory exists
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
# Save FAISS index
if self.index is not None:
faiss.write_index(self.index, self.index_path)
# Save metadata
with open(self.metadata_path, 'wb') as f:
pickle.dump(self.articles_metadata, f)
print(f"Saved vector store to {self.index_path}")
except Exception as e:
print(f"Error saving vector store: {e}")
def load_index(self):
"""Load FAISS index and metadata from disk"""
try:
# Load FAISS index
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
print(f"Loaded FAISS index from {self.index_path}")
# Load metadata
if os.path.exists(self.metadata_path):
with open(self.metadata_path, 'rb') as f:
self.articles_metadata = pickle.load(f)
print(f"Loaded {len(self.articles_metadata)} articles metadata")
except Exception as e:
print(f"Error loading vector store: {e}")
# Create new index if loading fails
self.index = None
self.articles_metadata = []
def clear_index(self):
"""Clear the entire vector store"""
self.index = None
self.articles_metadata = []
# Remove files
for path in [self.index_path, self.metadata_path]:
if os.path.exists(path):
os.remove(path)
print("Cleared vector store")
def get_stats(self) -> Dict[str, Any]:
"""Get vector store statistics"""
return {
'total_articles': len(self.articles_metadata),
'index_dimension': self.dimension,
'index_exists': self.index is not None,
'index_size': self.index.ntotal if self.index else 0,
'last_updated': max([a.get('added_date', '') for a in self.articles_metadata]) if self.articles_metadata else None
}
# Test function
if __name__ == "__main__":
# Test vector store
store = VectorStore()
stats = store.get_stats()
print(f"Vector store stats: {stats}")