Files
DS_TASK_AI_VIEWS/backend/embeddings.py
T

157 lines
5.8 KiB
Python
Raw Normal View History

"""Embeddings generation for DS Task AI News"""
import os
import numpy as np
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import cohere
from config import settings
class EmbeddingGenerator:
def __init__(self):
self.cohere_client = None
self.sentence_model = None
self.use_cohere = bool(settings.cohere_api_key)
# Initialize embedding model
if self.use_cohere:
try:
self.cohere_client = cohere.Client(settings.cohere_api_key)
print("Using Cohere for embeddings")
except Exception as e:
print(f"Cohere initialization failed: {e}")
self.use_cohere = False
if not self.use_cohere:
print("Using Sentence Transformers for embeddings")
self.sentence_model = SentenceTransformer(settings.embedding_model)
def create_article_text(self, article: Dict[str, Any]) -> str:
"""Combine article fields into text for embedding"""
title = article.get('title', '')
content = article.get('content', '')
source = article.get('source', '')
# Combine with weights (title is more important)
text = f"{title}. {content}"
if source:
text += f" Source: {source}"
return text.strip()
def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Cohere"""
try:
response = self.cohere_client.embed(
texts=texts,
model='embed-english-v3.0',
input_type='search_document'
)
return np.array(response.embeddings)
except Exception as e:
print(f"Cohere embedding error: {e}")
raise
def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Sentence Transformers"""
try:
embeddings = self.sentence_model.encode(texts, convert_to_numpy=True)
return embeddings
except Exception as e:
print(f"Sentence Transformer embedding error: {e}")
raise
def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray:
"""Generate embeddings for articles"""
if not articles:
return np.array([])
# Create texts for embedding
texts = [self.create_article_text(article) for article in articles]
print(f"Generating embeddings for {len(texts)} articles...")
# Generate embeddings
if self.use_cohere:
embeddings = self.generate_embeddings_cohere(texts)
else:
embeddings = self.generate_embeddings_sentence_transformer(texts)
print(f"Generated embeddings shape: {embeddings.shape}")
return embeddings
def generate_query_embedding(self, query: str) -> np.ndarray:
"""Generate embedding for a search query"""
if self.use_cohere:
try:
response = self.cohere_client.embed(
texts=[query],
model='embed-english-v3.0',
input_type='search_query'
)
return np.array(response.embeddings[0])
except Exception as e:
print(f"Cohere query embedding error: {e}")
# Fallback to sentence transformer
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
else:
return self.sentence_model.encode([query], convert_to_numpy=True)[0]
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings"""
# Normalize embeddings
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
# Cosine similarity
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
return float(similarity)
def find_similar_articles(self, query_embedding: np.ndarray,
article_embeddings: np.ndarray,
articles: List[Dict[str, Any]],
top_k: int = 5) -> List[Dict[str, Any]]:
"""Find most similar articles to query"""
if len(article_embeddings) == 0:
return []
similarities = []
for i, article_embedding in enumerate(article_embeddings):
similarity = self.compute_similarity(query_embedding, article_embedding)
similarities.append((similarity, i))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top-k results
results = []
for similarity, idx in similarities[:top_k]:
if similarity >= settings.similarity_threshold:
article = articles[idx].copy()
article['similarity_score'] = similarity
results.append(article)
return results
# Test function
if __name__ == "__main__":
# Test with sample articles
sample_articles = [
{
"title": "AI Revolution in Healthcare",
"content": "Artificial intelligence is transforming medical diagnosis and treatment.",
"source": "TechNews"
},
{
"title": "Climate Change Solutions",
"content": "New technologies are being developed to combat global warming.",
"source": "ScienceDaily"
}
]
generator = EmbeddingGenerator()
embeddings = generator.generate_embeddings(sample_articles)
print(f"Test embeddings shape: {embeddings.shape}")