55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
|
|
import cohere
|
||
|
|
import numpy as np
|
||
|
|
from typing import List, Union
|
||
|
|
import os
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
|
||
|
|
load_dotenv()
|
||
|
|
|
||
|
|
class CohereEmbeddings:
|
||
|
|
def __init__(self):
|
||
|
|
self.api_key = os.getenv('COHERE_API_KEY')
|
||
|
|
if not self.api_key:
|
||
|
|
raise ValueError("COHERE_API_KEY environment variable is not set")
|
||
|
|
self.client = cohere.Client(self.api_key)
|
||
|
|
|
||
|
|
def generate(self, text: Union[str, List[str]]) -> np.ndarray:
|
||
|
|
"""
|
||
|
|
Generate embeddings for the given text using Cohere.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: Single text string or list of texts
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
numpy array of embeddings
|
||
|
|
"""
|
||
|
|
if isinstance(text, str):
|
||
|
|
text = [text]
|
||
|
|
|
||
|
|
response = self.client.embed(
|
||
|
|
texts=text,
|
||
|
|
model='embed-english-v3.0',
|
||
|
|
input_type='search_document'
|
||
|
|
)
|
||
|
|
|
||
|
|
return np.array(response.embeddings)
|
||
|
|
|
||
|
|
def generate_batch(self, texts: List[str], batch_size: int = 96) -> List[np.ndarray]:
|
||
|
|
"""
|
||
|
|
Generate embeddings for a large batch of texts.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
texts: List of texts to generate embeddings for
|
||
|
|
batch_size: Size of each batch
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of numpy arrays containing embeddings
|
||
|
|
"""
|
||
|
|
all_embeddings = []
|
||
|
|
|
||
|
|
for i in range(0, len(texts), batch_size):
|
||
|
|
batch = texts[i:i + batch_size]
|
||
|
|
embeddings = self.generate(batch)
|
||
|
|
all_embeddings.extend(embeddings)
|
||
|
|
|
||
|
|
return all_embeddings
|