feat: Initial SCP project setup with AI-powered document compliance tools
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
from dataclasses import dataclass
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
# API Keys
|
||||
COHERE_API_KEY: str = os.getenv("COHERE_API_KEY", "")
|
||||
DEEPSEEK_API_KEY: str = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
PINECONE_API_KEY: str = os.getenv("PINECONE_API_KEY", "")
|
||||
PINECONE_ENVIRONMENT: str = os.getenv("PINECONE_ENVIRONMENT", "")
|
||||
|
||||
# Vector DB Settings
|
||||
PINECONE_INDEX_NAME: str = "document-compliance"
|
||||
|
||||
# Model Settings
|
||||
COHERE_EMBEDDING_MODEL: str = "embed-english-v3.0"
|
||||
COHERE_RERANKER_MODEL: str = "rerank-english-v2.0"
|
||||
DEEPSEEK_MODEL: str = "deepseek-r1"
|
||||
VECTOR_DIMENSION: int = 1024 # Updated to match Cohere's embedding dimension
|
||||
|
||||
config = Settings()
|
||||
@@ -0,0 +1,162 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_path: str = "data/app.db"):
|
||||
self.db_path = db_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize the database with required tables."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create analysis table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS analysis (
|
||||
document_id TEXT PRIMARY KEY,
|
||||
summary TEXT,
|
||||
issues TEXT,
|
||||
recommendations TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create metadata table
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
document_id TEXT PRIMARY KEY,
|
||||
filename TEXT,
|
||||
document_type TEXT,
|
||||
description TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_analysis(self, document_id: str, analysis: Dict[str, Any]):
|
||||
"""Save analysis results to the database."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO analysis (document_id, summary, issues, recommendations)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
document_id,
|
||||
analysis['summary'],
|
||||
json.dumps(analysis['issues']),
|
||||
json.dumps(analysis['recommendations'])
|
||||
))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving analysis for document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_analysis(self, document_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve analysis results from the database."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT summary, issues, recommendations FROM analysis WHERE document_id = ?', (document_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
raise FileNotFoundError(f"Analysis not found for document {document_id}")
|
||||
|
||||
return {
|
||||
'document_id': document_id,
|
||||
'summary': result[0],
|
||||
'issues': json.loads(result[1]),
|
||||
'recommendations': json.loads(result[2])
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving analysis for document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_metadata(self, document_id: str, metadata: Dict[str, Any]):
|
||||
"""Save document metadata to the database."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO metadata (document_id, filename, document_type, description)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
document_id,
|
||||
metadata['filename'],
|
||||
metadata['document_type'],
|
||||
metadata.get('description')
|
||||
))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error saving metadata for document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_metadata(self, document_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve document metadata from the database."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT filename, document_type, description FROM metadata WHERE document_id = ?', (document_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result:
|
||||
raise FileNotFoundError(f"Metadata not found for document {document_id}")
|
||||
|
||||
return {
|
||||
'document_id': document_id,
|
||||
'filename': result[0],
|
||||
'document_type': result[1],
|
||||
'description': result[2]
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving metadata for document {document_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_metadata(self) -> list:
|
||||
"""Retrieve metadata for all documents."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT m.document_id, m.filename, m.document_type, m.description, m.created_at,
|
||||
CASE WHEN a.document_id IS NOT NULL THEN 1 ELSE 0 END as has_analysis
|
||||
FROM metadata m
|
||||
LEFT JOIN analysis a ON m.document_id = a.document_id
|
||||
ORDER BY m.created_at DESC
|
||||
''')
|
||||
results = cursor.fetchall()
|
||||
|
||||
return [{
|
||||
'document_id': row[0],
|
||||
'filename': row[1],
|
||||
'document_type': row[2],
|
||||
'description': row[3],
|
||||
'upload_date': row[4],
|
||||
'status': 'completed' if row[5] == 1 else 'processing'
|
||||
} for row in results]
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving all metadata: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_document(self, document_id: str):
|
||||
"""Delete a document and its associated data from the database."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM analysis WHERE document_id = ?', (document_id,))
|
||||
cursor.execute('DELETE FROM metadata WHERE document_id = ?', (document_id,))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting document {document_id}: {str(e)}")
|
||||
raise
|
||||
@@ -0,0 +1,248 @@
|
||||
import cohere
|
||||
import requests
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from services.config import config
|
||||
from services.database import Database
|
||||
|
||||
class DocumentProcessor:
|
||||
def __init__(self, vector_store):
|
||||
self.vector_store = vector_store
|
||||
self.cohere_client = cohere.Client(config.COHERE_API_KEY)
|
||||
self.deepseek_url = "https://api.deepseek.com/v1/chat/completions"
|
||||
self.deepseek_headers = {
|
||||
"Authorization": f"Bearer {config.DEEPSEEK_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
self.database = Database()
|
||||
|
||||
async def process_document(self, doc_id: str, file_path: str, document_type: str, is_resubmission: bool = False):
|
||||
try:
|
||||
# Read document content with error handling for encoding
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with a different encoding if UTF-8 fails
|
||||
with open(file_path, 'r', encoding='latin-1') as f:
|
||||
content = f.read()
|
||||
|
||||
logging.info(f"Processing document {doc_id} with content length: {len(content)}")
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = self.cohere_client.embed(
|
||||
texts=[content],
|
||||
model=config.COHERE_EMBEDDING_MODEL,
|
||||
input_type="search_document" # Required parameter for the model
|
||||
).embeddings[0]
|
||||
|
||||
# Store in vector database
|
||||
self.vector_store.store_embedding(doc_id, embeddings, content)
|
||||
|
||||
# Process with DeepSeek for initial parsing
|
||||
deepseek_parse_payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a document analysis assistant. Extract key sections and requirements from the following document."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
"max_tokens": 4000
|
||||
}
|
||||
|
||||
# Make the API call with error handling
|
||||
try:
|
||||
deepseek_parse_response = requests.post(
|
||||
self.deepseek_url,
|
||||
json=deepseek_parse_payload,
|
||||
headers=self.deepseek_headers,
|
||||
timeout=60 # Add timeout
|
||||
)
|
||||
|
||||
# Check if the response is successful
|
||||
if deepseek_parse_response.status_code != 200:
|
||||
logging.error(f"DeepSeek API error: {deepseek_parse_response.status_code} - {deepseek_parse_response.text}")
|
||||
# Use a fallback summary if the API call fails
|
||||
summary = "Document analysis could not be completed due to API limitations."
|
||||
else:
|
||||
# Try to parse the JSON response
|
||||
try:
|
||||
deepseek_parse_result = deepseek_parse_response.json()
|
||||
summary = deepseek_parse_result['choices'][0]['message']['content']
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logging.error(f"Error parsing DeepSeek response: {str(e)}")
|
||||
logging.error(f"Response text: {deepseek_parse_response.text}")
|
||||
summary = "Document analysis could not be completed due to parsing errors."
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error calling DeepSeek API: {str(e)}")
|
||||
summary = "Document analysis could not be completed due to API connection issues."
|
||||
|
||||
# Process with DeepSeek for deep reasoning using URL
|
||||
deepseek_payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert in document compliance analysis. Analyze the following document for compliance issues and provide detailed feedback."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Analyze this type of document {document_type} for compliance issues and provide detailed feedback:\n\n{content}
|
||||
and these are the main sections of the document:\n\n{summary}"""
|
||||
}
|
||||
],
|
||||
"max_tokens": 4000
|
||||
}
|
||||
|
||||
# Make the API call with error handling
|
||||
try:
|
||||
deepseek_response = requests.post(
|
||||
self.deepseek_url,
|
||||
json=deepseek_payload,
|
||||
headers=self.deepseek_headers,
|
||||
timeout=60 # Add timeout
|
||||
)
|
||||
|
||||
# Check if the response is successful
|
||||
if deepseek_response.status_code != 200:
|
||||
logging.error(f"DeepSeek API error: {deepseek_response.status_code} - {deepseek_response.text}")
|
||||
# Use a fallback for issues if the API call fails
|
||||
issues = ["Document analysis could not be completed due to API limitations."]
|
||||
else:
|
||||
# Try to parse the JSON response
|
||||
try:
|
||||
deepseek_result = deepseek_response.json()
|
||||
issues = self._extract_issues(deepseek_result['choices'][0]['message']['content'])
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logging.error(f"Error parsing DeepSeek response: {str(e)}")
|
||||
logging.error(f"Response text: {deepseek_response.text}")
|
||||
issues = ["Document analysis could not be completed due to parsing errors."]
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error calling DeepSeek API: {str(e)}")
|
||||
issues = ["Document analysis could not be completed due to API connection issues."]
|
||||
|
||||
# Use Cohere reranker to prioritize issues
|
||||
try:
|
||||
reranked_issues = self.cohere_client.rerank(
|
||||
query="Compliance issues in technical document",
|
||||
documents=issues,
|
||||
model=config.COHERE_RERANKER_MODEL
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error using Cohere reranker: {str(e)}")
|
||||
# Create a simple reranked issues list if Cohere fails
|
||||
reranked_issues = [type('obj', (object,), {'document': issue, 'index': i}) for i, issue in enumerate(issues)]
|
||||
|
||||
# Store analysis results
|
||||
analysis = {
|
||||
"document_id": doc_id,
|
||||
"summary": summary,
|
||||
"issues": self._format_issues(reranked_issues),
|
||||
"recommendations": self._generate_recommendations(reranked_issues)
|
||||
}
|
||||
|
||||
# Save analysis to database
|
||||
self.database.save_analysis(doc_id, analysis)
|
||||
|
||||
# If this is a resubmission, update the metadata in the database
|
||||
if is_resubmission:
|
||||
try:
|
||||
# Get existing metadata
|
||||
existing_metadata = self.database.get_metadata(doc_id)
|
||||
|
||||
# Update with new document type if provided
|
||||
if document_type:
|
||||
existing_metadata["document_type"] = document_type
|
||||
|
||||
# Save updated metadata
|
||||
self.database.save_metadata(doc_id, existing_metadata)
|
||||
logging.info(f"Updated metadata for resubmitted document {doc_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating metadata for resubmitted document {doc_id}: {str(e)}")
|
||||
|
||||
logging.info(f"Document {doc_id} processed successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing document {doc_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_analysis(self, doc_id: str) -> Dict[str, Any]:
|
||||
return self.database.get_analysis(doc_id)
|
||||
|
||||
def _extract_issues(self, deepseek_response: str) -> List[str]:
|
||||
# Simple extraction of issues from DeepSeek's response
|
||||
# In a real implementation, this would be more sophisticated
|
||||
print(deepseek_response)
|
||||
return [issue.strip() for issue in re.split(r'\d+\.', deepseek_response) if issue.strip()]
|
||||
|
||||
def _format_issues(self, reranked_issues) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"issue": issue[0] if isinstance(issue, tuple) else issue.document,
|
||||
"severity": "high" if i < 3 else "medium" if i < 6 else "low",
|
||||
"rank": i + 1
|
||||
}
|
||||
for i, issue in enumerate(reranked_issues)
|
||||
]
|
||||
|
||||
def _generate_recommendations(self, reranked_issues) -> List[str]:
|
||||
# Generate specific recommendations for each issue
|
||||
recommendations = []
|
||||
print(f"Generating recommendations for {reranked_issues} issues")
|
||||
# Extract the results from the RerankResponse object
|
||||
results = reranked_issues.results if hasattr(reranked_issues, 'results') else reranked_issues
|
||||
|
||||
for issue in results[:5]: # Focus on top 5 issues
|
||||
recommendation_payload = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert in document compliance. Provide specific, actionable recommendations to fix compliance issues."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Provide a specific, actionable recommendation to fix this compliance issue: {issue}"
|
||||
}
|
||||
],
|
||||
"max_tokens": 1000
|
||||
}
|
||||
|
||||
# Make the API call with error handling
|
||||
try:
|
||||
recommendation_response = requests.post(
|
||||
self.deepseek_url,
|
||||
json=recommendation_payload,
|
||||
headers=self.deepseek_headers,
|
||||
timeout=60 # Add timeout
|
||||
)
|
||||
|
||||
# Check if the response is successful
|
||||
if recommendation_response.status_code != 200:
|
||||
logging.error(f"DeepSeek API error: {recommendation_response.status_code} - {recommendation_response.text}")
|
||||
recommendations.append("Recommendation could not be generated due to API limitations.")
|
||||
else:
|
||||
# Try to parse the JSON response
|
||||
try:
|
||||
recommendation_result = recommendation_response.json()
|
||||
recommendations.append(recommendation_result['choices'][0]['message']['content'])
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logging.error(f"Error parsing DeepSeek response: {str(e)}")
|
||||
logging.error(f"Response text: {recommendation_response.text}")
|
||||
recommendations.append("Recommendation could not be generated due to parsing errors.")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Error calling DeepSeek API: {str(e)}")
|
||||
recommendations.append("Recommendation could not be generated due to API connection issues.")
|
||||
|
||||
return recommendations
|
||||
|
||||
def _store_document(self, doc_id: str, file_path: str):
|
||||
# save document to vector store
|
||||
self.vector_store.add_document(doc_id, file_path)
|
||||
@@ -0,0 +1,57 @@
|
||||
import cohere
|
||||
from typing import List, Union
|
||||
from services.config import config
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self):
|
||||
self.cohere_client = cohere.Client(config.COHERE_API_KEY)
|
||||
self.model = config.COHERE_EMBEDDING_MODEL
|
||||
|
||||
def create_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Create an embedding for a single text using Cohere.
|
||||
|
||||
Args:
|
||||
text (str): The text to create an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
response = self.cohere_client.embed(
|
||||
texts=[text],
|
||||
model=self.model,
|
||||
input_type="search_document"
|
||||
)
|
||||
return response.embeddings[0]
|
||||
|
||||
def create_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Create embeddings for multiple texts using Cohere.
|
||||
|
||||
Args:
|
||||
texts (List[str]): List of texts to create embeddings for
|
||||
|
||||
Returns:
|
||||
List[List[float]]: List of embedding vectors
|
||||
"""
|
||||
response = self.cohere_client.embed(
|
||||
texts=texts,
|
||||
model=self.model,
|
||||
input_type="search_document",
|
||||
dimension=config.VECTOR_DIMENSION
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def create_embedding_from_file(self, file_path: str) -> List[float]:
|
||||
"""
|
||||
Create an embedding from a file's contents.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file to create an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return self.create_embedding(content)
|
||||
@@ -0,0 +1,139 @@
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
from typing import List, Any, Optional
|
||||
from services.config import config
|
||||
from services.embedding_service import EmbeddingService
|
||||
import logging
|
||||
import os
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self, pinecone_client: Optional[Pinecone] = None, embedding_service: Optional[EmbeddingService] = None):
|
||||
self.pinecone = pinecone_client or Pinecone(api_key=config.PINECONE_API_KEY)
|
||||
self.index_name = config.PINECONE_INDEX_NAME
|
||||
self.embedding_service = embedding_service or EmbeddingService()
|
||||
self._ensure_index()
|
||||
|
||||
def _ensure_index(self):
|
||||
"""Ensure the Pinecone index exists, create if it doesn't."""
|
||||
try:
|
||||
# Check if index exists, create if it doesn't
|
||||
if self.index_name not in self.pinecone.list_indexes().names():
|
||||
# Create a new index with the correct dimension
|
||||
self.pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=config.VECTOR_DIMENSION, # Using the dimension from config
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||
)
|
||||
logging.info(f"Created new index '{self.index_name}' with dimension {config.VECTOR_DIMENSION}")
|
||||
|
||||
self.index = self.pinecone.Index(self.index_name)
|
||||
|
||||
# Check if the index dimension matches the config dimension
|
||||
self._check_index_dimension()
|
||||
except Exception as e:
|
||||
logging.error(f"Error ensuring index exists: {str(e)}")
|
||||
raise
|
||||
|
||||
def _check_index_dimension(self):
|
||||
"""Check if the index dimension matches the config dimension and fix if needed."""
|
||||
try:
|
||||
# Get the index description
|
||||
index_description = self.pinecone.describe_index(self.index_name)
|
||||
index_dimension = index_description.dimension
|
||||
|
||||
if index_dimension != config.VECTOR_DIMENSION:
|
||||
logging.warning(f"Index dimension {index_dimension} does not match config dimension {config.VECTOR_DIMENSION}")
|
||||
logging.info("Recreating index with correct dimension...")
|
||||
|
||||
# Delete the existing index
|
||||
self.pinecone.delete_index(self.index_name)
|
||||
|
||||
# Create a new index with the correct dimension
|
||||
self.pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=config.VECTOR_DIMENSION,
|
||||
metric="cosine",
|
||||
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
||||
)
|
||||
|
||||
# Reinitialize the index
|
||||
self.index = self.pinecone.Index(self.index_name)
|
||||
logging.info(f"Index recreated with dimension {config.VECTOR_DIMENSION}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error checking index dimension: {str(e)}")
|
||||
raise
|
||||
|
||||
def store_embedding(self, doc_id: str, embedding: List[float], content: str):
|
||||
"""Store document embedding in Pinecone."""
|
||||
try:
|
||||
# Verify embedding dimension matches the index dimension
|
||||
if len(embedding) != config.VECTOR_DIMENSION:
|
||||
raise ValueError(f"Embedding dimension {len(embedding)} does not match index dimension {config.VECTOR_DIMENSION}")
|
||||
|
||||
self.index.upsert(
|
||||
vectors=[{
|
||||
"id": doc_id,
|
||||
"values": embedding,
|
||||
"metadata": {
|
||||
"content": content
|
||||
}
|
||||
}]
|
||||
)
|
||||
logging.info(f"Stored embedding for document {doc_id}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error storing embedding for document {doc_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Any]:
|
||||
"""Search for similar documents."""
|
||||
try:
|
||||
# Verify query embedding dimension matches the index dimension
|
||||
if len(query_embedding) != config.VECTOR_DIMENSION:
|
||||
raise ValueError(f"Query embedding dimension {len(query_embedding)} does not match index dimension {config.VECTOR_DIMENSION}")
|
||||
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
top_k=top_k,
|
||||
include_metadata=True
|
||||
)
|
||||
return results.matches
|
||||
except Exception as e:
|
||||
logging.error(f"Error searching for similar documents: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_document(self, doc_id: str):
|
||||
"""Delete a document from the index."""
|
||||
try:
|
||||
self.index.delete(ids=[doc_id])
|
||||
logging.info(f"Deleted document {doc_id} from index")
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting document {doc_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def add_document(self, doc_id: str, file_path: str):
|
||||
"""Add a document to the index."""
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# read document content with error handling for encoding
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
except UnicodeDecodeError:
|
||||
# Try with a different encoding if UTF-8 fails
|
||||
with open(file_path, "r", encoding="latin-1") as file:
|
||||
content = file.read()
|
||||
|
||||
# create embedding
|
||||
embedding = self.embedding_service.create_embedding(content)
|
||||
|
||||
# store embedding
|
||||
logging.info(f"Storing embedding for document {doc_id}")
|
||||
self.store_embedding(doc_id, embedding, content)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error adding document {doc_id}: {str(e)}")
|
||||
raise
|
||||
Reference in New Issue
Block a user