Files
ds_scp_task_solution/app/services/embedding.py
T
Aherobo Ovie Victor 0e3e22e8cb Initial commit
2025-07-17 22:20:25 +01:00

254 lines
9.7 KiB
Python

import cohere
from typing import List, Dict, Any, Optional
import uuid
from pinecone import Pinecone
import weaviate
from loguru import logger
from app.core.config import settings
from app.core.models import DocumentEmbedding
class EmbeddingService:
"""Service for document embedding and vector database operations."""
def __init__(self):
"""Initialize the embedding service with the Cohere client and vector DB."""
# Initialize Cohere client
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
# Initialize vector database client based on configuration
self.vector_db_client = self._init_vector_db()
self.embedding_model = settings.EMBEDDING_MODEL
def _init_vector_db(self) -> Any:
"""Initialize the vector database client based on settings."""
if settings.VECTOR_DB == "pinecone" and settings.PINECONE_API_KEY:
# Initialize Pinecone with new API
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
# Check if index exists, if not create it
if settings.PINECONE_INDEX_NAME not in [idx["name"] for idx in pc.list_indexes()]:
pc.create_index(
name=settings.PINECONE_INDEX_NAME,
dimension=1024, # Cohere embed-english-v3.0 dimension
metric="cosine"
)
# Return the index
return pc.Index(settings.PINECONE_INDEX_NAME)
elif settings.VECTOR_DB == "weaviate" and settings.WEAVIATE_URL:
# Initialize Weaviate
auth_config = weaviate.auth.AuthApiKey(api_key=settings.WEAVIATE_API_KEY) if settings.WEAVIATE_API_KEY else None
client = weaviate.Client(
url=settings.WEAVIATE_URL,
auth_client_secret=auth_config
)
# Check if schema exists, if not create it
if not client.schema.contains().get("classes", []):
class_obj = {
"class": "Document",
"vectorizer": "none", # We'll provide our own vectors
"properties": [
{
"name": "content",
"dataType": ["text"]
},
{
"name": "document_id",
"dataType": ["string"]
},
{
"name": "section_name",
"dataType": ["string"]
}
]
}
client.schema.create_class(class_obj)
return client
else:
logger.warning("No valid vector database configuration found. Using mock implementation.")
return MockVectorDB()
async def embed_document(self, document_id: str, sections: Dict[str, str]) -> DocumentEmbedding:
"""
Embed document sections and store in vector database.
Args:
document_id: Unique identifier for the document
sections: Dictionary mapping section names to section content
Returns:
DocumentEmbedding object with embedding metadata
"""
section_ids = {}
for section_name, content in sections.items():
# Generate embedding for section content
try:
embedding_response = self.cohere_client.embed(
texts=[content],
model=self.embedding_model,
input_type="search_document"
)
embedding_vector = embedding_response.embeddings[0]
# Generate a unique ID for this section
section_id = f"{document_id}_{section_name}_{str(uuid.uuid4())[:8]}"
# Store in vector database
if settings.VECTOR_DB == "pinecone":
self.vector_db_client.upsert(
vectors=[{
"id": section_id,
"values": embedding_vector,
"metadata": {
"document_id": document_id,
"section_name": section_name,
"content": content[:1000] # Store truncated content for context
}
}],
namespace=document_id
)
elif settings.VECTOR_DB == "weaviate":
self.vector_db_client.data_object.create(
class_name="Document",
data_object={
"content": content,
"document_id": document_id,
"section_name": section_name
},
uuid=section_id,
vector=embedding_vector
)
# Store the section ID
section_ids[section_name] = section_id
logger.info(f"Successfully embedded section '{section_name}' for document {document_id}")
except Exception as e:
logger.error(f"Error embedding section '{section_name}': {str(e)}")
raise
# Create and return DocumentEmbedding object
embedding = DocumentEmbedding(
embedding_id=str(uuid.uuid4()),
embedding_model=self.embedding_model,
vector_db=settings.VECTOR_DB,
sections=section_ids
)
return embedding
async def retrieve_similar_sections(self, query: str, document_id: Optional[str] = None, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve similar document sections for a query.
Args:
query: The query text to find similar sections for
document_id: Optional document ID to restrict search
top_k: Number of results to return
Returns:
List of similar sections with metadata
"""
# Generate embedding for query
query_embedding = self.cohere_client.embed(
texts=[query],
model=self.embedding_model,
input_type="search_query"
).embeddings[0]
# Search vector database
if settings.VECTOR_DB == "pinecone":
namespace = document_id if document_id else None
results = self.vector_db_client.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace,
include_metadata=True
)
# Format results
similar_sections = []
for match in results.matches:
similar_sections.append({
"section_id": match.id,
"document_id": match.metadata["document_id"],
"section_name": match.metadata["section_name"],
"content": match.metadata.get("content", ""),
"score": match.score
})
elif settings.VECTOR_DB == "weaviate":
query_builder = self.vector_db_client.query.get(
"Document", ["content", "document_id", "section_name"]
).with_near_vector({
"vector": query_embedding
}).with_limit(top_k)
if document_id:
query_builder = query_builder.with_where({
"path": ["document_id"],
"operator": "Equal",
"valueString": document_id
})
results = query_builder.do()
# Format results
similar_sections = []
for item in results.get("data", {}).get("Get", {}).get("Document", []):
similar_sections.append({
"section_id": item.get("_additional", {}).get("id"),
"document_id": item.get("document_id"),
"section_name": item.get("section_name"),
"content": item.get("content", ""),
"score": item.get("_additional", {}).get("distance")
})
else:
# Mock implementation
similar_sections = []
return similar_sections
class MockVectorDB:
"""Mock vector database for development without actual vector DB."""
def __init__(self):
self.vectors = {}
logger.warning("Using mock vector database. Not suitable for production.")
def upsert(self, vectors, namespace=None):
"""Mock upsert method."""
namespace = namespace or "default"
if namespace not in self.vectors:
self.vectors[namespace] = {}
for vector in vectors:
vector_id = vector['id']
metadata = vector['metadata']
self.vectors[namespace][vector_id] = metadata
def query(self, vector, top_k=5, namespace=None, include_metadata=True):
"""Mock query method."""
from collections import namedtuple
namespace = namespace or "default"
if namespace not in self.vectors:
return []
# Just return some mock results
Match = namedtuple('Match', ['id', 'score', 'metadata'])
Results = namedtuple('Results', ['matches'])
matches = [
Match(id=vector_id, score=0.8, metadata=metadata)
for vector_id, metadata in list(self.vectors[namespace].items())[:top_k]
]
return Results(matches=matches)