"""Embeddings generation for DS Task AI News""" import os import numpy as np from typing import List, Dict, Any, Optional from sentence_transformers import SentenceTransformer import cohere from config import settings class EmbeddingGenerator: def __init__(self): self.cohere_client = None self.sentence_model = None self.use_cohere = bool(settings.cohere_api_key) # Initialize embedding model if self.use_cohere: try: self.cohere_client = cohere.Client(settings.cohere_api_key) print("Using Cohere for embeddings") except Exception as e: print(f"Cohere initialization failed: {e}") self.use_cohere = False if not self.use_cohere: print("Using Sentence Transformers for embeddings") self.sentence_model = SentenceTransformer(settings.embedding_model) def create_article_text(self, article: Dict[str, Any]) -> str: """Combine article fields into text for embedding""" title = article.get('title', '') content = article.get('content', '') source = article.get('source', '') # Combine with weights (title is more important) text = f"{title}. {content}" if source: text += f" Source: {source}" return text.strip() def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray: """Generate embeddings using Cohere""" try: response = self.cohere_client.embed( texts=texts, model='embed-english-v3.0', input_type='search_document' ) return np.array(response.embeddings) except Exception as e: print(f"Cohere embedding error: {e}") raise def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray: """Generate embeddings using Sentence Transformers""" try: embeddings = self.sentence_model.encode(texts, convert_to_numpy=True) return embeddings except Exception as e: print(f"Sentence Transformer embedding error: {e}") raise def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray: """Generate embeddings for articles""" if not articles: return np.array([]) # Create texts for embedding texts = [self.create_article_text(article) for article in articles] print(f"Generating embeddings for {len(texts)} articles...") # Generate embeddings if self.use_cohere: embeddings = self.generate_embeddings_cohere(texts) else: embeddings = self.generate_embeddings_sentence_transformer(texts) print(f"Generated embeddings shape: {embeddings.shape}") return embeddings def generate_query_embedding(self, query: str) -> np.ndarray: """Generate embedding for a search query""" if self.use_cohere: try: response = self.cohere_client.embed( texts=[query], model='embed-english-v3.0', input_type='search_query' ) return np.array(response.embeddings[0]) except Exception as e: print(f"Cohere query embedding error: {e}") # Fallback to sentence transformer return self.sentence_model.encode([query], convert_to_numpy=True)[0] else: return self.sentence_model.encode([query], convert_to_numpy=True)[0] def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: """Compute cosine similarity between two embeddings""" # Normalize embeddings norm1 = np.linalg.norm(embedding1) norm2 = np.linalg.norm(embedding2) if norm1 == 0 or norm2 == 0: return 0.0 # Cosine similarity similarity = np.dot(embedding1, embedding2) / (norm1 * norm2) return float(similarity) def find_similar_articles(self, query_embedding: np.ndarray, article_embeddings: np.ndarray, articles: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]: """Find most similar articles to query""" if len(article_embeddings) == 0: return [] similarities = [] for i, article_embedding in enumerate(article_embeddings): similarity = self.compute_similarity(query_embedding, article_embedding) similarities.append((similarity, i)) # Sort by similarity (descending) similarities.sort(reverse=True) # Get top-k results results = [] for similarity, idx in similarities[:top_k]: if similarity >= settings.similarity_threshold: article = articles[idx].copy() article['similarity_score'] = similarity results.append(article) return results # Test function if __name__ == "__main__": # Test with sample articles sample_articles = [ { "title": "AI Revolution in Healthcare", "content": "Artificial intelligence is transforming medical diagnosis and treatment.", "source": "TechNews" }, { "title": "Climate Change Solutions", "content": "New technologies are being developed to combat global warming.", "source": "ScienceDaily" } ] generator = EmbeddingGenerator() embeddings = generator.generate_embeddings(sample_articles) print(f"Test embeddings shape: {embeddings.shape}")