Initial commit
This commit is contained in:
@@ -0,0 +1,254 @@
|
||||
import cohere
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uuid
|
||||
from pinecone import Pinecone
|
||||
import weaviate
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.models import DocumentEmbedding
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for document embedding and vector database operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the embedding service with the Cohere client and vector DB."""
|
||||
# Initialize Cohere client
|
||||
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
|
||||
|
||||
# Initialize vector database client based on configuration
|
||||
self.vector_db_client = self._init_vector_db()
|
||||
self.embedding_model = settings.EMBEDDING_MODEL
|
||||
|
||||
def _init_vector_db(self) -> Any:
|
||||
"""Initialize the vector database client based on settings."""
|
||||
if settings.VECTOR_DB == "pinecone" and settings.PINECONE_API_KEY:
|
||||
# Initialize Pinecone with new API
|
||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||
|
||||
# Check if index exists, if not create it
|
||||
if settings.PINECONE_INDEX_NAME not in [idx["name"] for idx in pc.list_indexes()]:
|
||||
pc.create_index(
|
||||
name=settings.PINECONE_INDEX_NAME,
|
||||
dimension=1024, # Cohere embed-english-v3.0 dimension
|
||||
metric="cosine"
|
||||
)
|
||||
|
||||
# Return the index
|
||||
return pc.Index(settings.PINECONE_INDEX_NAME)
|
||||
|
||||
elif settings.VECTOR_DB == "weaviate" and settings.WEAVIATE_URL:
|
||||
# Initialize Weaviate
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=settings.WEAVIATE_API_KEY) if settings.WEAVIATE_API_KEY else None
|
||||
client = weaviate.Client(
|
||||
url=settings.WEAVIATE_URL,
|
||||
auth_client_secret=auth_config
|
||||
)
|
||||
# Check if schema exists, if not create it
|
||||
if not client.schema.contains().get("classes", []):
|
||||
class_obj = {
|
||||
"class": "Document",
|
||||
"vectorizer": "none", # We'll provide our own vectors
|
||||
"properties": [
|
||||
{
|
||||
"name": "content",
|
||||
"dataType": ["text"]
|
||||
},
|
||||
{
|
||||
"name": "document_id",
|
||||
"dataType": ["string"]
|
||||
},
|
||||
{
|
||||
"name": "section_name",
|
||||
"dataType": ["string"]
|
||||
}
|
||||
]
|
||||
}
|
||||
client.schema.create_class(class_obj)
|
||||
return client
|
||||
|
||||
else:
|
||||
logger.warning("No valid vector database configuration found. Using mock implementation.")
|
||||
return MockVectorDB()
|
||||
|
||||
async def embed_document(self, document_id: str, sections: Dict[str, str]) -> DocumentEmbedding:
|
||||
"""
|
||||
Embed document sections and store in vector database.
|
||||
|
||||
Args:
|
||||
document_id: Unique identifier for the document
|
||||
sections: Dictionary mapping section names to section content
|
||||
|
||||
Returns:
|
||||
DocumentEmbedding object with embedding metadata
|
||||
"""
|
||||
section_ids = {}
|
||||
|
||||
for section_name, content in sections.items():
|
||||
# Generate embedding for section content
|
||||
try:
|
||||
embedding_response = self.cohere_client.embed(
|
||||
texts=[content],
|
||||
model=self.embedding_model,
|
||||
input_type="search_document"
|
||||
)
|
||||
embedding_vector = embedding_response.embeddings[0]
|
||||
|
||||
# Generate a unique ID for this section
|
||||
section_id = f"{document_id}_{section_name}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
# Store in vector database
|
||||
if settings.VECTOR_DB == "pinecone":
|
||||
self.vector_db_client.upsert(
|
||||
vectors=[{
|
||||
"id": section_id,
|
||||
"values": embedding_vector,
|
||||
"metadata": {
|
||||
"document_id": document_id,
|
||||
"section_name": section_name,
|
||||
"content": content[:1000] # Store truncated content for context
|
||||
}
|
||||
}],
|
||||
namespace=document_id
|
||||
)
|
||||
|
||||
elif settings.VECTOR_DB == "weaviate":
|
||||
self.vector_db_client.data_object.create(
|
||||
class_name="Document",
|
||||
data_object={
|
||||
"content": content,
|
||||
"document_id": document_id,
|
||||
"section_name": section_name
|
||||
},
|
||||
uuid=section_id,
|
||||
vector=embedding_vector
|
||||
)
|
||||
|
||||
# Store the section ID
|
||||
section_ids[section_name] = section_id
|
||||
logger.info(f"Successfully embedded section '{section_name}' for document {document_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding section '{section_name}': {str(e)}")
|
||||
raise
|
||||
|
||||
# Create and return DocumentEmbedding object
|
||||
embedding = DocumentEmbedding(
|
||||
embedding_id=str(uuid.uuid4()),
|
||||
embedding_model=self.embedding_model,
|
||||
vector_db=settings.VECTOR_DB,
|
||||
sections=section_ids
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
async def retrieve_similar_sections(self, query: str, document_id: Optional[str] = None, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve similar document sections for a query.
|
||||
|
||||
Args:
|
||||
query: The query text to find similar sections for
|
||||
document_id: Optional document ID to restrict search
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of similar sections with metadata
|
||||
"""
|
||||
# Generate embedding for query
|
||||
query_embedding = self.cohere_client.embed(
|
||||
texts=[query],
|
||||
model=self.embedding_model,
|
||||
input_type="search_query"
|
||||
).embeddings[0]
|
||||
|
||||
# Search vector database
|
||||
if settings.VECTOR_DB == "pinecone":
|
||||
namespace = document_id if document_id else None
|
||||
results = self.vector_db_client.query(
|
||||
vector=query_embedding,
|
||||
top_k=top_k,
|
||||
namespace=namespace,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Format results
|
||||
similar_sections = []
|
||||
for match in results.matches:
|
||||
similar_sections.append({
|
||||
"section_id": match.id,
|
||||
"document_id": match.metadata["document_id"],
|
||||
"section_name": match.metadata["section_name"],
|
||||
"content": match.metadata.get("content", ""),
|
||||
"score": match.score
|
||||
})
|
||||
|
||||
elif settings.VECTOR_DB == "weaviate":
|
||||
query_builder = self.vector_db_client.query.get(
|
||||
"Document", ["content", "document_id", "section_name"]
|
||||
).with_near_vector({
|
||||
"vector": query_embedding
|
||||
}).with_limit(top_k)
|
||||
|
||||
if document_id:
|
||||
query_builder = query_builder.with_where({
|
||||
"path": ["document_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": document_id
|
||||
})
|
||||
|
||||
results = query_builder.do()
|
||||
|
||||
# Format results
|
||||
similar_sections = []
|
||||
for item in results.get("data", {}).get("Get", {}).get("Document", []):
|
||||
similar_sections.append({
|
||||
"section_id": item.get("_additional", {}).get("id"),
|
||||
"document_id": item.get("document_id"),
|
||||
"section_name": item.get("section_name"),
|
||||
"content": item.get("content", ""),
|
||||
"score": item.get("_additional", {}).get("distance")
|
||||
})
|
||||
|
||||
else:
|
||||
# Mock implementation
|
||||
similar_sections = []
|
||||
|
||||
return similar_sections
|
||||
|
||||
|
||||
class MockVectorDB:
|
||||
"""Mock vector database for development without actual vector DB."""
|
||||
|
||||
def __init__(self):
|
||||
self.vectors = {}
|
||||
logger.warning("Using mock vector database. Not suitable for production.")
|
||||
|
||||
def upsert(self, vectors, namespace=None):
|
||||
"""Mock upsert method."""
|
||||
namespace = namespace or "default"
|
||||
if namespace not in self.vectors:
|
||||
self.vectors[namespace] = {}
|
||||
|
||||
for vector in vectors:
|
||||
vector_id = vector['id']
|
||||
metadata = vector['metadata']
|
||||
self.vectors[namespace][vector_id] = metadata
|
||||
|
||||
def query(self, vector, top_k=5, namespace=None, include_metadata=True):
|
||||
"""Mock query method."""
|
||||
from collections import namedtuple
|
||||
|
||||
namespace = namespace or "default"
|
||||
if namespace not in self.vectors:
|
||||
return []
|
||||
|
||||
# Just return some mock results
|
||||
Match = namedtuple('Match', ['id', 'score', 'metadata'])
|
||||
Results = namedtuple('Results', ['matches'])
|
||||
|
||||
matches = [
|
||||
Match(id=vector_id, score=0.8, metadata=metadata)
|
||||
for vector_id, metadata in list(self.vectors[namespace].items())[:top_k]
|
||||
]
|
||||
|
||||
return Results(matches=matches)
|
||||
Reference in New Issue
Block a user