Refactor backend configuration and enhance news fetching functionality
- Introduced a Config dataclass in config.py to manage API keys, RSS feeds, and directory paths more effectively. - Updated the NewsFetcher class to include retry logic for fetching articles from RSS feeds. - Modified the EmbeddingGenerator and NewsRecommender classes to utilize the new configuration structure. - Enhanced main.py to implement API token verification for secure access to news fetching and recommendations.
This commit is contained in:
+56
-22
@@ -1,5 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException, Depends, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
# Load environment variables
|
||||
|
||||
@@ -9,30 +14,59 @@ from dotenv import load_dotenv
|
||||
# Load environment variables from the specified path
|
||||
load_dotenv()
|
||||
|
||||
# API Keys
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
||||
@dataclass
|
||||
class Config:
|
||||
# API Keys
|
||||
cohere_api_key: str = os.getenv("COHERE_API_KEY", "")
|
||||
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
||||
pinecone_api_key: str = os.getenv("PINECONE_API_KEY", "")
|
||||
api_token: str = os.getenv("API_TOKEN", "default_secret_token") # Default token for development
|
||||
|
||||
# Pinecone Configuration
|
||||
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||
# Pinecone Configuration
|
||||
pinecone_index_name: str = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||
vector_dimension: int = 1024 # Cohere embedding dimension
|
||||
top_k_results: int = 5
|
||||
|
||||
# News Sources
|
||||
RSS_FEEDS = [
|
||||
# "https://feeds.feedburner.com/TechCrunch/",
|
||||
# "https://www.theverge.com/rss/index.xml",
|
||||
"https://www.wired.com/feed/rss",
|
||||
"https://www.technologyreview.com/feed/",
|
||||
]
|
||||
# News Sources
|
||||
rss_feeds: List[str] = field(default_factory=lambda: [
|
||||
# "https://feeds.feedburner.com/TechCrunch/",
|
||||
# "https://www.theverge.com/rss/index.xml",
|
||||
"https://www.wired.com/feed/rss",
|
||||
"https://www.technologyreview.com/feed/",
|
||||
])
|
||||
|
||||
# Vector Database Settings
|
||||
VECTOR_DIMENSION = 1024 # Cohere embedding dimension
|
||||
TOP_K_RESULTS = 5
|
||||
# Data Directories
|
||||
raw_news_dir: str = "data/raw_news"
|
||||
processed_news_dir: str = "data/processed_news"
|
||||
|
||||
# Data Directories
|
||||
RAW_NEWS_DIR = "data/raw_news"
|
||||
PROCESSED_NEWS_DIR = "data/processed_news"
|
||||
def __post_init__(self):
|
||||
# Create directories if they don't exist
|
||||
os.makedirs(self.raw_news_dir, exist_ok=True)
|
||||
os.makedirs(self.processed_news_dir, exist_ok=True)
|
||||
|
||||
# Create directories if they don't exist
|
||||
os.makedirs(RAW_NEWS_DIR, exist_ok=True)
|
||||
os.makedirs(PROCESSED_NEWS_DIR, exist_ok=True)
|
||||
# Create a global config instance
|
||||
config = Config()
|
||||
|
||||
# API Key header
|
||||
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
|
||||
|
||||
def verify_api_token(api_key: str):
|
||||
"""
|
||||
Verify the API token from the request header.
|
||||
|
||||
Args:
|
||||
api_key: The API key from the request header
|
||||
|
||||
Returns:
|
||||
The API key if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If the API key is invalid
|
||||
"""
|
||||
if api_key == config.api_token:
|
||||
print(f"API key verified: {api_key}")
|
||||
return api_key
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API token"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import cohere
|
||||
from typing import List, Dict, Any
|
||||
from config import COHERE_API_KEY
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
|
||||
class EmbeddingGenerator:
|
||||
def __init__(self):
|
||||
self.client = cohere.Client(COHERE_API_KEY)
|
||||
def __init__(self, cohere_client: Optional[cohere.Client] = None):
|
||||
self.client = cohere_client or cohere.Client(config.cohere_api_key)
|
||||
|
||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a list of texts using Cohere."""
|
||||
|
||||
+15
-7
@@ -1,4 +1,4 @@
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.responses import HTMLResponse
|
||||
@@ -10,13 +10,20 @@ from news_fetcher import NewsFetcher
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
from recommender import NewsRecommender
|
||||
from config import RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||
from config import config
|
||||
from fastapi import HTTPException
|
||||
|
||||
app = FastAPI(title="DS Task AI News API")
|
||||
|
||||
# Configure templates
|
||||
templates = Jinja2Templates(directory="backend/templates")
|
||||
|
||||
def verify_api_token(token: str):
|
||||
if token == config.api_token:
|
||||
print(f"API key verified: {token}")
|
||||
return token
|
||||
return None
|
||||
|
||||
# Add custom filters
|
||||
def from_json(value):
|
||||
"""Parse a JSON string into a Python object."""
|
||||
@@ -51,7 +58,8 @@ async def root(request: Request):
|
||||
)
|
||||
|
||||
@app.get("/fetch-news", response_class=HTMLResponse)
|
||||
async def fetch_news(request: Request):
|
||||
def fetch_news(request: Request, token: str = Depends(verify_api_token)):
|
||||
# print(f"Fetching news with token: {token}")
|
||||
"""Fetch news from RSS feeds and store in vector database."""
|
||||
try:
|
||||
result = news_fetcher.process()
|
||||
@@ -59,11 +67,11 @@ async def fetch_news(request: Request):
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
|
||||
# Get the latest processed articles
|
||||
processed_files = sorted(os.listdir(PROCESSED_NEWS_DIR), reverse=True)
|
||||
processed_files = sorted(os.listdir(config.processed_news_dir), reverse=True)
|
||||
if not processed_files:
|
||||
raise HTTPException(status_code=404, detail="No processed articles found")
|
||||
|
||||
latest_file = os.path.join(PROCESSED_NEWS_DIR, processed_files[0])
|
||||
latest_file = os.path.join(config.processed_news_dir, processed_files[0])
|
||||
with open(latest_file, 'r', encoding='utf-8') as f:
|
||||
articles = json.load(f)
|
||||
|
||||
@@ -81,7 +89,7 @@ async def fetch_news(request: Request):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/recommend-news", response_class=HTMLResponse)
|
||||
async def recommend_news(request: Request, article_id: str = None, query: str = None):
|
||||
async def recommend_news(request: Request, article_id: str = None, query: str = None, token: str = Depends(verify_api_token)):
|
||||
"""Get news recommendations based on article ID or search query."""
|
||||
try:
|
||||
if article_id:
|
||||
@@ -129,7 +137,7 @@ async def recommend_news(request: Request, article_id: str = None, query: str =
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/article/{article_id}")
|
||||
async def get_article(article_id: str):
|
||||
async def get_article(article_id: str, token: str = Depends(verify_api_token)):
|
||||
"""Get a specific article and its summary."""
|
||||
try:
|
||||
# Search for the article
|
||||
|
||||
+65
-39
@@ -3,12 +3,13 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from config import RSS_FEEDS, RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
from typing import List, Dict, Any, Optional
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import time
|
||||
from config import config
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -22,10 +23,18 @@ logging.basicConfig(
|
||||
logger = logging.getLogger('NewsFetcher')
|
||||
|
||||
class NewsFetcher:
|
||||
def __init__(self):
|
||||
self.feeds = RSS_FEEDS
|
||||
self.embedding_generator = EmbeddingGenerator()
|
||||
self.vector_store = VectorStore()
|
||||
def __init__(
|
||||
self,
|
||||
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: int = 5
|
||||
):
|
||||
self.feeds = config.rss_feeds
|
||||
self.embedding_generator = embedding_generator or EmbeddingGenerator()
|
||||
self.vector_store = vector_store or VectorStore()
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
logger.info("NewsFetcher initialized with %d RSS feeds", len(self.feeds))
|
||||
|
||||
def clean_html_content(self, html_content: str) -> str:
|
||||
@@ -54,32 +63,52 @@ class NewsFetcher:
|
||||
return cleaned_text
|
||||
|
||||
def fetch_rss_news(self, feed_url: str) -> List[Dict[str, Any]]:
|
||||
"""Fetch news articles from a single RSS feed."""
|
||||
"""Fetch news articles from a single RSS feed with retry logic."""
|
||||
logger.info("Fetching news from feed: %s", feed_url)
|
||||
feed = feedparser.parse(feed_url)
|
||||
articles = []
|
||||
|
||||
for entry in feed.entries:
|
||||
# Get raw content with HTML
|
||||
raw_content = entry.get("summary", "")
|
||||
|
||||
# Clean HTML content
|
||||
clean_content = self.clean_html_content(raw_content)
|
||||
|
||||
article = {
|
||||
"title": entry.title,
|
||||
"raw_content": raw_content, # Store original HTML content
|
||||
"content": clean_content, # Store cleaned text content
|
||||
"link": entry.get("link", ""),
|
||||
"published": entry.get("published", datetime.now().isoformat()),
|
||||
"source": feed.feed.get("title", "Unknown"),
|
||||
"categories": [tag.term for tag in entry.get("tags", [])],
|
||||
"id": entry.get("id", entry.get("link", "")),
|
||||
}
|
||||
articles.append(article)
|
||||
|
||||
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
||||
return articles
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
feed = feedparser.parse(feed_url)
|
||||
if not feed.entries:
|
||||
logger.warning("No entries found in feed %s (attempt %d/%d)",
|
||||
feed_url, attempt + 1, self.max_retries)
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
return []
|
||||
|
||||
for entry in feed.entries:
|
||||
# Get raw content with HTML
|
||||
raw_content = entry.get("summary", "")
|
||||
|
||||
# Clean HTML content
|
||||
clean_content = self.clean_html_content(raw_content)
|
||||
|
||||
article = {
|
||||
"title": entry.title,
|
||||
"raw_content": raw_content, # Store original HTML content
|
||||
"content": clean_content, # Store cleaned text content
|
||||
"link": entry.get("link", ""),
|
||||
"published": entry.get("published", datetime.now().isoformat()),
|
||||
"source": feed.feed.get("title", "Unknown"),
|
||||
"categories": [tag.term for tag in entry.get("tags", [])],
|
||||
"id": entry.get("id", entry.get("link", "")),
|
||||
}
|
||||
articles.append(article)
|
||||
|
||||
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
||||
return articles
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching from %s (attempt %d/%d): %s",
|
||||
feed_url, attempt + 1, self.max_retries, str(e))
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error("Failed to fetch from %s after %d attempts",
|
||||
feed_url, self.max_retries)
|
||||
return []
|
||||
|
||||
def fetch_all_news(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch news from all configured RSS feeds."""
|
||||
@@ -87,12 +116,9 @@ class NewsFetcher:
|
||||
all_articles = []
|
||||
|
||||
for feed_url in self.feeds:
|
||||
try:
|
||||
articles = self.fetch_rss_news(feed_url)
|
||||
all_articles.extend(articles)
|
||||
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
||||
except Exception as e:
|
||||
logger.error("Error fetching from %s: %s", feed_url, str(e))
|
||||
articles = self.fetch_rss_news(feed_url)
|
||||
all_articles.extend(articles)
|
||||
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
||||
|
||||
logger.info("Total articles fetched: %d", len(all_articles))
|
||||
return all_articles
|
||||
@@ -101,7 +127,7 @@ class NewsFetcher:
|
||||
"""Save raw articles to a JSON file."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"raw_news_{timestamp}.json"
|
||||
filepath = os.path.join(RAW_NEWS_DIR, filename)
|
||||
filepath = os.path.join(config.raw_news_dir, filename)
|
||||
|
||||
logger.info("Saving %d raw articles to %s", len(articles), filepath)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
@@ -114,7 +140,7 @@ class NewsFetcher:
|
||||
"""Save processed articles with embeddings to a JSON file."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"processed_news_{timestamp}.json"
|
||||
filepath = os.path.join(PROCESSED_NEWS_DIR, filename)
|
||||
filepath = os.path.join(config.processed_news_dir, filename)
|
||||
|
||||
# Create a copy of articles without raw_content for processed storage
|
||||
processed_articles = []
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from groq import Groq
|
||||
from typing import List, Dict, Any
|
||||
from config import GROQ_API_KEY
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
import json
|
||||
|
||||
class NewsRecommender:
|
||||
def __init__(self):
|
||||
self.client = Groq(api_key=GROQ_API_KEY)
|
||||
def __init__(self, groq_client: Optional[Groq] = None):
|
||||
self.client = groq_client or Groq(api_key=config.groq_api_key)
|
||||
|
||||
def analyze_articles(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Analyze a set of articles using Groq to generate insights."""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
<div class="p-6">
|
||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">Latest News</h2>
|
||||
<p class="text-gray-600 mb-6">View the latest news articles fetched from our RSS feeds.</p>
|
||||
<a href="/fetch-news" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
||||
<a href="/fetch-news?token=default_secret_token" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
||||
View Latest News
|
||||
</a>
|
||||
</div>
|
||||
@@ -27,7 +27,7 @@
|
||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">News Recommendations</h2>
|
||||
<p class="text-gray-600 mb-6">Get personalized news recommendations based on your interests.</p>
|
||||
<div class="space-y-4">
|
||||
<a href="/recommend-news?query=technology" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
<a href="/recommend-news?query=technology?token=default_secret_token" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
Technology News
|
||||
</a>
|
||||
<a href="/recommend-news?query=artificial intelligence" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
|
||||
+9
-14
@@ -1,16 +1,11 @@
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
from typing import List, Dict, Any
|
||||
from config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_INDEX_NAME,
|
||||
VECTOR_DIMENSION,
|
||||
TOP_K_RESULTS
|
||||
)
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self):
|
||||
self.pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
||||
self.index_name = PINECONE_INDEX_NAME
|
||||
def __init__(self, pinecone_client: Optional[Pinecone] = None):
|
||||
self.pinecone = pinecone_client or Pinecone(api_key=config.pinecone_api_key)
|
||||
self.index_name = config.pinecone_index_name
|
||||
self._ensure_index()
|
||||
|
||||
def _ensure_index(self):
|
||||
@@ -20,11 +15,11 @@ class VectorStore:
|
||||
# Create a new index with the correct dimension
|
||||
self.pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=VECTOR_DIMENSION,
|
||||
dimension=config.vector_dimension,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||
)
|
||||
print(f"Created new index '{self.index_name}' with dimension {VECTOR_DIMENSION}")
|
||||
print(f"Created new index '{self.index_name}' with dimension {config.vector_dimension}")
|
||||
|
||||
self.index = self.pinecone.Index(self.index_name)
|
||||
|
||||
@@ -57,12 +52,12 @@ class VectorStore:
|
||||
print(f"Error upserting articles: {str(e)}")
|
||||
return False
|
||||
|
||||
def search_similar(self, query_embedding: List[float], top_k: int = TOP_K_RESULTS) -> List[Dict[str, Any]]:
|
||||
def search_similar(self, query_embedding: List[float], top_k: int = None) -> List[Dict[str, Any]]:
|
||||
"""Search for similar articles using the query embedding."""
|
||||
try:
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
top_k=top_k,
|
||||
top_k=top_k or config.top_k_results,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user