152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
"""News recommendation system"""
|
|
from typing import List, Dict, Any, Optional
|
|
import numpy as np
|
|
from embeddings import EmbeddingGenerator
|
|
from vector_store import VectorStore
|
|
from config import settings
|
|
|
|
class NewsRecommender:
|
|
def __init__(self):
|
|
self.embedding_generator = EmbeddingGenerator()
|
|
self.vector_store = VectorStore()
|
|
|
|
def recommend_by_article_id(self, article_id: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""Recommend articles similar to a given article ID"""
|
|
# Get the article
|
|
article = self.vector_store.get_article_by_id(article_id)
|
|
if not article:
|
|
return []
|
|
|
|
# Create text from article for embedding
|
|
article_text = self.embedding_generator.create_article_text(article)
|
|
|
|
# Generate embedding for the article
|
|
query_embedding = self.embedding_generator.generate_query_embedding(article_text)
|
|
|
|
# Search for similar articles
|
|
similar_articles = self.vector_store.search_similar(query_embedding, top_k + 1) # +1 to exclude self
|
|
|
|
# Remove the original article from results
|
|
filtered_results = [a for a in similar_articles if a.get('id') != article_id]
|
|
|
|
return filtered_results[:top_k]
|
|
|
|
def recommend_by_query(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""Recommend articles based on a text query"""
|
|
if not query.strip():
|
|
return []
|
|
|
|
# Generate embedding for query
|
|
query_embedding = self.embedding_generator.generate_query_embedding(query)
|
|
|
|
# Search for similar articles
|
|
similar_articles = self.vector_store.search_similar(query_embedding, top_k)
|
|
|
|
return similar_articles
|
|
|
|
def recommend_by_interests(self, interests: List[str], top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""Recommend articles based on user interests"""
|
|
if not interests:
|
|
return []
|
|
|
|
# Combine interests into a query
|
|
query = " ".join(interests)
|
|
|
|
return self.recommend_by_query(query, top_k)
|
|
|
|
def get_trending_articles(self, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""Get trending articles (most recent for now)"""
|
|
all_articles = self.vector_store.get_all_articles()
|
|
|
|
# Sort by published date (most recent first)
|
|
sorted_articles = sorted(
|
|
all_articles,
|
|
key=lambda x: x.get('published_date', ''),
|
|
reverse=True
|
|
)
|
|
|
|
return sorted_articles[:top_k]
|
|
|
|
def get_articles_by_source(self, source: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""Get articles from a specific source"""
|
|
all_articles = self.vector_store.get_all_articles()
|
|
|
|
# Filter by source
|
|
source_articles = [
|
|
article for article in all_articles
|
|
if article.get('source', '').lower() == source.lower()
|
|
]
|
|
|
|
# Sort by published date
|
|
sorted_articles = sorted(
|
|
source_articles,
|
|
key=lambda x: x.get('published_date', ''),
|
|
reverse=True
|
|
)
|
|
|
|
return sorted_articles[:top_k]
|
|
|
|
def add_articles_to_store(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Add new articles to the vector store"""
|
|
if not articles:
|
|
return {"success": False, "message": "No articles provided"}
|
|
|
|
try:
|
|
# Generate embeddings
|
|
embeddings = self.embedding_generator.generate_embeddings(articles)
|
|
|
|
# Add to vector store
|
|
self.vector_store.add_articles(articles, embeddings)
|
|
|
|
return {
|
|
"success": True,
|
|
"articles_added": len(articles),
|
|
"total_articles": len(self.vector_store.get_all_articles())
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"message": f"Error adding articles: {str(e)}"
|
|
}
|
|
|
|
def get_store_stats(self) -> Dict[str, Any]:
|
|
"""Get vector store statistics"""
|
|
return self.vector_store.get_stats()
|
|
|
|
def search_articles(self, query: str, filters: Optional[Dict[str, Any]] = None,
|
|
top_k: int = 10) -> List[Dict[str, Any]]:
|
|
"""Advanced search with filters"""
|
|
# Get basic recommendations
|
|
results = self.recommend_by_query(query, top_k * 2) # Get more to allow filtering
|
|
|
|
# Apply filters if provided
|
|
if filters:
|
|
filtered_results = []
|
|
|
|
for article in results:
|
|
include = True
|
|
|
|
# Source filter
|
|
if 'source' in filters:
|
|
if article.get('source', '').lower() != filters['source'].lower():
|
|
include = False
|
|
|
|
# Date range filter (simplified)
|
|
if 'date_from' in filters or 'date_to' in filters:
|
|
# This would need proper date parsing in a real implementation
|
|
pass
|
|
|
|
if include:
|
|
filtered_results.append(article)
|
|
|
|
results = filtered_results
|
|
|
|
return results[:top_k]
|
|
|
|
# Test function
|
|
if __name__ == "__main__":
|
|
recommender = NewsRecommender()
|
|
stats = recommender.get_store_stats()
|
|
print(f"Recommender stats: {stats}")
|