Files
ds_scp_task_solution/app/services/embedding.py
T

258 lines
9.9 KiB
Python
Raw Normal View History

2025-07-17 22:20:25 +01:00
import cohere
from typing import List, Dict, Any, Optional
import uuid
2025-07-18 00:52:51 +01:00
from pinecone import Pinecone, ServerlessSpec
2025-07-17 22:20:25 +01:00
import weaviate
from loguru import logger
from app.core.config import settings
from app.core.models import DocumentEmbedding
class EmbeddingService:
"""Service for document embedding and vector database operations."""
def __init__(self):
"""Initialize the embedding service with the Cohere client and vector DB."""
# Initialize Cohere client
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
# Initialize vector database client based on configuration
self.vector_db_client = self._init_vector_db()
self.embedding_model = settings.EMBEDDING_MODEL
def _init_vector_db(self) -> Any:
"""Initialize the vector database client based on settings."""
if settings.VECTOR_DB == "pinecone" and settings.PINECONE_API_KEY:
# Initialize Pinecone with new API
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
# Check if index exists, if not create it
if settings.PINECONE_INDEX_NAME not in [idx["name"] for idx in pc.list_indexes()]:
pc.create_index(
name=settings.PINECONE_INDEX_NAME,
dimension=1024, # Cohere embed-english-v3.0 dimension
2025-07-18 00:52:51 +01:00
metric="cosine",
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
2025-07-17 22:20:25 +01:00
)
# Return the index
return pc.Index(settings.PINECONE_INDEX_NAME)
elif settings.VECTOR_DB == "weaviate" and settings.WEAVIATE_URL:
# Initialize Weaviate
auth_config = weaviate.auth.AuthApiKey(api_key=settings.WEAVIATE_API_KEY) if settings.WEAVIATE_API_KEY else None
client = weaviate.Client(
url=settings.WEAVIATE_URL,
auth_client_secret=auth_config
)
# Check if schema exists, if not create it
if not client.schema.contains().get("classes", []):
class_obj = {
"class": "Document",
"vectorizer": "none", # We'll provide our own vectors
"properties": [
{
"name": "content",
"dataType": ["text"]
},
{
"name": "document_id",
"dataType": ["string"]
},
{
"name": "section_name",
"dataType": ["string"]
}
]
}
client.schema.create_class(class_obj)
return client
else:
logger.warning("No valid vector database configuration found. Using mock implementation.")
return MockVectorDB()
async def embed_document(self, document_id: str, sections: Dict[str, str]) -> DocumentEmbedding:
"""
Embed document sections and store in vector database.
Args:
document_id: Unique identifier for the document
sections: Dictionary mapping section names to section content
Returns:
DocumentEmbedding object with embedding metadata
"""
section_ids = {}
for section_name, content in sections.items():
# Generate embedding for section content
try:
embedding_response = self.cohere_client.embed(
texts=[content],
model=self.embedding_model,
input_type="search_document"
)
embedding_vector = embedding_response.embeddings[0]
# Generate a unique ID for this section
section_id = f"{document_id}_{section_name}_{str(uuid.uuid4())[:8]}"
# Store in vector database
if settings.VECTOR_DB == "pinecone":
self.vector_db_client.upsert(
vectors=[{
"id": section_id,
"values": embedding_vector,
"metadata": {
"document_id": document_id,
"section_name": section_name,
"content": content[:1000] # Store truncated content for context
}
}],
namespace=document_id
)
elif settings.VECTOR_DB == "weaviate":
self.vector_db_client.data_object.create(
class_name="Document",
data_object={
"content": content,
"document_id": document_id,
"section_name": section_name
},
uuid=section_id,
vector=embedding_vector
)
# Store the section ID
section_ids[section_name] = section_id
logger.info(f"Successfully embedded section '{section_name}' for document {document_id}")
except Exception as e:
logger.error(f"Error embedding section '{section_name}': {str(e)}")
raise
# Create and return DocumentEmbedding object
embedding = DocumentEmbedding(
embedding_id=str(uuid.uuid4()),
embedding_model=self.embedding_model,
vector_db=settings.VECTOR_DB,
sections=section_ids
)
return embedding
async def retrieve_similar_sections(self, query: str, document_id: Optional[str] = None, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve similar document sections for a query.
Args:
query: The query text to find similar sections for
document_id: Optional document ID to restrict search
top_k: Number of results to return
Returns:
List of similar sections with metadata
"""
# Generate embedding for query
query_embedding = self.cohere_client.embed(
texts=[query],
model=self.embedding_model,
input_type="search_query"
).embeddings[0]
# Search vector database
if settings.VECTOR_DB == "pinecone":
namespace = document_id if document_id else None
results = self.vector_db_client.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace,
include_metadata=True
)
# Format results
similar_sections = []
for match in results.matches:
similar_sections.append({
"section_id": match.id,
"document_id": match.metadata["document_id"],
"section_name": match.metadata["section_name"],
"content": match.metadata.get("content", ""),
"score": match.score
})
elif settings.VECTOR_DB == "weaviate":
query_builder = self.vector_db_client.query.get(
"Document", ["content", "document_id", "section_name"]
).with_near_vector({
"vector": query_embedding
}).with_limit(top_k)
if document_id:
query_builder = query_builder.with_where({
"path": ["document_id"],
"operator": "Equal",
"valueString": document_id
})
results = query_builder.do()
# Format results
similar_sections = []
for item in results.get("data", {}).get("Get", {}).get("Document", []):
similar_sections.append({
"section_id": item.get("_additional", {}).get("id"),
"document_id": item.get("document_id"),
"section_name": item.get("section_name"),
"content": item.get("content", ""),
"score": item.get("_additional", {}).get("distance")
})
else:
# Mock implementation
similar_sections = []
return similar_sections
class MockVectorDB:
"""Mock vector database for development without actual vector DB."""
def __init__(self):
self.vectors = {}
logger.warning("Using mock vector database. Not suitable for production.")
def upsert(self, vectors, namespace=None):
"""Mock upsert method."""
namespace = namespace or "default"
if namespace not in self.vectors:
self.vectors[namespace] = {}
for vector in vectors:
vector_id = vector['id']
metadata = vector['metadata']
self.vectors[namespace][vector_id] = metadata
def query(self, vector, top_k=5, namespace=None, include_metadata=True):
"""Mock query method."""
from collections import namedtuple
namespace = namespace or "default"
if namespace not in self.vectors:
return []
# Just return some mock results
Match = namedtuple('Match', ['id', 'score', 'metadata'])
Results = namedtuple('Results', ['matches'])
matches = [
Match(id=vector_id, score=0.8, metadata=metadata)
for vector_id, metadata in list(self.vectors[namespace].items())[:top_k]
]
return Results(matches=matches)