2025-07-17 22:20:25 +01:00
|
|
|
import cohere
|
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
import uuid
|
2025-07-18 00:52:51 +01:00
|
|
|
from pinecone import Pinecone, ServerlessSpec
|
2025-07-17 22:20:25 +01:00
|
|
|
import weaviate
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
from app.core.models import DocumentEmbedding
|
|
|
|
|
|
|
|
|
|
class EmbeddingService:
|
|
|
|
|
"""Service for document embedding and vector database operations."""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""Initialize the embedding service with the Cohere client and vector DB."""
|
|
|
|
|
# Initialize Cohere client
|
|
|
|
|
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
|
|
|
|
|
|
|
|
|
|
# Initialize vector database client based on configuration
|
|
|
|
|
self.vector_db_client = self._init_vector_db()
|
|
|
|
|
self.embedding_model = settings.EMBEDDING_MODEL
|
|
|
|
|
|
|
|
|
|
def _init_vector_db(self) -> Any:
|
|
|
|
|
"""Initialize the vector database client based on settings."""
|
|
|
|
|
if settings.VECTOR_DB == "pinecone" and settings.PINECONE_API_KEY:
|
|
|
|
|
# Initialize Pinecone with new API
|
|
|
|
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
|
|
|
|
|
|
|
|
|
# Check if index exists, if not create it
|
|
|
|
|
if settings.PINECONE_INDEX_NAME not in [idx["name"] for idx in pc.list_indexes()]:
|
|
|
|
|
pc.create_index(
|
|
|
|
|
name=settings.PINECONE_INDEX_NAME,
|
|
|
|
|
dimension=1024, # Cohere embed-english-v3.0 dimension
|
2025-07-18 00:52:51 +01:00
|
|
|
metric="cosine",
|
|
|
|
|
spec=ServerlessSpec(
|
|
|
|
|
cloud='aws',
|
|
|
|
|
region='us-east-1'
|
|
|
|
|
)
|
2025-07-17 22:20:25 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Return the index
|
|
|
|
|
return pc.Index(settings.PINECONE_INDEX_NAME)
|
|
|
|
|
|
|
|
|
|
elif settings.VECTOR_DB == "weaviate" and settings.WEAVIATE_URL:
|
|
|
|
|
# Initialize Weaviate
|
|
|
|
|
auth_config = weaviate.auth.AuthApiKey(api_key=settings.WEAVIATE_API_KEY) if settings.WEAVIATE_API_KEY else None
|
|
|
|
|
client = weaviate.Client(
|
|
|
|
|
url=settings.WEAVIATE_URL,
|
|
|
|
|
auth_client_secret=auth_config
|
|
|
|
|
)
|
|
|
|
|
# Check if schema exists, if not create it
|
|
|
|
|
if not client.schema.contains().get("classes", []):
|
|
|
|
|
class_obj = {
|
|
|
|
|
"class": "Document",
|
|
|
|
|
"vectorizer": "none", # We'll provide our own vectors
|
|
|
|
|
"properties": [
|
|
|
|
|
{
|
|
|
|
|
"name": "content",
|
|
|
|
|
"dataType": ["text"]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "document_id",
|
|
|
|
|
"dataType": ["string"]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "section_name",
|
|
|
|
|
"dataType": ["string"]
|
|
|
|
|
}
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
client.schema.create_class(class_obj)
|
|
|
|
|
return client
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("No valid vector database configuration found. Using mock implementation.")
|
|
|
|
|
return MockVectorDB()
|
|
|
|
|
|
|
|
|
|
async def embed_document(self, document_id: str, sections: Dict[str, str]) -> DocumentEmbedding:
|
|
|
|
|
"""
|
|
|
|
|
Embed document sections and store in vector database.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
document_id: Unique identifier for the document
|
|
|
|
|
sections: Dictionary mapping section names to section content
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
DocumentEmbedding object with embedding metadata
|
|
|
|
|
"""
|
|
|
|
|
section_ids = {}
|
|
|
|
|
|
|
|
|
|
for section_name, content in sections.items():
|
|
|
|
|
# Generate embedding for section content
|
|
|
|
|
try:
|
|
|
|
|
embedding_response = self.cohere_client.embed(
|
|
|
|
|
texts=[content],
|
|
|
|
|
model=self.embedding_model,
|
|
|
|
|
input_type="search_document"
|
|
|
|
|
)
|
|
|
|
|
embedding_vector = embedding_response.embeddings[0]
|
|
|
|
|
|
|
|
|
|
# Generate a unique ID for this section
|
|
|
|
|
section_id = f"{document_id}_{section_name}_{str(uuid.uuid4())[:8]}"
|
|
|
|
|
|
|
|
|
|
# Store in vector database
|
|
|
|
|
if settings.VECTOR_DB == "pinecone":
|
|
|
|
|
self.vector_db_client.upsert(
|
|
|
|
|
vectors=[{
|
|
|
|
|
"id": section_id,
|
|
|
|
|
"values": embedding_vector,
|
|
|
|
|
"metadata": {
|
|
|
|
|
"document_id": document_id,
|
|
|
|
|
"section_name": section_name,
|
|
|
|
|
"content": content[:1000] # Store truncated content for context
|
|
|
|
|
}
|
|
|
|
|
}],
|
|
|
|
|
namespace=document_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif settings.VECTOR_DB == "weaviate":
|
|
|
|
|
self.vector_db_client.data_object.create(
|
|
|
|
|
class_name="Document",
|
|
|
|
|
data_object={
|
|
|
|
|
"content": content,
|
|
|
|
|
"document_id": document_id,
|
|
|
|
|
"section_name": section_name
|
|
|
|
|
},
|
|
|
|
|
uuid=section_id,
|
|
|
|
|
vector=embedding_vector
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Store the section ID
|
|
|
|
|
section_ids[section_name] = section_id
|
|
|
|
|
logger.info(f"Successfully embedded section '{section_name}' for document {document_id}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error embedding section '{section_name}': {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Create and return DocumentEmbedding object
|
|
|
|
|
embedding = DocumentEmbedding(
|
|
|
|
|
embedding_id=str(uuid.uuid4()),
|
|
|
|
|
embedding_model=self.embedding_model,
|
|
|
|
|
vector_db=settings.VECTOR_DB,
|
|
|
|
|
sections=section_ids
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return embedding
|
|
|
|
|
|
|
|
|
|
async def retrieve_similar_sections(self, query: str, document_id: Optional[str] = None, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
Retrieve similar document sections for a query.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query: The query text to find similar sections for
|
|
|
|
|
document_id: Optional document ID to restrict search
|
|
|
|
|
top_k: Number of results to return
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of similar sections with metadata
|
|
|
|
|
"""
|
|
|
|
|
# Generate embedding for query
|
|
|
|
|
query_embedding = self.cohere_client.embed(
|
|
|
|
|
texts=[query],
|
|
|
|
|
model=self.embedding_model,
|
|
|
|
|
input_type="search_query"
|
|
|
|
|
).embeddings[0]
|
|
|
|
|
|
|
|
|
|
# Search vector database
|
|
|
|
|
if settings.VECTOR_DB == "pinecone":
|
|
|
|
|
namespace = document_id if document_id else None
|
|
|
|
|
results = self.vector_db_client.query(
|
|
|
|
|
vector=query_embedding,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
namespace=namespace,
|
|
|
|
|
include_metadata=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Format results
|
|
|
|
|
similar_sections = []
|
|
|
|
|
for match in results.matches:
|
|
|
|
|
similar_sections.append({
|
|
|
|
|
"section_id": match.id,
|
|
|
|
|
"document_id": match.metadata["document_id"],
|
|
|
|
|
"section_name": match.metadata["section_name"],
|
|
|
|
|
"content": match.metadata.get("content", ""),
|
|
|
|
|
"score": match.score
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
elif settings.VECTOR_DB == "weaviate":
|
|
|
|
|
query_builder = self.vector_db_client.query.get(
|
|
|
|
|
"Document", ["content", "document_id", "section_name"]
|
|
|
|
|
).with_near_vector({
|
|
|
|
|
"vector": query_embedding
|
|
|
|
|
}).with_limit(top_k)
|
|
|
|
|
|
|
|
|
|
if document_id:
|
|
|
|
|
query_builder = query_builder.with_where({
|
|
|
|
|
"path": ["document_id"],
|
|
|
|
|
"operator": "Equal",
|
|
|
|
|
"valueString": document_id
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
results = query_builder.do()
|
|
|
|
|
|
|
|
|
|
# Format results
|
|
|
|
|
similar_sections = []
|
|
|
|
|
for item in results.get("data", {}).get("Get", {}).get("Document", []):
|
|
|
|
|
similar_sections.append({
|
|
|
|
|
"section_id": item.get("_additional", {}).get("id"),
|
|
|
|
|
"document_id": item.get("document_id"),
|
|
|
|
|
"section_name": item.get("section_name"),
|
|
|
|
|
"content": item.get("content", ""),
|
|
|
|
|
"score": item.get("_additional", {}).get("distance")
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# Mock implementation
|
|
|
|
|
similar_sections = []
|
|
|
|
|
|
|
|
|
|
return similar_sections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MockVectorDB:
|
|
|
|
|
"""Mock vector database for development without actual vector DB."""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.vectors = {}
|
|
|
|
|
logger.warning("Using mock vector database. Not suitable for production.")
|
|
|
|
|
|
|
|
|
|
def upsert(self, vectors, namespace=None):
|
|
|
|
|
"""Mock upsert method."""
|
|
|
|
|
namespace = namespace or "default"
|
|
|
|
|
if namespace not in self.vectors:
|
|
|
|
|
self.vectors[namespace] = {}
|
|
|
|
|
|
|
|
|
|
for vector in vectors:
|
|
|
|
|
vector_id = vector['id']
|
|
|
|
|
metadata = vector['metadata']
|
|
|
|
|
self.vectors[namespace][vector_id] = metadata
|
|
|
|
|
|
|
|
|
|
def query(self, vector, top_k=5, namespace=None, include_metadata=True):
|
|
|
|
|
"""Mock query method."""
|
|
|
|
|
from collections import namedtuple
|
|
|
|
|
|
|
|
|
|
namespace = namespace or "default"
|
|
|
|
|
if namespace not in self.vectors:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
# Just return some mock results
|
|
|
|
|
Match = namedtuple('Match', ['id', 'score', 'metadata'])
|
|
|
|
|
Results = namedtuple('Results', ['matches'])
|
|
|
|
|
|
|
|
|
|
matches = [
|
|
|
|
|
Match(id=vector_id, score=0.8, metadata=metadata)
|
|
|
|
|
for vector_id, metadata in list(self.vectors[namespace].items())[:top_k]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return Results(matches=matches)
|