Add backend functionality for news fetching, processing, and recommendations
- Implemented NewsFetcher class to fetch articles from RSS feeds and clean HTML content. - Added EmbeddingGenerator for generating embeddings using Cohere API. - Created VectorStore for storing and retrieving articles using Pinecone. - Developed NewsRecommender for analyzing articles and generating insights with Groq. - Set up FastAPI application with endpoints for fetching news and providing recommendations. - Configured logging for better traceability and debugging. - Updated .gitignore to include environment variables and data directories. - Added requirements.txt for project dependencies.
This commit is contained in:
+38
-4
@@ -1,9 +1,43 @@
|
|||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
# Virtual Environment
|
# Virtual Environment
|
||||||
.venv
|
.venv
|
||||||
|
|
||||||
# Environment Variables
|
|
||||||
.env
|
|
||||||
|
|
||||||
# vscode settings
|
# IDE
|
||||||
.vscode
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# Data directories
|
||||||
|
data/raw_news/
|
||||||
|
data/processed_news/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
|
||||||
|
# Construct the path to the .env file
|
||||||
|
# dotenv_path = os.path.join(os.path.dirname(__file__), '..', '.env')
|
||||||
|
|
||||||
|
# Load environment variables from the specified path
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# API Keys
|
||||||
|
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||||
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
||||||
|
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
||||||
|
|
||||||
|
# Pinecone Configuration
|
||||||
|
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||||
|
|
||||||
|
# News Sources
|
||||||
|
RSS_FEEDS = [
|
||||||
|
"https://feeds.feedburner.com/TechCrunch/",
|
||||||
|
# "https://www.theverge.com/rss/index.xml",
|
||||||
|
# "https://www.wired.com/feed/rss",
|
||||||
|
# "https://www.technologyreview.com/feed/",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Vector Database Settings
|
||||||
|
VECTOR_DIMENSION = 4096 # Cohere embedding dimension
|
||||||
|
TOP_K_RESULTS = 5
|
||||||
|
|
||||||
|
# Data Directories
|
||||||
|
RAW_NEWS_DIR = "data/raw_news"
|
||||||
|
PROCESSED_NEWS_DIR = "data/processed_news"
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
os.makedirs(RAW_NEWS_DIR, exist_ok=True)
|
||||||
|
os.makedirs(PROCESSED_NEWS_DIR, exist_ok=True)
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import cohere
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from config import COHERE_API_KEY
|
||||||
|
|
||||||
|
class EmbeddingGenerator:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = cohere.Client(COHERE_API_KEY)
|
||||||
|
|
||||||
|
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for a list of texts using Cohere."""
|
||||||
|
try:
|
||||||
|
response = self.client.embed(
|
||||||
|
texts=texts,
|
||||||
|
model="embed-english-v3.0",
|
||||||
|
input_type="search_document"
|
||||||
|
)
|
||||||
|
return response.embeddings
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating embeddings: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def process_articles(self, articles: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Process articles and add embeddings to them."""
|
||||||
|
# Prepare texts for embedding
|
||||||
|
texts = [
|
||||||
|
f"{article['title']} {article['content']}"
|
||||||
|
for article in articles
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
embeddings = self.generate_embeddings(texts)
|
||||||
|
|
||||||
|
# Add embeddings to articles
|
||||||
|
for article, embedding in zip(articles, embeddings):
|
||||||
|
article["embedding"] = embedding
|
||||||
|
|
||||||
|
return articles
|
||||||
|
|
||||||
|
def get_query_embedding(self, query: str) -> List[float]:
|
||||||
|
"""Generate embedding for a search query."""
|
||||||
|
try:
|
||||||
|
response = self.client.embed(
|
||||||
|
texts=[query],
|
||||||
|
model="embed-english-v3.0",
|
||||||
|
input_type="search_query"
|
||||||
|
)
|
||||||
|
return response.embeddings[0]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating query embedding: {str(e)}")
|
||||||
|
return []
|
||||||
|
|||||||
+112
@@ -0,0 +1,112 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from news_fetcher import NewsFetcher
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
from vector_store import VectorStore
|
||||||
|
from recommender import NewsRecommender
|
||||||
|
from config import RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||||
|
|
||||||
|
app = FastAPI(title="DS Task AI News API")
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
news_fetcher = NewsFetcher()
|
||||||
|
embedding_generator = EmbeddingGenerator()
|
||||||
|
vector_store = VectorStore()
|
||||||
|
recommender = NewsRecommender()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint returning API information."""
|
||||||
|
return {
|
||||||
|
"name": "DS Task AI News API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "AI-powered news retrieval and recommendation system"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/fetch-news")
|
||||||
|
async def fetch_news():
|
||||||
|
"""Fetch news from RSS feeds and store in vector database."""
|
||||||
|
try:
|
||||||
|
result = news_fetcher.process()
|
||||||
|
if result["status"] == "error":
|
||||||
|
raise HTTPException(status_code=404, detail=result["message"])
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/recommend-news")
|
||||||
|
async def recommend_news(article_id: str = None, query: str = None):
|
||||||
|
"""Get news recommendations based on article ID or search query."""
|
||||||
|
try:
|
||||||
|
if article_id:
|
||||||
|
# Get article from vector store
|
||||||
|
article = vector_store.search_similar([0] * 4096, top_k=1) # Placeholder vector
|
||||||
|
if not article:
|
||||||
|
raise HTTPException(status_code=404, detail="Article not found")
|
||||||
|
|
||||||
|
# Generate query embedding from article content
|
||||||
|
query_embedding = embedding_generator.get_query_embedding(
|
||||||
|
f"{article[0]['title']} {article[0]['content']}"
|
||||||
|
)
|
||||||
|
elif query:
|
||||||
|
# Generate query embedding from search query
|
||||||
|
query_embedding = embedding_generator.get_query_embedding(query)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Either article_id or query parameter is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search for similar articles
|
||||||
|
similar_articles = vector_store.search_similar(query_embedding)
|
||||||
|
if not similar_articles:
|
||||||
|
raise HTTPException(status_code=404, detail="No similar articles found")
|
||||||
|
|
||||||
|
# Generate insights for the articles
|
||||||
|
insights = recommender.analyze_articles(similar_articles)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"articles": similar_articles,
|
||||||
|
"insights": insights
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/article/{article_id}")
|
||||||
|
async def get_article(article_id: str):
|
||||||
|
"""Get a specific article and its summary."""
|
||||||
|
try:
|
||||||
|
# Search for the article
|
||||||
|
articles = vector_store.search_similar([0] * 4096, top_k=1) # Placeholder vector
|
||||||
|
if not articles:
|
||||||
|
raise HTTPException(status_code=404, detail="Article not found")
|
||||||
|
|
||||||
|
article = articles[0]
|
||||||
|
|
||||||
|
# Generate summary
|
||||||
|
summary = recommender.generate_summary(article)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"article": article,
|
||||||
|
"summary": summary
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|||||||
@@ -0,0 +1,178 @@
|
|||||||
|
import feedparser
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from config import RSS_FEEDS, RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
from vector_store import VectorStore
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('news_fetcher.log')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger('NewsFetcher')
|
||||||
|
|
||||||
|
class NewsFetcher:
|
||||||
|
def __init__(self):
|
||||||
|
self.feeds = RSS_FEEDS
|
||||||
|
self.embedding_generator = EmbeddingGenerator()
|
||||||
|
self.vector_store = VectorStore()
|
||||||
|
logger.info("NewsFetcher initialized with %d RSS feeds", len(self.feeds))
|
||||||
|
|
||||||
|
def clean_html_content(self, html_content: str) -> str:
|
||||||
|
"""Clean HTML content and extract plain text."""
|
||||||
|
logger.debug("Cleaning HTML content of length %d", len(html_content))
|
||||||
|
# Parse HTML with BeautifulSoup
|
||||||
|
soup = BeautifulSoup(html_content, 'html.parser')
|
||||||
|
|
||||||
|
# Remove script and style elements
|
||||||
|
for script in soup(["script", "style"]):
|
||||||
|
script.decompose()
|
||||||
|
|
||||||
|
# Get text content
|
||||||
|
text = soup.get_text()
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
lines = (line.strip() for line in text.splitlines())
|
||||||
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||||
|
text = ' '.join(chunk for chunk in chunks if chunk)
|
||||||
|
|
||||||
|
# Remove extra spaces
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
|
||||||
|
cleaned_text = text.strip()
|
||||||
|
logger.debug("Cleaned text length: %d", len(cleaned_text))
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def fetch_rss_news(self, feed_url: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetch news articles from a single RSS feed."""
|
||||||
|
logger.info("Fetching news from feed: %s", feed_url)
|
||||||
|
feed = feedparser.parse(feed_url)
|
||||||
|
articles = []
|
||||||
|
|
||||||
|
for entry in feed.entries:
|
||||||
|
# Get raw content with HTML
|
||||||
|
raw_content = entry.get("summary", "")
|
||||||
|
|
||||||
|
# Clean HTML content
|
||||||
|
clean_content = self.clean_html_content(raw_content)
|
||||||
|
|
||||||
|
article = {
|
||||||
|
"title": entry.title,
|
||||||
|
"raw_content": raw_content, # Store original HTML content
|
||||||
|
"content": clean_content, # Store cleaned text content
|
||||||
|
"link": entry.get("link", ""),
|
||||||
|
"published": entry.get("published", datetime.now().isoformat()),
|
||||||
|
"source": feed.feed.get("title", "Unknown"),
|
||||||
|
"categories": [tag.term for tag in entry.get("tags", [])],
|
||||||
|
"id": entry.get("id", entry.get("link", "")),
|
||||||
|
}
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
||||||
|
return articles
|
||||||
|
|
||||||
|
def fetch_all_news(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetch news from all configured RSS feeds."""
|
||||||
|
logger.info("Starting to fetch news from all %d feeds", len(self.feeds))
|
||||||
|
all_articles = []
|
||||||
|
|
||||||
|
for feed_url in self.feeds:
|
||||||
|
try:
|
||||||
|
articles = self.fetch_rss_news(feed_url)
|
||||||
|
all_articles.extend(articles)
|
||||||
|
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error fetching from %s: %s", feed_url, str(e))
|
||||||
|
|
||||||
|
logger.info("Total articles fetched: %d", len(all_articles))
|
||||||
|
return all_articles
|
||||||
|
|
||||||
|
def save_raw_articles(self, articles: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Save raw articles to a JSON file."""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"raw_news_{timestamp}.json"
|
||||||
|
filepath = os.path.join(RAW_NEWS_DIR, filename)
|
||||||
|
|
||||||
|
logger.info("Saving %d raw articles to %s", len(articles), filepath)
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(articles, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info("Raw articles saved successfully")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
def save_processed_articles(self, articles: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Save processed articles with embeddings to a JSON file."""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"processed_news_{timestamp}.json"
|
||||||
|
filepath = os.path.join(PROCESSED_NEWS_DIR, filename)
|
||||||
|
|
||||||
|
# Create a copy of articles without raw_content for processed storage
|
||||||
|
processed_articles = []
|
||||||
|
for article in articles:
|
||||||
|
processed_article = article.copy()
|
||||||
|
processed_article.pop('raw_content', None) # Remove raw_content from processed articles
|
||||||
|
processed_articles.append(processed_article)
|
||||||
|
|
||||||
|
logger.info("Saving %d processed articles to %s", len(processed_articles), filepath)
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(processed_articles, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info("Processed articles saved successfully")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
def process(self) -> Dict[str, Any]:
|
||||||
|
"""Main process to fetch, process, and store news articles."""
|
||||||
|
logger.info("Starting news processing pipeline")
|
||||||
|
|
||||||
|
# Fetch articles
|
||||||
|
logger.info("Step 1: Fetching articles from RSS feeds")
|
||||||
|
articles = self.fetch_all_news()
|
||||||
|
if not articles:
|
||||||
|
logger.warning("No articles found during fetching")
|
||||||
|
return {"status": "error", "message": "No articles found"}
|
||||||
|
|
||||||
|
# Save raw articles
|
||||||
|
logger.info("Step 2: Saving raw articles")
|
||||||
|
raw_filepath = self.save_raw_articles(articles)
|
||||||
|
|
||||||
|
# Generate embeddings
|
||||||
|
logger.info("Step 3: Generating embeddings for %d articles", len(articles))
|
||||||
|
articles_with_embeddings = self.embedding_generator.process_articles(articles)
|
||||||
|
logger.info("Embeddings generated successfully")
|
||||||
|
|
||||||
|
# Save processed articles
|
||||||
|
logger.info("Step 4: Saving processed articles with embeddings")
|
||||||
|
processed_filepath = self.save_processed_articles(articles_with_embeddings)
|
||||||
|
|
||||||
|
# Store in vector database
|
||||||
|
logger.info("Step 5: Storing articles in vector database")
|
||||||
|
success = self.vector_store.upsert_articles(articles_with_embeddings)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("Articles successfully stored in vector database")
|
||||||
|
else:
|
||||||
|
logger.error("Failed to store articles in vector database")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"status": "success" if success else "error",
|
||||||
|
"message": "Articles processed and stored successfully" if success else "Failed to store articles",
|
||||||
|
"raw_filepath": raw_filepath,
|
||||||
|
"processed_filepath": processed_filepath,
|
||||||
|
"article_count": len(articles)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("News processing pipeline completed with status: %s", result["status"])
|
||||||
|
return result
|
||||||
|
|
||||||
|
news_fetcher = NewsFetcher()
|
||||||
|
print(news_fetcher.process())
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
from groq import Groq
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from config import GROQ_API_KEY
|
||||||
|
|
||||||
|
class NewsRecommender:
|
||||||
|
def __init__(self):
|
||||||
|
self.client = Groq(api_key=GROQ_API_KEY)
|
||||||
|
|
||||||
|
def analyze_articles(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""Analyze a set of articles using Groq to generate insights."""
|
||||||
|
try:
|
||||||
|
# Prepare the prompt
|
||||||
|
articles_text = "\n\n".join([
|
||||||
|
f"Title: {article['title']}\nContent: {article['content']}"
|
||||||
|
for article in articles
|
||||||
|
])
|
||||||
|
|
||||||
|
prompt = f"""Analyze these news articles and provide insights:
|
||||||
|
|
||||||
|
{articles_text}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. Main themes and topics
|
||||||
|
2. Key insights and trends
|
||||||
|
3. Potential implications
|
||||||
|
4. Related areas of interest
|
||||||
|
|
||||||
|
Format the response as a JSON with these keys: themes, insights, implications, related_areas"""
|
||||||
|
|
||||||
|
# Get completion from Groq
|
||||||
|
completion = self.client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a news analyst providing insights about technology and AI news."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
model="mixtral-8x7b-32768",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and return the analysis
|
||||||
|
return completion.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error analyzing articles: {str(e)}")
|
||||||
|
return {
|
||||||
|
"themes": [],
|
||||||
|
"insights": [],
|
||||||
|
"implications": [],
|
||||||
|
"related_areas": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_summary(self, article: Dict[str, Any]) -> str:
|
||||||
|
"""Generate a summary of a single article using Groq."""
|
||||||
|
try:
|
||||||
|
prompt = f"""Summarize this news article:
|
||||||
|
|
||||||
|
Title: {article['title']}
|
||||||
|
Content: {article['content']}
|
||||||
|
|
||||||
|
Please provide a concise summary focusing on the key points and implications."""
|
||||||
|
|
||||||
|
completion = self.client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a news summarizer providing concise summaries of technology and AI news."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
model="mixtral-8x7b-32768",
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating summary: {str(e)}")
|
||||||
|
return "Unable to generate summary."
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
fastapi==0.109.2
|
||||||
|
uvicorn==0.27.1
|
||||||
|
feedparser==6.0.10
|
||||||
|
cohere==4.47
|
||||||
|
pinecone-client==3.0.2
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
groq==0.4.2
|
||||||
|
pydantic==2.6.3
|
||||||
|
python-multipart==0.0.9
|
||||||
|
httpx==0.27.0
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from config import (
|
||||||
|
PINECONE_API_KEY,
|
||||||
|
PINECONE_ENVIRONMENT,
|
||||||
|
PINECONE_INDEX_NAME,
|
||||||
|
VECTOR_DIMENSION,
|
||||||
|
TOP_K_RESULTS
|
||||||
|
)
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
def __init__(self):
|
||||||
|
self.pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
||||||
|
self.index_name = PINECONE_INDEX_NAME
|
||||||
|
self._ensure_index()
|
||||||
|
|
||||||
|
def _ensure_index(self):
|
||||||
|
"""Ensure the Pinecone index exists, create if it doesn't."""
|
||||||
|
if self.index_name not in self.pinecone.list_indexes().names():
|
||||||
|
self.pinecone.create_index(
|
||||||
|
name=self.index_name,
|
||||||
|
dimension=VECTOR_DIMENSION,
|
||||||
|
metric="cosine",
|
||||||
|
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||||
|
)
|
||||||
|
self.index = self.pinecone.Index(self.index_name)
|
||||||
|
|
||||||
|
def upsert_articles(self, articles: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""Upsert articles to the vector store."""
|
||||||
|
try:
|
||||||
|
vectors = []
|
||||||
|
for article in articles:
|
||||||
|
if "embedding" not in article:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector = {
|
||||||
|
"id": article["id"],
|
||||||
|
"values": article["embedding"],
|
||||||
|
"metadata": {
|
||||||
|
"title": article["title"],
|
||||||
|
"content": article["content"],
|
||||||
|
"link": article["link"],
|
||||||
|
"published": article["published"],
|
||||||
|
"source": article["source"],
|
||||||
|
"categories": article["categories"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vectors.append(vector)
|
||||||
|
|
||||||
|
if vectors:
|
||||||
|
self.index.upsert(vectors=vectors)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error upserting articles: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def search_similar(self, query_embedding: List[float], top_k: int = TOP_K_RESULTS) -> List[Dict[str, Any]]:
|
||||||
|
"""Search for similar articles using the query embedding."""
|
||||||
|
try:
|
||||||
|
results = self.index.query(
|
||||||
|
vector=query_embedding,
|
||||||
|
top_k=top_k,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
articles = []
|
||||||
|
for match in results.matches:
|
||||||
|
article = {
|
||||||
|
"id": match.id,
|
||||||
|
"score": match.score,
|
||||||
|
**match.metadata
|
||||||
|
}
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
return articles
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error searching similar articles: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def delete_article(self, article_id: str) -> bool:
|
||||||
|
"""Delete an article from the vector store."""
|
||||||
|
try:
|
||||||
|
self.index.delete(ids=[article_id])
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting article: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user