feat: Implement complete RSS news fetching system with multi-source support
This commit is contained in:
@@ -0,0 +1,156 @@
|
||||
"""Embeddings generation for DS Task AI News"""
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import cohere
|
||||
from config import settings
|
||||
|
||||
class EmbeddingGenerator:
|
||||
def __init__(self):
|
||||
self.cohere_client = None
|
||||
self.sentence_model = None
|
||||
self.use_cohere = bool(settings.cohere_api_key)
|
||||
|
||||
# Initialize embedding model
|
||||
if self.use_cohere:
|
||||
try:
|
||||
self.cohere_client = cohere.Client(settings.cohere_api_key)
|
||||
print("Using Cohere for embeddings")
|
||||
except Exception as e:
|
||||
print(f"Cohere initialization failed: {e}")
|
||||
self.use_cohere = False
|
||||
|
||||
if not self.use_cohere:
|
||||
print("Using Sentence Transformers for embeddings")
|
||||
self.sentence_model = SentenceTransformer(settings.embedding_model)
|
||||
|
||||
def create_article_text(self, article: Dict[str, Any]) -> str:
|
||||
"""Combine article fields into text for embedding"""
|
||||
title = article.get('title', '')
|
||||
content = article.get('content', '')
|
||||
source = article.get('source', '')
|
||||
|
||||
# Combine with weights (title is more important)
|
||||
text = f"{title}. {content}"
|
||||
if source:
|
||||
text += f" Source: {source}"
|
||||
|
||||
return text.strip()
|
||||
|
||||
def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray:
|
||||
"""Generate embeddings using Cohere"""
|
||||
try:
|
||||
response = self.cohere_client.embed(
|
||||
texts=texts,
|
||||
model='embed-english-v3.0',
|
||||
input_type='search_document'
|
||||
)
|
||||
return np.array(response.embeddings)
|
||||
except Exception as e:
|
||||
print(f"Cohere embedding error: {e}")
|
||||
raise
|
||||
|
||||
def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray:
|
||||
"""Generate embeddings using Sentence Transformers"""
|
||||
try:
|
||||
embeddings = self.sentence_model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
print(f"Sentence Transformer embedding error: {e}")
|
||||
raise
|
||||
|
||||
def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray:
|
||||
"""Generate embeddings for articles"""
|
||||
if not articles:
|
||||
return np.array([])
|
||||
|
||||
# Create texts for embedding
|
||||
texts = [self.create_article_text(article) for article in articles]
|
||||
|
||||
print(f"Generating embeddings for {len(texts)} articles...")
|
||||
|
||||
# Generate embeddings
|
||||
if self.use_cohere:
|
||||
embeddings = self.generate_embeddings_cohere(texts)
|
||||
else:
|
||||
embeddings = self.generate_embeddings_sentence_transformer(texts)
|
||||
|
||||
print(f"Generated embeddings shape: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
def generate_query_embedding(self, query: str) -> np.ndarray:
|
||||
"""Generate embedding for a search query"""
|
||||
if self.use_cohere:
|
||||
try:
|
||||
response = self.cohere_client.embed(
|
||||
texts=[query],
|
||||
model='embed-english-v3.0',
|
||||
input_type='search_query'
|
||||
)
|
||||
return np.array(response.embeddings[0])
|
||||
except Exception as e:
|
||||
print(f"Cohere query embedding error: {e}")
|
||||
# Fallback to sentence transformer
|
||||
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
|
||||
else:
|
||||
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
|
||||
|
||||
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two embeddings"""
|
||||
# Normalize embeddings
|
||||
norm1 = np.linalg.norm(embedding1)
|
||||
norm2 = np.linalg.norm(embedding2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
# Cosine similarity
|
||||
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
|
||||
return float(similarity)
|
||||
|
||||
def find_similar_articles(self, query_embedding: np.ndarray,
|
||||
article_embeddings: np.ndarray,
|
||||
articles: List[Dict[str, Any]],
|
||||
top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Find most similar articles to query"""
|
||||
if len(article_embeddings) == 0:
|
||||
return []
|
||||
|
||||
similarities = []
|
||||
for i, article_embedding in enumerate(article_embeddings):
|
||||
similarity = self.compute_similarity(query_embedding, article_embedding)
|
||||
similarities.append((similarity, i))
|
||||
|
||||
# Sort by similarity (descending)
|
||||
similarities.sort(reverse=True)
|
||||
|
||||
# Get top-k results
|
||||
results = []
|
||||
for similarity, idx in similarities[:top_k]:
|
||||
if similarity >= settings.similarity_threshold:
|
||||
article = articles[idx].copy()
|
||||
article['similarity_score'] = similarity
|
||||
results.append(article)
|
||||
|
||||
return results
|
||||
|
||||
# Test function
|
||||
if __name__ == "__main__":
|
||||
# Test with sample articles
|
||||
sample_articles = [
|
||||
{
|
||||
"title": "AI Revolution in Healthcare",
|
||||
"content": "Artificial intelligence is transforming medical diagnosis and treatment.",
|
||||
"source": "TechNews"
|
||||
},
|
||||
{
|
||||
"title": "Climate Change Solutions",
|
||||
"content": "New technologies are being developed to combat global warming.",
|
||||
"source": "ScienceDaily"
|
||||
}
|
||||
]
|
||||
|
||||
generator = EmbeddingGenerator()
|
||||
embeddings = generator.generate_embeddings(sample_articles)
|
||||
print(f"Test embeddings shape: {embeddings.shape}")
|
||||
Reference in New Issue
Block a user