Refactor backend configuration and enhance news fetching functionality
- Introduced a Config dataclass in config.py to manage API keys, RSS feeds, and directory paths more effectively. - Updated the NewsFetcher class to include retry logic for fetching articles from RSS feeds. - Modified the EmbeddingGenerator and NewsRecommender classes to utilize the new configuration structure. - Enhanced main.py to implement API token verification for secure access to news fetching and recommendations.
This commit is contained in:
@@ -55,8 +55,7 @@ DS Task AI News is a web application that uses AI technologies to fetch, analyze
|
||||
```
|
||||
|
||||
4. Run the application:
|
||||
```
|
||||
python backend/main.py
|
||||
``` python backend/main.py
|
||||
```
|
||||
|
||||
5. Open your web browser and navigate to `http://localhost:8000`.
|
||||
@@ -104,3 +103,4 @@ This project is licensed under the MIT License - see the LICENSE file for detail
|
||||
- [Cohere](https://cohere.ai/)
|
||||
- [Pinecone](https://www.pinecone.io/)
|
||||
- [Groq](https://groq.com/)
|
||||
|
||||
|
||||
+52
-18
@@ -1,5 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import HTTPException, Depends, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
|
||||
# Load environment variables
|
||||
|
||||
@@ -9,30 +14,59 @@ from dotenv import load_dotenv
|
||||
# Load environment variables from the specified path
|
||||
load_dotenv()
|
||||
|
||||
# API Keys
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
||||
@dataclass
|
||||
class Config:
|
||||
# API Keys
|
||||
cohere_api_key: str = os.getenv("COHERE_API_KEY", "")
|
||||
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
||||
pinecone_api_key: str = os.getenv("PINECONE_API_KEY", "")
|
||||
api_token: str = os.getenv("API_TOKEN", "default_secret_token") # Default token for development
|
||||
|
||||
# Pinecone Configuration
|
||||
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||
# Pinecone Configuration
|
||||
pinecone_index_name: str = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||
vector_dimension: int = 1024 # Cohere embedding dimension
|
||||
top_k_results: int = 5
|
||||
|
||||
# News Sources
|
||||
RSS_FEEDS = [
|
||||
# News Sources
|
||||
rss_feeds: List[str] = field(default_factory=lambda: [
|
||||
# "https://feeds.feedburner.com/TechCrunch/",
|
||||
# "https://www.theverge.com/rss/index.xml",
|
||||
"https://www.wired.com/feed/rss",
|
||||
"https://www.technologyreview.com/feed/",
|
||||
]
|
||||
])
|
||||
|
||||
# Vector Database Settings
|
||||
VECTOR_DIMENSION = 1024 # Cohere embedding dimension
|
||||
TOP_K_RESULTS = 5
|
||||
# Data Directories
|
||||
raw_news_dir: str = "data/raw_news"
|
||||
processed_news_dir: str = "data/processed_news"
|
||||
|
||||
# Data Directories
|
||||
RAW_NEWS_DIR = "data/raw_news"
|
||||
PROCESSED_NEWS_DIR = "data/processed_news"
|
||||
def __post_init__(self):
|
||||
# Create directories if they don't exist
|
||||
os.makedirs(self.raw_news_dir, exist_ok=True)
|
||||
os.makedirs(self.processed_news_dir, exist_ok=True)
|
||||
|
||||
# Create directories if they don't exist
|
||||
os.makedirs(RAW_NEWS_DIR, exist_ok=True)
|
||||
os.makedirs(PROCESSED_NEWS_DIR, exist_ok=True)
|
||||
# Create a global config instance
|
||||
config = Config()
|
||||
|
||||
# API Key header
|
||||
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
|
||||
|
||||
def verify_api_token(api_key: str):
|
||||
"""
|
||||
Verify the API token from the request header.
|
||||
|
||||
Args:
|
||||
api_key: The API key from the request header
|
||||
|
||||
Returns:
|
||||
The API key if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If the API key is invalid
|
||||
"""
|
||||
if api_key == config.api_token:
|
||||
print(f"API key verified: {api_key}")
|
||||
return api_key
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API token"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import cohere
|
||||
from typing import List, Dict, Any
|
||||
from config import COHERE_API_KEY
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
|
||||
class EmbeddingGenerator:
|
||||
def __init__(self):
|
||||
self.client = cohere.Client(COHERE_API_KEY)
|
||||
def __init__(self, cohere_client: Optional[cohere.Client] = None):
|
||||
self.client = cohere_client or cohere.Client(config.cohere_api_key)
|
||||
|
||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for a list of texts using Cohere."""
|
||||
|
||||
+15
-7
@@ -1,4 +1,4 @@
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.responses import HTMLResponse
|
||||
@@ -10,13 +10,20 @@ from news_fetcher import NewsFetcher
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
from recommender import NewsRecommender
|
||||
from config import RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||
from config import config
|
||||
from fastapi import HTTPException
|
||||
|
||||
app = FastAPI(title="DS Task AI News API")
|
||||
|
||||
# Configure templates
|
||||
templates = Jinja2Templates(directory="backend/templates")
|
||||
|
||||
def verify_api_token(token: str):
|
||||
if token == config.api_token:
|
||||
print(f"API key verified: {token}")
|
||||
return token
|
||||
return None
|
||||
|
||||
# Add custom filters
|
||||
def from_json(value):
|
||||
"""Parse a JSON string into a Python object."""
|
||||
@@ -51,7 +58,8 @@ async def root(request: Request):
|
||||
)
|
||||
|
||||
@app.get("/fetch-news", response_class=HTMLResponse)
|
||||
async def fetch_news(request: Request):
|
||||
def fetch_news(request: Request, token: str = Depends(verify_api_token)):
|
||||
# print(f"Fetching news with token: {token}")
|
||||
"""Fetch news from RSS feeds and store in vector database."""
|
||||
try:
|
||||
result = news_fetcher.process()
|
||||
@@ -59,11 +67,11 @@ async def fetch_news(request: Request):
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
|
||||
# Get the latest processed articles
|
||||
processed_files = sorted(os.listdir(PROCESSED_NEWS_DIR), reverse=True)
|
||||
processed_files = sorted(os.listdir(config.processed_news_dir), reverse=True)
|
||||
if not processed_files:
|
||||
raise HTTPException(status_code=404, detail="No processed articles found")
|
||||
|
||||
latest_file = os.path.join(PROCESSED_NEWS_DIR, processed_files[0])
|
||||
latest_file = os.path.join(config.processed_news_dir, processed_files[0])
|
||||
with open(latest_file, 'r', encoding='utf-8') as f:
|
||||
articles = json.load(f)
|
||||
|
||||
@@ -81,7 +89,7 @@ async def fetch_news(request: Request):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/recommend-news", response_class=HTMLResponse)
|
||||
async def recommend_news(request: Request, article_id: str = None, query: str = None):
|
||||
async def recommend_news(request: Request, article_id: str = None, query: str = None, token: str = Depends(verify_api_token)):
|
||||
"""Get news recommendations based on article ID or search query."""
|
||||
try:
|
||||
if article_id:
|
||||
@@ -129,7 +137,7 @@ async def recommend_news(request: Request, article_id: str = None, query: str =
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/article/{article_id}")
|
||||
async def get_article(article_id: str):
|
||||
async def get_article(article_id: str, token: str = Depends(verify_api_token)):
|
||||
"""Get a specific article and its summary."""
|
||||
try:
|
||||
# Search for the article
|
||||
|
||||
+41
-15
@@ -3,12 +3,13 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from config import RSS_FEEDS, RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
from typing import List, Dict, Any, Optional
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import time
|
||||
from config import config
|
||||
from embeddings import EmbeddingGenerator
|
||||
from vector_store import VectorStore
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -22,10 +23,18 @@ logging.basicConfig(
|
||||
logger = logging.getLogger('NewsFetcher')
|
||||
|
||||
class NewsFetcher:
|
||||
def __init__(self):
|
||||
self.feeds = RSS_FEEDS
|
||||
self.embedding_generator = EmbeddingGenerator()
|
||||
self.vector_store = VectorStore()
|
||||
def __init__(
|
||||
self,
|
||||
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: int = 5
|
||||
):
|
||||
self.feeds = config.rss_feeds
|
||||
self.embedding_generator = embedding_generator or EmbeddingGenerator()
|
||||
self.vector_store = vector_store or VectorStore()
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
logger.info("NewsFetcher initialized with %d RSS feeds", len(self.feeds))
|
||||
|
||||
def clean_html_content(self, html_content: str) -> str:
|
||||
@@ -54,11 +63,21 @@ class NewsFetcher:
|
||||
return cleaned_text
|
||||
|
||||
def fetch_rss_news(self, feed_url: str) -> List[Dict[str, Any]]:
|
||||
"""Fetch news articles from a single RSS feed."""
|
||||
"""Fetch news articles from a single RSS feed with retry logic."""
|
||||
logger.info("Fetching news from feed: %s", feed_url)
|
||||
feed = feedparser.parse(feed_url)
|
||||
articles = []
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
feed = feedparser.parse(feed_url)
|
||||
if not feed.entries:
|
||||
logger.warning("No entries found in feed %s (attempt %d/%d)",
|
||||
feed_url, attempt + 1, self.max_retries)
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
continue
|
||||
return []
|
||||
|
||||
for entry in feed.entries:
|
||||
# Get raw content with HTML
|
||||
raw_content = entry.get("summary", "")
|
||||
@@ -81,18 +100,25 @@ class NewsFetcher:
|
||||
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
||||
return articles
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error fetching from %s (attempt %d/%d): %s",
|
||||
feed_url, attempt + 1, self.max_retries, str(e))
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error("Failed to fetch from %s after %d attempts",
|
||||
feed_url, self.max_retries)
|
||||
return []
|
||||
|
||||
def fetch_all_news(self) -> List[Dict[str, Any]]:
|
||||
"""Fetch news from all configured RSS feeds."""
|
||||
logger.info("Starting to fetch news from all %d feeds", len(self.feeds))
|
||||
all_articles = []
|
||||
|
||||
for feed_url in self.feeds:
|
||||
try:
|
||||
articles = self.fetch_rss_news(feed_url)
|
||||
all_articles.extend(articles)
|
||||
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
||||
except Exception as e:
|
||||
logger.error("Error fetching from %s: %s", feed_url, str(e))
|
||||
|
||||
logger.info("Total articles fetched: %d", len(all_articles))
|
||||
return all_articles
|
||||
@@ -101,7 +127,7 @@ class NewsFetcher:
|
||||
"""Save raw articles to a JSON file."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"raw_news_{timestamp}.json"
|
||||
filepath = os.path.join(RAW_NEWS_DIR, filename)
|
||||
filepath = os.path.join(config.raw_news_dir, filename)
|
||||
|
||||
logger.info("Saving %d raw articles to %s", len(articles), filepath)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
@@ -114,7 +140,7 @@ class NewsFetcher:
|
||||
"""Save processed articles with embeddings to a JSON file."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"processed_news_{timestamp}.json"
|
||||
filepath = os.path.join(PROCESSED_NEWS_DIR, filename)
|
||||
filepath = os.path.join(config.processed_news_dir, filename)
|
||||
|
||||
# Create a copy of articles without raw_content for processed storage
|
||||
processed_articles = []
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from groq import Groq
|
||||
from typing import List, Dict, Any
|
||||
from config import GROQ_API_KEY
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
import json
|
||||
|
||||
class NewsRecommender:
|
||||
def __init__(self):
|
||||
self.client = Groq(api_key=GROQ_API_KEY)
|
||||
def __init__(self, groq_client: Optional[Groq] = None):
|
||||
self.client = groq_client or Groq(api_key=config.groq_api_key)
|
||||
|
||||
def analyze_articles(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Analyze a set of articles using Groq to generate insights."""
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
<div class="p-6">
|
||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">Latest News</h2>
|
||||
<p class="text-gray-600 mb-6">View the latest news articles fetched from our RSS feeds.</p>
|
||||
<a href="/fetch-news" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
||||
<a href="/fetch-news?token=default_secret_token" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
||||
View Latest News
|
||||
</a>
|
||||
</div>
|
||||
@@ -27,7 +27,7 @@
|
||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">News Recommendations</h2>
|
||||
<p class="text-gray-600 mb-6">Get personalized news recommendations based on your interests.</p>
|
||||
<div class="space-y-4">
|
||||
<a href="/recommend-news?query=technology" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
<a href="/recommend-news?query=technology?token=default_secret_token" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
Technology News
|
||||
</a>
|
||||
<a href="/recommend-news?query=artificial intelligence" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||
|
||||
+9
-14
@@ -1,16 +1,11 @@
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
from typing import List, Dict, Any
|
||||
from config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_INDEX_NAME,
|
||||
VECTOR_DIMENSION,
|
||||
TOP_K_RESULTS
|
||||
)
|
||||
from typing import List, Dict, Any, Optional
|
||||
from config import config
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self):
|
||||
self.pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
||||
self.index_name = PINECONE_INDEX_NAME
|
||||
def __init__(self, pinecone_client: Optional[Pinecone] = None):
|
||||
self.pinecone = pinecone_client or Pinecone(api_key=config.pinecone_api_key)
|
||||
self.index_name = config.pinecone_index_name
|
||||
self._ensure_index()
|
||||
|
||||
def _ensure_index(self):
|
||||
@@ -20,11 +15,11 @@ class VectorStore:
|
||||
# Create a new index with the correct dimension
|
||||
self.pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=VECTOR_DIMENSION,
|
||||
dimension=config.vector_dimension,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||
)
|
||||
print(f"Created new index '{self.index_name}' with dimension {VECTOR_DIMENSION}")
|
||||
print(f"Created new index '{self.index_name}' with dimension {config.vector_dimension}")
|
||||
|
||||
self.index = self.pinecone.Index(self.index_name)
|
||||
|
||||
@@ -57,12 +52,12 @@ class VectorStore:
|
||||
print(f"Error upserting articles: {str(e)}")
|
||||
return False
|
||||
|
||||
def search_similar(self, query_embedding: List[float], top_k: int = TOP_K_RESULTS) -> List[Dict[str, Any]]:
|
||||
def search_similar(self, query_embedding: List[float], top_k: int = None) -> List[Dict[str, Any]]:
|
||||
"""Search for similar articles using the query embedding."""
|
||||
try:
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
top_k=top_k,
|
||||
top_k=top_k or config.top_k_results,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Reference in New Issue
Block a user