Files
ds_task_marketing_assistant…/backend/embeddings.py
T

55 lines
1.6 KiB
Python
Raw Normal View History

import cohere
import numpy as np
from typing import List, Union
import os
from dotenv import load_dotenv
load_dotenv()
class CohereEmbeddings:
def __init__(self):
self.api_key = os.getenv('COHERE_API_KEY')
if not self.api_key:
raise ValueError("COHERE_API_KEY environment variable is not set")
self.client = cohere.Client(self.api_key)
def generate(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Generate embeddings for the given text using Cohere.
Args:
text: Single text string or list of texts
Returns:
numpy array of embeddings
"""
if isinstance(text, str):
text = [text]
response = self.client.embed(
texts=text,
model='embed-english-v3.0',
input_type='search_document'
)
return np.array(response.embeddings)
def generate_batch(self, texts: List[str], batch_size: int = 96) -> List[np.ndarray]:
"""
Generate embeddings for a large batch of texts.
Args:
texts: List of texts to generate embeddings for
batch_size: Size of each batch
Returns:
List of numpy arrays containing embeddings
"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
embeddings = self.generate(batch)
all_embeddings.extend(embeddings)
return all_embeddings