Refactor backend configuration and enhance news fetching functionality
- Introduced a Config dataclass in config.py to manage API keys, RSS feeds, and directory paths more effectively. - Updated the NewsFetcher class to include retry logic for fetching articles from RSS feeds. - Modified the EmbeddingGenerator and NewsRecommender classes to utilize the new configuration structure. - Enhanced main.py to implement API token verification for secure access to news fetching and recommendations.
This commit is contained in:
@@ -55,8 +55,7 @@ DS Task AI News is a web application that uses AI technologies to fetch, analyze
|
|||||||
```
|
```
|
||||||
|
|
||||||
4. Run the application:
|
4. Run the application:
|
||||||
```
|
``` python backend/main.py
|
||||||
python backend/main.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Open your web browser and navigate to `http://localhost:8000`.
|
5. Open your web browser and navigate to `http://localhost:8000`.
|
||||||
@@ -104,3 +103,4 @@ This project is licensed under the MIT License - see the LICENSE file for detail
|
|||||||
- [Cohere](https://cohere.ai/)
|
- [Cohere](https://cohere.ai/)
|
||||||
- [Pinecone](https://www.pinecone.io/)
|
- [Pinecone](https://www.pinecone.io/)
|
||||||
- [Groq](https://groq.com/)
|
- [Groq](https://groq.com/)
|
||||||
|
|
||||||
|
|||||||
+56
-22
@@ -1,5 +1,10 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import HTTPException, Depends, Security
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from starlette.status import HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
|
|
||||||
@@ -9,30 +14,59 @@ from dotenv import load_dotenv
|
|||||||
# Load environment variables from the specified path
|
# Load environment variables from the specified path
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# API Keys
|
@dataclass
|
||||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
class Config:
|
||||||
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
# API Keys
|
||||||
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
cohere_api_key: str = os.getenv("COHERE_API_KEY", "")
|
||||||
|
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
||||||
|
pinecone_api_key: str = os.getenv("PINECONE_API_KEY", "")
|
||||||
|
api_token: str = os.getenv("API_TOKEN", "default_secret_token") # Default token for development
|
||||||
|
|
||||||
# Pinecone Configuration
|
# Pinecone Configuration
|
||||||
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
pinecone_index_name: str = os.getenv("PINECONE_INDEX_NAME", "news-articles")
|
||||||
|
vector_dimension: int = 1024 # Cohere embedding dimension
|
||||||
|
top_k_results: int = 5
|
||||||
|
|
||||||
# News Sources
|
# News Sources
|
||||||
RSS_FEEDS = [
|
rss_feeds: List[str] = field(default_factory=lambda: [
|
||||||
# "https://feeds.feedburner.com/TechCrunch/",
|
# "https://feeds.feedburner.com/TechCrunch/",
|
||||||
# "https://www.theverge.com/rss/index.xml",
|
# "https://www.theverge.com/rss/index.xml",
|
||||||
"https://www.wired.com/feed/rss",
|
"https://www.wired.com/feed/rss",
|
||||||
"https://www.technologyreview.com/feed/",
|
"https://www.technologyreview.com/feed/",
|
||||||
]
|
])
|
||||||
|
|
||||||
# Vector Database Settings
|
# Data Directories
|
||||||
VECTOR_DIMENSION = 1024 # Cohere embedding dimension
|
raw_news_dir: str = "data/raw_news"
|
||||||
TOP_K_RESULTS = 5
|
processed_news_dir: str = "data/processed_news"
|
||||||
|
|
||||||
# Data Directories
|
def __post_init__(self):
|
||||||
RAW_NEWS_DIR = "data/raw_news"
|
# Create directories if they don't exist
|
||||||
PROCESSED_NEWS_DIR = "data/processed_news"
|
os.makedirs(self.raw_news_dir, exist_ok=True)
|
||||||
|
os.makedirs(self.processed_news_dir, exist_ok=True)
|
||||||
|
|
||||||
# Create directories if they don't exist
|
# Create a global config instance
|
||||||
os.makedirs(RAW_NEWS_DIR, exist_ok=True)
|
config = Config()
|
||||||
os.makedirs(PROCESSED_NEWS_DIR, exist_ok=True)
|
|
||||||
|
# API Key header
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Token", auto_error=False)
|
||||||
|
|
||||||
|
def verify_api_token(api_key: str):
|
||||||
|
"""
|
||||||
|
Verify the API token from the request header.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The API key from the request header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The API key if valid
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the API key is invalid
|
||||||
|
"""
|
||||||
|
if api_key == config.api_token:
|
||||||
|
print(f"API key verified: {api_key}")
|
||||||
|
return api_key
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API token"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import cohere
|
import cohere
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
from config import COHERE_API_KEY
|
from config import config
|
||||||
|
|
||||||
class EmbeddingGenerator:
|
class EmbeddingGenerator:
|
||||||
def __init__(self):
|
def __init__(self, cohere_client: Optional[cohere.Client] = None):
|
||||||
self.client = cohere.Client(COHERE_API_KEY)
|
self.client = cohere_client or cohere.Client(config.cohere_api_key)
|
||||||
|
|
||||||
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Generate embeddings for a list of texts using Cohere."""
|
"""Generate embeddings for a list of texts using Cohere."""
|
||||||
|
|||||||
+15
-7
@@ -1,4 +1,4 @@
|
|||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
@@ -10,13 +10,20 @@ from news_fetcher import NewsFetcher
|
|||||||
from embeddings import EmbeddingGenerator
|
from embeddings import EmbeddingGenerator
|
||||||
from vector_store import VectorStore
|
from vector_store import VectorStore
|
||||||
from recommender import NewsRecommender
|
from recommender import NewsRecommender
|
||||||
from config import RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
from config import config
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
app = FastAPI(title="DS Task AI News API")
|
app = FastAPI(title="DS Task AI News API")
|
||||||
|
|
||||||
# Configure templates
|
# Configure templates
|
||||||
templates = Jinja2Templates(directory="backend/templates")
|
templates = Jinja2Templates(directory="backend/templates")
|
||||||
|
|
||||||
|
def verify_api_token(token: str):
|
||||||
|
if token == config.api_token:
|
||||||
|
print(f"API key verified: {token}")
|
||||||
|
return token
|
||||||
|
return None
|
||||||
|
|
||||||
# Add custom filters
|
# Add custom filters
|
||||||
def from_json(value):
|
def from_json(value):
|
||||||
"""Parse a JSON string into a Python object."""
|
"""Parse a JSON string into a Python object."""
|
||||||
@@ -51,7 +58,8 @@ async def root(request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@app.get("/fetch-news", response_class=HTMLResponse)
|
@app.get("/fetch-news", response_class=HTMLResponse)
|
||||||
async def fetch_news(request: Request):
|
def fetch_news(request: Request, token: str = Depends(verify_api_token)):
|
||||||
|
# print(f"Fetching news with token: {token}")
|
||||||
"""Fetch news from RSS feeds and store in vector database."""
|
"""Fetch news from RSS feeds and store in vector database."""
|
||||||
try:
|
try:
|
||||||
result = news_fetcher.process()
|
result = news_fetcher.process()
|
||||||
@@ -59,11 +67,11 @@ async def fetch_news(request: Request):
|
|||||||
raise HTTPException(status_code=404, detail=result["message"])
|
raise HTTPException(status_code=404, detail=result["message"])
|
||||||
|
|
||||||
# Get the latest processed articles
|
# Get the latest processed articles
|
||||||
processed_files = sorted(os.listdir(PROCESSED_NEWS_DIR), reverse=True)
|
processed_files = sorted(os.listdir(config.processed_news_dir), reverse=True)
|
||||||
if not processed_files:
|
if not processed_files:
|
||||||
raise HTTPException(status_code=404, detail="No processed articles found")
|
raise HTTPException(status_code=404, detail="No processed articles found")
|
||||||
|
|
||||||
latest_file = os.path.join(PROCESSED_NEWS_DIR, processed_files[0])
|
latest_file = os.path.join(config.processed_news_dir, processed_files[0])
|
||||||
with open(latest_file, 'r', encoding='utf-8') as f:
|
with open(latest_file, 'r', encoding='utf-8') as f:
|
||||||
articles = json.load(f)
|
articles = json.load(f)
|
||||||
|
|
||||||
@@ -81,7 +89,7 @@ async def fetch_news(request: Request):
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get("/recommend-news", response_class=HTMLResponse)
|
@app.get("/recommend-news", response_class=HTMLResponse)
|
||||||
async def recommend_news(request: Request, article_id: str = None, query: str = None):
|
async def recommend_news(request: Request, article_id: str = None, query: str = None, token: str = Depends(verify_api_token)):
|
||||||
"""Get news recommendations based on article ID or search query."""
|
"""Get news recommendations based on article ID or search query."""
|
||||||
try:
|
try:
|
||||||
if article_id:
|
if article_id:
|
||||||
@@ -129,7 +137,7 @@ async def recommend_news(request: Request, article_id: str = None, query: str =
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get("/article/{article_id}")
|
@app.get("/article/{article_id}")
|
||||||
async def get_article(article_id: str):
|
async def get_article(article_id: str, token: str = Depends(verify_api_token)):
|
||||||
"""Get a specific article and its summary."""
|
"""Get a specific article and its summary."""
|
||||||
try:
|
try:
|
||||||
# Search for the article
|
# Search for the article
|
||||||
|
|||||||
+62
-36
@@ -3,12 +3,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
from config import RSS_FEEDS, RAW_NEWS_DIR, PROCESSED_NEWS_DIR
|
|
||||||
from embeddings import EmbeddingGenerator
|
|
||||||
from vector_store import VectorStore
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from config import config
|
||||||
|
from embeddings import EmbeddingGenerator
|
||||||
|
from vector_store import VectorStore
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -22,10 +23,18 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger('NewsFetcher')
|
logger = logging.getLogger('NewsFetcher')
|
||||||
|
|
||||||
class NewsFetcher:
|
class NewsFetcher:
|
||||||
def __init__(self):
|
def __init__(
|
||||||
self.feeds = RSS_FEEDS
|
self,
|
||||||
self.embedding_generator = EmbeddingGenerator()
|
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||||
self.vector_store = VectorStore()
|
vector_store: Optional[VectorStore] = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: int = 5
|
||||||
|
):
|
||||||
|
self.feeds = config.rss_feeds
|
||||||
|
self.embedding_generator = embedding_generator or EmbeddingGenerator()
|
||||||
|
self.vector_store = vector_store or VectorStore()
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
logger.info("NewsFetcher initialized with %d RSS feeds", len(self.feeds))
|
logger.info("NewsFetcher initialized with %d RSS feeds", len(self.feeds))
|
||||||
|
|
||||||
def clean_html_content(self, html_content: str) -> str:
|
def clean_html_content(self, html_content: str) -> str:
|
||||||
@@ -54,32 +63,52 @@ class NewsFetcher:
|
|||||||
return cleaned_text
|
return cleaned_text
|
||||||
|
|
||||||
def fetch_rss_news(self, feed_url: str) -> List[Dict[str, Any]]:
|
def fetch_rss_news(self, feed_url: str) -> List[Dict[str, Any]]:
|
||||||
"""Fetch news articles from a single RSS feed."""
|
"""Fetch news articles from a single RSS feed with retry logic."""
|
||||||
logger.info("Fetching news from feed: %s", feed_url)
|
logger.info("Fetching news from feed: %s", feed_url)
|
||||||
feed = feedparser.parse(feed_url)
|
|
||||||
articles = []
|
articles = []
|
||||||
|
|
||||||
for entry in feed.entries:
|
for attempt in range(self.max_retries):
|
||||||
# Get raw content with HTML
|
try:
|
||||||
raw_content = entry.get("summary", "")
|
feed = feedparser.parse(feed_url)
|
||||||
|
if not feed.entries:
|
||||||
|
logger.warning("No entries found in feed %s (attempt %d/%d)",
|
||||||
|
feed_url, attempt + 1, self.max_retries)
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
time.sleep(self.retry_delay)
|
||||||
|
continue
|
||||||
|
return []
|
||||||
|
|
||||||
# Clean HTML content
|
for entry in feed.entries:
|
||||||
clean_content = self.clean_html_content(raw_content)
|
# Get raw content with HTML
|
||||||
|
raw_content = entry.get("summary", "")
|
||||||
|
|
||||||
article = {
|
# Clean HTML content
|
||||||
"title": entry.title,
|
clean_content = self.clean_html_content(raw_content)
|
||||||
"raw_content": raw_content, # Store original HTML content
|
|
||||||
"content": clean_content, # Store cleaned text content
|
|
||||||
"link": entry.get("link", ""),
|
|
||||||
"published": entry.get("published", datetime.now().isoformat()),
|
|
||||||
"source": feed.feed.get("title", "Unknown"),
|
|
||||||
"categories": [tag.term for tag in entry.get("tags", [])],
|
|
||||||
"id": entry.get("id", entry.get("link", "")),
|
|
||||||
}
|
|
||||||
articles.append(article)
|
|
||||||
|
|
||||||
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
article = {
|
||||||
return articles
|
"title": entry.title,
|
||||||
|
"raw_content": raw_content, # Store original HTML content
|
||||||
|
"content": clean_content, # Store cleaned text content
|
||||||
|
"link": entry.get("link", ""),
|
||||||
|
"published": entry.get("published", datetime.now().isoformat()),
|
||||||
|
"source": feed.feed.get("title", "Unknown"),
|
||||||
|
"categories": [tag.term for tag in entry.get("tags", [])],
|
||||||
|
"id": entry.get("id", entry.get("link", "")),
|
||||||
|
}
|
||||||
|
articles.append(article)
|
||||||
|
|
||||||
|
logger.info("Fetched %d articles from %s", len(articles), feed_url)
|
||||||
|
return articles
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error fetching from %s (attempt %d/%d): %s",
|
||||||
|
feed_url, attempt + 1, self.max_retries, str(e))
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
time.sleep(self.retry_delay)
|
||||||
|
else:
|
||||||
|
logger.error("Failed to fetch from %s after %d attempts",
|
||||||
|
feed_url, self.max_retries)
|
||||||
|
return []
|
||||||
|
|
||||||
def fetch_all_news(self) -> List[Dict[str, Any]]:
|
def fetch_all_news(self) -> List[Dict[str, Any]]:
|
||||||
"""Fetch news from all configured RSS feeds."""
|
"""Fetch news from all configured RSS feeds."""
|
||||||
@@ -87,12 +116,9 @@ class NewsFetcher:
|
|||||||
all_articles = []
|
all_articles = []
|
||||||
|
|
||||||
for feed_url in self.feeds:
|
for feed_url in self.feeds:
|
||||||
try:
|
articles = self.fetch_rss_news(feed_url)
|
||||||
articles = self.fetch_rss_news(feed_url)
|
all_articles.extend(articles)
|
||||||
all_articles.extend(articles)
|
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
||||||
logger.info("Successfully fetched %d articles from %s", len(articles), feed_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error fetching from %s: %s", feed_url, str(e))
|
|
||||||
|
|
||||||
logger.info("Total articles fetched: %d", len(all_articles))
|
logger.info("Total articles fetched: %d", len(all_articles))
|
||||||
return all_articles
|
return all_articles
|
||||||
@@ -101,7 +127,7 @@ class NewsFetcher:
|
|||||||
"""Save raw articles to a JSON file."""
|
"""Save raw articles to a JSON file."""
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = f"raw_news_{timestamp}.json"
|
filename = f"raw_news_{timestamp}.json"
|
||||||
filepath = os.path.join(RAW_NEWS_DIR, filename)
|
filepath = os.path.join(config.raw_news_dir, filename)
|
||||||
|
|
||||||
logger.info("Saving %d raw articles to %s", len(articles), filepath)
|
logger.info("Saving %d raw articles to %s", len(articles), filepath)
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
@@ -114,7 +140,7 @@ class NewsFetcher:
|
|||||||
"""Save processed articles with embeddings to a JSON file."""
|
"""Save processed articles with embeddings to a JSON file."""
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
filename = f"processed_news_{timestamp}.json"
|
filename = f"processed_news_{timestamp}.json"
|
||||||
filepath = os.path.join(PROCESSED_NEWS_DIR, filename)
|
filepath = os.path.join(config.processed_news_dir, filename)
|
||||||
|
|
||||||
# Create a copy of articles without raw_content for processed storage
|
# Create a copy of articles without raw_content for processed storage
|
||||||
processed_articles = []
|
processed_articles = []
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from groq import Groq
|
from groq import Groq
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
from config import GROQ_API_KEY
|
from config import config
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class NewsRecommender:
|
class NewsRecommender:
|
||||||
def __init__(self):
|
def __init__(self, groq_client: Optional[Groq] = None):
|
||||||
self.client = Groq(api_key=GROQ_API_KEY)
|
self.client = groq_client or Groq(api_key=config.groq_api_key)
|
||||||
|
|
||||||
def analyze_articles(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def analyze_articles(self, articles: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""Analyze a set of articles using Groq to generate insights."""
|
"""Analyze a set of articles using Groq to generate insights."""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
<div class="p-6">
|
<div class="p-6">
|
||||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">Latest News</h2>
|
<h2 class="text-2xl font-semibold text-gray-800 mb-4">Latest News</h2>
|
||||||
<p class="text-gray-600 mb-6">View the latest news articles fetched from our RSS feeds.</p>
|
<p class="text-gray-600 mb-6">View the latest news articles fetched from our RSS feeds.</p>
|
||||||
<a href="/fetch-news" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
<a href="/fetch-news?token=default_secret_token" class="inline-block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300">
|
||||||
View Latest News
|
View Latest News
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
@@ -27,7 +27,7 @@
|
|||||||
<h2 class="text-2xl font-semibold text-gray-800 mb-4">News Recommendations</h2>
|
<h2 class="text-2xl font-semibold text-gray-800 mb-4">News Recommendations</h2>
|
||||||
<p class="text-gray-600 mb-6">Get personalized news recommendations based on your interests.</p>
|
<p class="text-gray-600 mb-6">Get personalized news recommendations based on your interests.</p>
|
||||||
<div class="space-y-4">
|
<div class="space-y-4">
|
||||||
<a href="/recommend-news?query=technology" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
<a href="/recommend-news?query=technology?token=default_secret_token" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||||
Technology News
|
Technology News
|
||||||
</a>
|
</a>
|
||||||
<a href="/recommend-news?query=artificial intelligence" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
<a href="/recommend-news?query=artificial intelligence" class="block bg-blue-600 text-white px-6 py-3 rounded-md font-medium hover:bg-blue-700 transition-colors duration-300 text-center">
|
||||||
|
|||||||
+9
-14
@@ -1,16 +1,11 @@
|
|||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional
|
||||||
from config import (
|
from config import config
|
||||||
PINECONE_API_KEY,
|
|
||||||
PINECONE_INDEX_NAME,
|
|
||||||
VECTOR_DIMENSION,
|
|
||||||
TOP_K_RESULTS
|
|
||||||
)
|
|
||||||
|
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
def __init__(self):
|
def __init__(self, pinecone_client: Optional[Pinecone] = None):
|
||||||
self.pinecone = Pinecone(api_key=PINECONE_API_KEY)
|
self.pinecone = pinecone_client or Pinecone(api_key=config.pinecone_api_key)
|
||||||
self.index_name = PINECONE_INDEX_NAME
|
self.index_name = config.pinecone_index_name
|
||||||
self._ensure_index()
|
self._ensure_index()
|
||||||
|
|
||||||
def _ensure_index(self):
|
def _ensure_index(self):
|
||||||
@@ -20,11 +15,11 @@ class VectorStore:
|
|||||||
# Create a new index with the correct dimension
|
# Create a new index with the correct dimension
|
||||||
self.pinecone.create_index(
|
self.pinecone.create_index(
|
||||||
name=self.index_name,
|
name=self.index_name,
|
||||||
dimension=VECTOR_DIMENSION,
|
dimension=config.vector_dimension,
|
||||||
metric="cosine",
|
metric="cosine",
|
||||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||||
)
|
)
|
||||||
print(f"Created new index '{self.index_name}' with dimension {VECTOR_DIMENSION}")
|
print(f"Created new index '{self.index_name}' with dimension {config.vector_dimension}")
|
||||||
|
|
||||||
self.index = self.pinecone.Index(self.index_name)
|
self.index = self.pinecone.Index(self.index_name)
|
||||||
|
|
||||||
@@ -57,12 +52,12 @@ class VectorStore:
|
|||||||
print(f"Error upserting articles: {str(e)}")
|
print(f"Error upserting articles: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search_similar(self, query_embedding: List[float], top_k: int = TOP_K_RESULTS) -> List[Dict[str, Any]]:
|
def search_similar(self, query_embedding: List[float], top_k: int = None) -> List[Dict[str, Any]]:
|
||||||
"""Search for similar articles using the query embedding."""
|
"""Search for similar articles using the query embedding."""
|
||||||
try:
|
try:
|
||||||
results = self.index.query(
|
results = self.index.query(
|
||||||
vector=query_embedding,
|
vector=query_embedding,
|
||||||
top_k=top_k,
|
top_k=top_k or config.top_k_results,
|
||||||
include_metadata=True
|
include_metadata=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Reference in New Issue
Block a user