Files

285 lines
12 KiB
Python
Raw Permalink Normal View History

"""Embeddings generation for DS Task AI News"""
import os
import numpy as np
from typing import List, Dict, Any, Optional
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError:
SENTENCE_TRANSFORMERS_AVAILABLE = False
print("⚠️ Sentence Transformers not available")
try:
import cohere
COHERE_AVAILABLE = True
except ImportError:
COHERE_AVAILABLE = False
print("⚠️ Cohere not available")
from config import settings
class EmbeddingGenerator:
def __init__(self):
self.cohere_client = None
self.sentence_model = None
self.use_cohere = COHERE_AVAILABLE and bool(settings.cohere_api_key)
self.use_sentence_transformers = SENTENCE_TRANSFORMERS_AVAILABLE
self.model_loaded = False
self.dimension = settings.vector_dimension
self.embedding_method = "hash" # Default fallback
# Priority: 1. Local Sentence Transformers, 2. Cohere, 3. Hash fallback
# Use lazy loading for faster startup
if self.use_sentence_transformers:
print("🚀 Sentence Transformers available - will load on first use")
self.embedding_method = "sentence_transformers"
self.model_loaded = True # Mark as ready for lazy loading
if not self.use_sentence_transformers and self.use_cohere:
try:
self.cohere_client = cohere.Client(settings.cohere_api_key)
self.embedding_method = "cohere"
print("✅ Using Cohere for embeddings")
self.model_loaded = True
except Exception as e:
print(f"❌ Cohere initialization failed: {e}")
self.use_cohere = False
if not self.use_sentence_transformers and not self.use_cohere:
print("⚡ Using enhanced hash-based embeddings as fallback")
self.embedding_method = "hash"
self.model_loaded = True
def _load_sentence_model(self):
"""Lazy load sentence transformer model on first use"""
if self.sentence_model is None and self.use_sentence_transformers:
try:
print("📥 Loading Sentence Transformers model (first use)...")
print("🌐 This may take a few minutes for initial download...")
# Set longer timeout for model download
import socket
original_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(300) # 5 minutes timeout
try:
self.sentence_model = SentenceTransformer(settings.embedding_model)
print("✅ Sentence Transformers loaded successfully!")
print(f"📊 Model dimension: {self.sentence_model.get_sentence_embedding_dimension()}")
self.model_loaded = True
return True
finally:
# Restore original timeout
socket.setdefaulttimeout(original_timeout)
except Exception as e:
print(f"❌ Failed to load Sentence Transformers: {e}")
print("🔄 Retrying with cache_folder parameter...")
# Try with explicit cache folder
try:
import os
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
os.makedirs(cache_dir, exist_ok=True)
self.sentence_model = SentenceTransformer(
settings.embedding_model,
cache_folder=cache_dir
)
print("✅ Sentence Transformers loaded successfully on retry!")
print(f"📊 Model dimension: {self.sentence_model.get_sentence_embedding_dimension()}")
self.model_loaded = True
return True
except Exception as e2:
print(f"❌ Retry also failed: {e2}")
raise Exception(f"Cannot load Sentence Transformers model: {e2}")
return self.sentence_model is not None
def _simple_text_to_vector(self, text: str) -> np.ndarray:
"""Convert text to a simple vector using basic hashing (fallback method)"""
words = text.lower().split()
vector = np.zeros(self.dimension)
for i, word in enumerate(words[:50]): # Use first 50 words
hash_val = hash(word) % self.dimension
vector[hash_val] += 1.0 / (i + 1) # Weight by position
# Normalize
norm = np.linalg.norm(vector)
if norm > 0:
vector = vector / norm
return vector
def create_article_text(self, article: Dict[str, Any]) -> str:
"""Combine article fields into text for embedding"""
title = article.get('title', '')
content = article.get('content', '')
source = article.get('source', '')
# Combine with weights (title is more important)
text = f"{title}. {content}"
if source:
text += f" Source: {source}"
return text.strip()
def generate_embeddings_cohere(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Cohere"""
try:
response = self.cohere_client.embed(
texts=texts,
model='embed-english-v3.0',
input_type='search_document'
)
return np.array(response.embeddings)
except Exception as e:
print(f"Cohere embedding error: {e}")
raise
def generate_embeddings_sentence_transformer(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using Sentence Transformers"""
try:
if not self.model_loaded and SENTENCE_TRANSFORMERS_AVAILABLE:
self._load_sentence_model()
if self.sentence_model is None:
# Use simple hash-based embeddings as fallback
print("⚠️ Using simple hash-based embeddings (Sentence Transformers not available)")
embeddings = []
for text in texts:
embedding = self._simple_text_to_vector(text)
embeddings.append(embedding)
return np.array(embeddings)
embeddings = self.sentence_model.encode(texts, convert_to_numpy=True)
return embeddings
except Exception as e:
print(f"❌ Sentence Transformer embedding error: {e}")
# Use simple embeddings as fallback
print("⚠️ Falling back to simple hash-based embeddings")
embeddings = []
for text in texts:
embedding = self._simple_text_to_vector(text)
embeddings.append(embedding)
return np.array(embeddings)
def generate_embeddings(self, articles: List[Dict[str, Any]]) -> np.ndarray:
"""Generate embeddings for articles using best available method"""
if not articles:
return np.array([])
# Create texts for embedding
texts = [self.create_article_text(article) for article in articles]
print(f"🔄 Generating embeddings for {len(texts)} articles using {self.embedding_method}...")
# Priority: Sentence Transformers > Cohere > Hash fallback
if self.use_sentence_transformers:
# Lazy load model on first use
if self._load_sentence_model():
embeddings = self.generate_embeddings_sentence_transformer(texts)
else:
# Fallback to hash if model loading failed
embeddings = np.array([self._simple_text_to_vector(text) for text in texts])
elif self.use_cohere:
embeddings = self.generate_embeddings_cohere(texts)
else:
# Enhanced hash-based fallback
embeddings = np.array([self._simple_text_to_vector(text) for text in texts])
print(f"✅ Generated embeddings shape: {embeddings.shape}")
return embeddings
def generate_query_embedding(self, query: str) -> np.ndarray:
"""Generate embedding for a search query using best available method"""
print(f"🔍 Generating query embedding using {self.embedding_method}...")
# Priority: Sentence Transformers > Cohere > Hash fallback
if self.use_sentence_transformers:
# Lazy load model on first use
if self._load_sentence_model():
try:
embedding = self.sentence_model.encode([query], convert_to_numpy=True)[0]
print(f"✅ Query embedding generated with shape: {embedding.shape}")
return embedding
except Exception as e:
print(f"❌ Sentence Transformers query error: {e}")
if self.use_cohere:
try:
response = self.cohere_client.embed(
texts=[query],
model='embed-english-v3.0',
input_type='search_query'
)
embedding = np.array(response.embeddings[0])
print(f"✅ Query embedding generated with shape: {embedding.shape}")
return embedding
except Exception as e:
print(f"❌ Cohere query embedding error: {e}")
# Fallback to hash-based embeddings
print("⚡ Using hash-based fallback for query embedding")
return self._simple_text_to_vector(query)
def compute_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Compute cosine similarity between two embeddings"""
# Normalize embeddings
norm1 = np.linalg.norm(embedding1)
norm2 = np.linalg.norm(embedding2)
if norm1 == 0 or norm2 == 0:
return 0.0
# Cosine similarity
similarity = np.dot(embedding1, embedding2) / (norm1 * norm2)
return float(similarity)
def find_similar_articles(self, query_embedding: np.ndarray,
article_embeddings: np.ndarray,
articles: List[Dict[str, Any]],
top_k: int = 5) -> List[Dict[str, Any]]:
"""Find most similar articles to query"""
if len(article_embeddings) == 0:
return []
similarities = []
for i, article_embedding in enumerate(article_embeddings):
similarity = self.compute_similarity(query_embedding, article_embedding)
similarities.append((similarity, i))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top-k results
results = []
for similarity, idx in similarities[:top_k]:
if similarity >= settings.similarity_threshold:
article = articles[idx].copy()
article['similarity_score'] = similarity
results.append(article)
return results
# Test function
if __name__ == "__main__":
# Test with sample articles
sample_articles = [
{
"title": "AI Revolution in Healthcare",
"content": "Artificial intelligence is transforming medical diagnosis and treatment.",
"source": "TechNews"
},
{
"title": "Climate Change Solutions",
"content": "New technologies are being developed to combat global warming.",
"source": "ScienceDaily"
}
]
generator = EmbeddingGenerator()
embeddings = generator.generate_embeddings(sample_articles)
print(f"Test embeddings shape: {embeddings.shape}")