Files
DS_TASK_AI_VIEWS/backend/embeddings.py
T

157 lines
5.8 KiB
Python

"""Embeddings generation for DS Task AI News"""
import os
import numpy as np
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import cohere
from config import settings
class EmbeddingGenerator:
def __init__(self):
self.cohere_client = None
self.sentence_model = None
self.use_cohere = bool(settings.cohere_api_key)
# Initialize embedding model
if self.use_cohere:
try:
self.cohere_client = cohere.Client(settings.cohere_api_key)
print("Using Cohere for embeddings")
except Exception as e:
print(f"Cohere initialization failed: {e}")
self.use_cohere = False
if not self.use_cohere:
print("Using Sentence Transformers for embeddings")
self.sentence_model = SentenceTransformer(settings.embedding_model)
def create_article_text(self, article: Dict[str, Any]) -> str:
"""Combine article fields into text for embedding"""
title = article.get('title', '')
content = article.get('content', '')
source = article.get('source', '')
# Combine with weights (title is more important)
text = f"{title}. {content}"
if source:
text += f" Source: {source}"
return text.strip()
def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Cohere"""
try:
response = self.cohere_client.embed(
texts=texts,
model='embed-english-v3.0',
input_type='search_document'
)
return np.array(response.embeddings)
except Exception as e:
print(f"Cohere embedding error: {e}")
raise
def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Sentence Transformers"""
try:
embeddings = self.sentence_model.encode(texts, convert_to_numpy=True)
return embeddings
except Exception as e:
print(f"Sentence Transformer embedding error: {e}")
raise
def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray:
"""Generate embeddings for articles"""
if not articles:
return np.array([])
# Create texts for embedding
texts = [self.create_article_text(article) for article in articles]
print(f"Generating embeddings for {len(texts)} articles...")
# Generate embeddings
if self.use_cohere:
embeddings = self.generate_embeddings_cohere(texts)
else:
embeddings = self.generate_embeddings_sentence_transformer(texts)
print(f"Generated embeddings shape: {embeddings.shape}")
return embeddings
def generate_query_embedding(self, query: str) -> np.ndarray:
"""Generate embedding for a search query"""
if self.use_cohere:
try:
response = self.cohere_client.embed(
texts=[query],
model='embed-english-v3.0',
input_type='search_query'
)
return np.array(response.embeddings[0])
except Exception as e:
print(f"Cohere query embedding error: {e}")
# Fallback to sentence transformer
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
else:
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings"""
# Normalize embeddings
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
# Cosine similarity
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
return float(similarity)
def find_similar_articles(self, query_embedding: np.ndarray,
article_embeddings: np.ndarray,
articles: List[Dict[str, Any]],
top_k: int = 5) -> List[Dict[str, Any]]:
"""Find most similar articles to query"""
if len(article_embeddings) == 0:
return []
similarities = []
for i, article_embedding in enumerate(article_embeddings):
similarity = self.compute_similarity(query_embedding, article_embedding)
similarities.append((similarity, i))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top-k results
results = []
for similarity, idx in similarities[:top_k]:
if similarity >= settings.similarity_threshold:
article = articles[idx].copy()
article['similarity_score'] = similarity
results.append(article)
return results
# Test function
if __name__ == "__main__":
# Test with sample articles
sample_articles = [
{
"title": "AI Revolution in Healthcare",
"content": "Artificial intelligence is transforming medical diagnosis and treatment.",
"source": "TechNews"
},
{
"title": "Climate Change Solutions",
"content": "New technologies are being developed to combat global warming.",
"source": "ScienceDaily"
}
]
generator = EmbeddingGenerator()
embeddings = generator.generate_embeddings(sample_articles)
print(f"Test embeddings shape: {embeddings.shape}")