update all endpoints
This commit is contained in:
+8
-1
@@ -3,11 +3,18 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Config:
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||
RSS_FEEDS = [
|
||||
"http://rss.nytimes.com/services/xml/rss/nyt/Technology.xml",
|
||||
"https://feeds.bbci.co.uk/news/technology/rss.xml"
|
||||
"https://feeds.bbci.co.uk/news/technology/rss.xml",
|
||||
"https://feeds.a.dj.com/rss/RSSMarketsMain.xml",
|
||||
"https://rss.cnn.com/rss/edition.rss"
|
||||
]
|
||||
VECTOR_DB_PATH = "data/vector_db.index"
|
||||
RAW_NEWS_PATH = "data/raw_news/"
|
||||
PROCESSED_NEWS_PATH = "data/processed_news/"
|
||||
EMBEDDING_MODEL = "embed-english-v3.0"
|
||||
GROQ_MODEL = "llama3-8b-8192"
|
||||
|
||||
+42
-3
@@ -1,8 +1,47 @@
|
||||
import cohere
|
||||
from backend.config import Config
|
||||
from .config import Config
|
||||
|
||||
co = cohere.Client(Config.COHERE_API_KEY)
|
||||
|
||||
|
||||
def get_embeddings(texts):
|
||||
response = co.embed(texts=texts, model="embed-english-v3.0")
|
||||
return response.embeddings
|
||||
"""Generate embeddings using Cohere"""
|
||||
try:
|
||||
response = co.embed(
|
||||
texts=texts,
|
||||
model=Config.EMBEDDING_MODEL,
|
||||
input_type="search_document"
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
print(f"Error generating embeddings: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_query_embedding(query):
|
||||
"""Generate embedding for search query"""
|
||||
try:
|
||||
response = co.embed(
|
||||
texts=[query],
|
||||
model=Config.EMBEDDING_MODEL,
|
||||
input_type="search_query"
|
||||
)
|
||||
return response.embeddings[0]
|
||||
except Exception as e:
|
||||
print(f"Error generating query embedding: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def rerank_results(query, documents):
|
||||
"""Re-rank search results using Cohere"""
|
||||
try:
|
||||
response = co.rerank(
|
||||
model="rerank-english-v2.0",
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=5
|
||||
)
|
||||
return response.results
|
||||
except Exception as e:
|
||||
print(f"Error reranking results: {str(e)}")
|
||||
return []
|
||||
|
||||
+133
-14
@@ -1,20 +1,139 @@
|
||||
from fastapi import FastAPI
|
||||
from backend.news_fetcher import fetch_news
|
||||
from backend.recommender import recommend_similar
|
||||
from backend.config import Config
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .news_fetcher import fetch_all_news, save_raw_news, save_processed_news
|
||||
from .recommender import recommend_similar, process_articles_for_vector_db
|
||||
from .recommender import analyze_article_with_groq
|
||||
from .recommender import get_personalized_recommendations, vector_db
|
||||
from .vector_store import VectorDB
|
||||
from .config import Config
|
||||
import os
|
||||
|
||||
app = FastAPI(title="DS Task AI News", version="1.0.0")
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the application"""
|
||||
# Create necessary directories
|
||||
os.makedirs(Config.RAW_NEWS_PATH, exist_ok=True)
|
||||
os.makedirs(Config.PROCESSED_NEWS_PATH, exist_ok=True)
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
# Load existing vector database if available
|
||||
if os.path.exists(Config.VECTOR_DB_PATH):
|
||||
vector_db.load_index(Config.VECTOR_DB_PATH)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {"message": "DS Task AI News API", "version": "1.0.0"}
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/fetch-news")
|
||||
async def get_latest_news():
|
||||
all_news = []
|
||||
for feed in Config.RSS_FEEDS:
|
||||
all_news.extend(fetch_news(feed))
|
||||
return {"news": all_news}
|
||||
async def fetch_news():
|
||||
"""Fetch news from RSS feeds"""
|
||||
try:
|
||||
articles = fetch_all_news()
|
||||
|
||||
if not articles:
|
||||
raise HTTPException(status_code=404, detail="No articles found")
|
||||
|
||||
# Save raw news
|
||||
raw_file = save_raw_news(articles)
|
||||
|
||||
# Process articles for vector database
|
||||
process_articles_for_vector_db(articles)
|
||||
|
||||
# Save processed news
|
||||
processed_file = save_processed_news(articles)
|
||||
|
||||
# Save vector database
|
||||
vector_db.save_index(Config.VECTOR_DB_PATH)
|
||||
|
||||
return {
|
||||
"message": "News fetched successfully",
|
||||
"articles_count": len(articles),
|
||||
"raw_file": raw_file,
|
||||
"processed_file": processed_file,
|
||||
"articles": articles
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching news: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/recommend")
|
||||
@app.get("/recommend-news")
|
||||
async def recommend_news(article_id: str):
|
||||
sample_text = "AI breakthroughs in 2024"
|
||||
similar_ids = recommend_similar(sample_text)
|
||||
return {"similar_articles": similar_ids}
|
||||
"""Retrieve similar news based on the selected article"""
|
||||
try:
|
||||
recommendations = recommend_similar(article_id)
|
||||
|
||||
if not recommendations:
|
||||
raise HTTPException(status_code=404, detail="No recommendations found")
|
||||
|
||||
return {
|
||||
"article_id": article_id,
|
||||
"recommendations": recommendations,
|
||||
"count": len(recommendations)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting recommendations: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/analyze-article")
|
||||
async def analyze_article(article_id: str):
|
||||
"""Analyze article using Groq LLM"""
|
||||
try:
|
||||
article = vector_db.get_article_by_id(article_id)
|
||||
|
||||
if not article:
|
||||
raise HTTPException(status_code=404, detail="Article not found")
|
||||
|
||||
article_text = f"{article['title']} {article['content']}"
|
||||
analysis = analyze_article_with_groq(article_text)
|
||||
|
||||
return {
|
||||
"article_id": article_id,
|
||||
"article": article,
|
||||
"analysis": analysis
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error analyzing article: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/personalized-recommendations")
|
||||
async def personalized_recommendations(interests: str):
|
||||
"""Get personalized recommendations based on user interests"""
|
||||
try:
|
||||
recommendations = get_personalized_recommendations(interests)
|
||||
|
||||
return {
|
||||
"interests": interests,
|
||||
"recommendations": recommendations,
|
||||
"count": len(recommendations)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error getting personalized recommendations: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "database_articles": len(vector_db.articles)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
+60
-20
@@ -1,26 +1,66 @@
|
||||
# backend/news_fetcher.py
|
||||
from datetime import datetime
|
||||
import feedparser
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from .config import Config
|
||||
|
||||
def fetch_news(rss_url):
|
||||
feed = feedparser.parse(rss_url)
|
||||
|
||||
def fetch_rss_news(feed_url):
|
||||
"""Fetch news from RSS feed"""
|
||||
feed = feedparser.parse(feed_url)
|
||||
articles = []
|
||||
for entry in feed.entries:
|
||||
try:
|
||||
# Try parsing with timezone first
|
||||
pub_date = datetime.strptime(entry.published, "%a, %d %b %Y %H:%M:%S %z")
|
||||
except ValueError:
|
||||
try:
|
||||
# Fallback to GMT format without timezone
|
||||
pub_date = datetime.strptime(entry.published, "%a, %d %b %Y %H:%M:%S %Z")
|
||||
except ValueError:
|
||||
# Final fallback - use current time if parsing fails
|
||||
pub_date = datetime.now()
|
||||
|
||||
articles.append({
|
||||
for entry in feed.entries:
|
||||
article = {
|
||||
"title": entry.title,
|
||||
"content": entry.description,
|
||||
"published": pub_date,
|
||||
"source": rss_url
|
||||
})
|
||||
"content": getattr(entry, 'summary', ''),
|
||||
"date": getattr(entry, 'published', ''),
|
||||
"slug": entry.title.lower().replace(" ", "-").replace(",", "").replace(".", ""),
|
||||
"categories": ["Technology", "AI and Innovation"],
|
||||
"tags": ["AI", "Technology", "Innovation"],
|
||||
"url": getattr(entry, 'link', ''),
|
||||
"source": feed_url
|
||||
}
|
||||
articles.append(article)
|
||||
|
||||
return articles
|
||||
|
||||
|
||||
def fetch_all_news():
|
||||
"""Fetch news from all RSS feeds"""
|
||||
all_articles = []
|
||||
|
||||
for feed_url in Config.RSS_FEEDS:
|
||||
try:
|
||||
articles = fetch_rss_news(feed_url)
|
||||
all_articles.extend(articles)
|
||||
except Exception as e:
|
||||
print(f"Error fetching from {feed_url}: {str(e)}")
|
||||
|
||||
return all_articles
|
||||
|
||||
|
||||
def save_raw_news(articles):
|
||||
"""Save raw news articles to file"""
|
||||
os.makedirs(Config.RAW_NEWS_PATH, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{Config.RAW_NEWS_PATH}news_{timestamp}.json"
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(articles, f, indent=2)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def save_processed_news(articles):
|
||||
"""Save processed news articles to file"""
|
||||
os.makedirs(Config.PROCESSED_NEWS_PATH, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{Config.PROCESSED_NEWS_PATH}processed_news_{timestamp}.json"
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(articles, f, indent=2)
|
||||
|
||||
return filename
|
||||
|
||||
+95
-6
@@ -1,8 +1,97 @@
|
||||
from backend.embeddings import get_embeddings
|
||||
from backend.vector_store import VectorDB
|
||||
from .embeddings import get_embeddings, get_query_embedding, rerank_results
|
||||
from .vector_store import VectorDB
|
||||
import groq
|
||||
from .config import Config
|
||||
|
||||
db = VectorDB()
|
||||
# Initialize Groq client
|
||||
groq_client = groq.Groq(api_key=Config.GROQ_API_KEY)
|
||||
|
||||
def recommend_similar(article_text, top_k=3):
|
||||
query_embed = get_embeddings([article_text])[0]
|
||||
return db.search(query_embed, k=top_k)
|
||||
# Vector database instance
|
||||
vector_db = VectorDB()
|
||||
|
||||
|
||||
def process_articles_for_vector_db(articles):
|
||||
"""Process articles and add to vector database"""
|
||||
if not articles:
|
||||
return
|
||||
|
||||
# Extract text content for embedding
|
||||
texts = []
|
||||
for article in articles:
|
||||
text = f"{article['title']} {article['content']}"
|
||||
texts.append(text)
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = get_embeddings(texts)
|
||||
if embeddings:
|
||||
vector_db.add_vectors(embeddings, articles)
|
||||
|
||||
|
||||
def recommend_similar(article_id, top_n=3):
|
||||
"""Recommend similar articles based on article ID"""
|
||||
article = vector_db.get_article_by_id(article_id)
|
||||
if not article:
|
||||
return []
|
||||
|
||||
# Get embedding for the article
|
||||
article_text = f"{article['title']} {article['content']}"
|
||||
query_embedding = get_query_embedding(article_text)
|
||||
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
# Search for similar articles
|
||||
similar_articles = vector_db.search(query_embedding, k=top_n + 1)
|
||||
|
||||
# Filter out the original article
|
||||
recommendations = [art for art in similar_articles if art.get('slug') != article_id]
|
||||
|
||||
return recommendations[:top_n]
|
||||
|
||||
|
||||
def analyze_article_with_groq(article_text):
|
||||
"""Analyze article using Groq LLM"""
|
||||
try:
|
||||
response = groq_client.chat.completions.create(
|
||||
model=Config.GROQ_MODEL,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an AI news analyst. Provide insights, key points, and sentiment analysis for the given article."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Analyze this news article: {article_text}"
|
||||
}
|
||||
],
|
||||
max_tokens=500,
|
||||
temperature=0.3
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print(f"Error analyzing article with Groq: {str(e)}")
|
||||
return "Analysis unavailable"
|
||||
|
||||
|
||||
def get_personalized_recommendations(user_interests, top_n=5):
|
||||
"""Get personalized recommendations based on user interests"""
|
||||
query_embedding = get_query_embedding(user_interests)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
recommendations = vector_db.search(query_embedding, k=top_n)
|
||||
|
||||
# Re-rank results for better relevance
|
||||
if recommendations:
|
||||
documents = [f"{art['title']} {art['content']}" for art in recommendations]
|
||||
reranked = rerank_results(user_interests, documents)
|
||||
|
||||
if reranked:
|
||||
# Reorder recommendations based on reranking
|
||||
reordered = []
|
||||
for result in reranked:
|
||||
if result.index < len(recommendations):
|
||||
reordered.append(recommendations[result.index])
|
||||
return reordered
|
||||
|
||||
return recommendations
|
||||
|
||||
+51
-6
@@ -1,14 +1,59 @@
|
||||
import numpy as np
|
||||
import faiss
|
||||
from backend.config import Config
|
||||
import os
|
||||
|
||||
|
||||
class VectorDB:
|
||||
def __init__(self):
|
||||
self.index = faiss.IndexFlatL2(768) # Cohere embedding dim
|
||||
self.index = None
|
||||
self.articles = []
|
||||
self.dimension = 1024 # Cohere embedding dimension
|
||||
|
||||
def add_vectors(self, ids, embeddings):
|
||||
self.index.add(np.array(embeddings).astype('float32'))
|
||||
def initialize_index(self):
|
||||
"""Initialize FAISS index"""
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
|
||||
def add_vectors(self, embeddings, articles):
|
||||
"""Add vectors to the database"""
|
||||
if self.index is None:
|
||||
self.initialize_index()
|
||||
|
||||
embeddings_array = np.array(embeddings).astype('float32')
|
||||
self.index.add(embeddings_array)
|
||||
self.articles.extend(articles)
|
||||
|
||||
def search(self, query_embedding, k=5):
|
||||
distances, indices = self.index.search(np.array([query_embedding]), k)
|
||||
return indices[0]
|
||||
"""Search for similar vectors"""
|
||||
if self.index is None or self.index.ntotal == 0:
|
||||
return []
|
||||
|
||||
query_array = np.array([query_embedding]).astype('float32')
|
||||
distances, indices = self.index.search(query_array, k)
|
||||
|
||||
results = []
|
||||
for i, idx in enumerate(indices[0]):
|
||||
if idx < len(self.articles):
|
||||
result = self.articles[idx].copy()
|
||||
result['similarity_score'] = float(distances[0][i])
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def save_index(self, filepath):
|
||||
"""Save the index to file"""
|
||||
if self.index is not None:
|
||||
faiss.write_index(self.index, filepath)
|
||||
|
||||
def load_index(self, filepath):
|
||||
"""Load the index from file"""
|
||||
if os.path.exists(filepath):
|
||||
self.index = faiss.read_index(filepath)
|
||||
|
||||
def get_article_by_id(self, article_id):
|
||||
"""Get article by ID/slug"""
|
||||
for article in self.articles:
|
||||
if (article.get('slug') == article_id or
|
||||
article.get('title') == article_id or
|
||||
article.get('title', '').lower().replace(' ', '-').replace(',', '').replace('.', '') == article_id):
|
||||
return article
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# API Endpoints
|
||||
|
||||
1. **GET /fetch-news** - Fetches news from RSS feeds
|
||||
2. **GET /recommend-news** - Retrieves similar news based on the selected article
|
||||
3. **GET /analyze-article** - Analyzes article using Groq LLM
|
||||
4. **GET /personalized-recommendations** - Get personalized recommendations
|
||||
5. **GET /health** - Health check endpoint
|
||||
|
||||
Reference in New Issue
Block a user