Files
ds_task_scp/backend/embeddings.py
T

29 lines
792 B
Python
Raw Normal View History

2025-07-11 22:29:45 +01:00
import cohere
from .config import Config
class EmbeddingGenerator:
def __init__(self):
self.client = cohere.Client(Config.COHERE_API_KEY)
def generate_embeddings(self, text: str):
response = self.client.embed(
texts=[text],
model=Config.EMBED_MODEL,
2025-07-14 23:41:31 +01:00
input_type="search_document"
2025-07-11 22:29:45 +01:00
)
return response.embeddings[0]
def rerank_issues(self, issues: list, query: str, top_n: int = 5):
2025-07-14 23:41:31 +01:00
# Handle empty issues list
if not issues:
return []
2025-07-11 22:29:45 +01:00
response = self.client.rerank(
query=query,
documents=issues,
2025-07-14 23:41:31 +01:00
top_n=min(top_n, len(issues)),
2025-07-11 22:29:45 +01:00
model=Config.RERANK_MODEL
)
return [result.document for result in response.results]