refactor: Transition to SQLAlchemy for database management

- Replaced SQLite direct connections with SQLAlchemy ORM for better abstraction and maintainability.
- Introduced new database models for 'Analysis' and 'Metadata' with appropriate fields.
- Enhanced database initialization and session management.
- Updated methods for saving, retrieving, and deleting analysis and metadata records to use SQLAlchemy sessions.
This commit is contained in:
boladeE
2025-04-23 14:27:15 +01:00
parent f4cb9dfa92
commit 932f76b603
2 changed files with 130 additions and 148 deletions
+1
View File
@@ -6,3 +6,4 @@ cohere==4.47
groq==0.4.2 groq==0.4.2
python-dotenv==1.0.1 python-dotenv==1.0.1
pydantic==2.6.3 pydantic==2.6.3
sqlalchemy==2.0.27
+128 -147
View File
@@ -1,116 +1,98 @@
import sqlite3
import json
import logging
from typing import Dict, Any, Optional
import os import os
from datetime import datetime
from typing import Dict, Any, List, Optional
from sqlalchemy import create_engine, Column, String, DateTime, Integer, Boolean, event, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import Engine
import logging
import json
Base = declarative_base()
# Enable foreign key support for SQLite
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
class Analysis(Base):
__tablename__ = 'analysis'
document_id = Column(String, primary_key=True)
summary = Column(String)
issues_and_recommendations = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
class Metadata(Base):
__tablename__ = 'metadata'
document_id = Column(String, primary_key=True)
filename = Column(String)
document_type = Column(String)
description = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)
class Database: class Database:
def __init__(self, db_path: str = "data/app.db"): def __init__(self, db_path: str = "data/app.db"):
self.db_path = db_path self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
# Configure SQLite engine with better defaults
self.engine = create_engine(
f'sqlite:///{db_path}',
connect_args={
'check_same_thread': False, # Needed for FastAPI
'timeout': 30, # Set a reasonable timeout
},
pool_pre_ping=True, # Check connections before using them
pool_recycle=3600, # Recycle connections after an hour
)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
self._init_db() self._init_db()
def _init_db(self): def _init_db(self):
"""Initialize the database with required tables.""" """Initialize the database with required tables."""
try: try:
with sqlite3.connect(self.db_path) as conn: Base.metadata.create_all(bind=self.engine)
cursor = conn.cursor() # Set SQLite-specific optimizations
with self.engine.connect() as conn:
# Check if we need to migrate the old schema conn.execute(text("PRAGMA journal_mode=WAL")) # Write-Ahead Logging
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='analysis'") conn.execute(text("PRAGMA synchronous=NORMAL")) # Better performance
table_exists = cursor.fetchone() is not None conn.execute(text("PRAGMA cache_size=-2000")) # Use 2MB of memory for cache
conn.execute(text("PRAGMA temp_store=MEMORY")) # Store temp tables in memory
if table_exists:
# Check if we need to migrate
cursor.execute("PRAGMA table_info(analysis)")
columns = [column[1] for column in cursor.fetchall()]
if 'issues_and_recommendations' not in columns:
# Backup old data
cursor.execute("SELECT document_id, summary, issues, recommendations FROM analysis")
old_data = cursor.fetchall()
# Drop the old table
cursor.execute("DROP TABLE analysis")
# Create the new table
cursor.execute('''
CREATE TABLE analysis (
document_id TEXT PRIMARY KEY,
summary TEXT,
issues_and_recommendations TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Migrate old data to new format
for row in old_data:
doc_id, summary, issues, recommendations = row
try:
old_issues = json.loads(issues) if issues else []
old_recommendations = json.loads(recommendations) if recommendations else []
# Combine issues and recommendations
issues_and_recommendations = []
for i in range(max(len(old_issues), len(old_recommendations))):
issue = old_issues[i]['issue'] if i < len(old_issues) else "Unknown Issue"
recommendation = old_recommendations[i] if i < len(old_recommendations) else "No recommendation provided"
issues_and_recommendations.append({
"issue": issue,
"recommendation": recommendation
})
cursor.execute('''
INSERT INTO analysis (document_id, summary, issues_and_recommendations)
VALUES (?, ?, ?)
''', (
doc_id,
summary,
json.dumps(issues_and_recommendations)
))
except Exception as e:
logging.error(f"Error migrating data for document {doc_id}: {str(e)}")
else:
# Create the new table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS analysis (
document_id TEXT PRIMARY KEY,
summary TEXT,
issues_and_recommendations TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# Create metadata table
cursor.execute('''
CREATE TABLE IF NOT EXISTS metadata (
document_id TEXT PRIMARY KEY,
filename TEXT,
document_type TEXT,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit() conn.commit()
except Exception as e: except Exception as e:
logging.error(f"Error initializing database: {str(e)}") logging.error(f"Error initializing database: {str(e)}")
raise raise
def get_db(self):
"""Get a database session."""
db = self.SessionLocal()
try:
yield db
finally:
db.close()
def save_analysis(self, document_id: str, analysis: Dict[str, Any]): def save_analysis(self, document_id: str, analysis: Dict[str, Any]):
"""Save analysis results to the database.""" """Save analysis results to the database."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() # Check if record exists
cursor.execute(''' existing = session.query(Analysis).filter(Analysis.document_id == document_id).first()
INSERT OR REPLACE INTO analysis (document_id, summary, issues_and_recommendations) if existing:
VALUES (?, ?, ?) # Update existing record
''', ( existing.summary = analysis['summary']
document_id, existing.issues_and_recommendations = json.dumps(analysis['issues_and_recommendations'])
analysis['summary'], else:
json.dumps(analysis['issues_and_recommendations']) # Create new record
)) analysis_record = Analysis(
conn.commit() document_id=document_id,
summary=analysis['summary'],
issues_and_recommendations=json.dumps(analysis['issues_and_recommendations'])
)
session.add(analysis_record)
session.commit()
except Exception as e: except Exception as e:
logging.error(f"Error saving analysis for document {document_id}: {str(e)}") logging.error(f"Error saving analysis for document {document_id}: {str(e)}")
raise raise
@@ -118,18 +100,15 @@ class Database:
def get_analysis(self, document_id: str) -> Dict[str, Any]: def get_analysis(self, document_id: str) -> Dict[str, Any]:
"""Retrieve analysis results from the database.""" """Retrieve analysis results from the database."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() analysis = session.query(Analysis).filter(Analysis.document_id == document_id).first()
cursor.execute('SELECT summary, issues_and_recommendations FROM analysis WHERE document_id = ?', (document_id,)) if not analysis:
result = cursor.fetchone()
if not result:
raise FileNotFoundError(f"Analysis not found for document {document_id}") raise FileNotFoundError(f"Analysis not found for document {document_id}")
return { return {
'document_id': document_id, 'document_id': analysis.document_id,
'summary': result[0], 'summary': analysis.summary,
'issues_and_recommendations': json.loads(result[1]) 'issues_and_recommendations': json.loads(analysis.issues_and_recommendations)
} }
except Exception as e: except Exception as e:
logging.error(f"Error retrieving analysis for document {document_id}: {str(e)}") logging.error(f"Error retrieving analysis for document {document_id}: {str(e)}")
@@ -138,18 +117,24 @@ class Database:
def save_metadata(self, document_id: str, metadata: Dict[str, Any]): def save_metadata(self, document_id: str, metadata: Dict[str, Any]):
"""Save document metadata to the database.""" """Save document metadata to the database."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() # Check if record exists
cursor.execute(''' existing = session.query(Metadata).filter(Metadata.document_id == document_id).first()
INSERT OR REPLACE INTO metadata (document_id, filename, document_type, description) if existing:
VALUES (?, ?, ?, ?) # Update existing record
''', ( existing.filename = metadata['filename']
document_id, existing.document_type = metadata['document_type']
metadata['filename'], existing.description = metadata.get('description')
metadata['document_type'], else:
metadata.get('description') # Create new record
)) metadata_record = Metadata(
conn.commit() document_id=document_id,
filename=metadata['filename'],
document_type=metadata['document_type'],
description=metadata.get('description')
)
session.add(metadata_record)
session.commit()
except Exception as e: except Exception as e:
logging.error(f"Error saving metadata for document {document_id}: {str(e)}") logging.error(f"Error saving metadata for document {document_id}: {str(e)}")
raise raise
@@ -157,45 +142,42 @@ class Database:
def get_metadata(self, document_id: str) -> Dict[str, Any]: def get_metadata(self, document_id: str) -> Dict[str, Any]:
"""Retrieve document metadata from the database.""" """Retrieve document metadata from the database."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() metadata = session.query(Metadata).filter(Metadata.document_id == document_id).first()
cursor.execute('SELECT filename, document_type, description FROM metadata WHERE document_id = ?', (document_id,)) if not metadata:
result = cursor.fetchone()
if not result:
raise FileNotFoundError(f"Metadata not found for document {document_id}") raise FileNotFoundError(f"Metadata not found for document {document_id}")
return { return {
'document_id': document_id, 'document_id': metadata.document_id,
'filename': result[0], 'filename': metadata.filename,
'document_type': result[1], 'document_type': metadata.document_type,
'description': result[2] 'description': metadata.description
} }
except Exception as e: except Exception as e:
logging.error(f"Error retrieving metadata for document {document_id}: {str(e)}") logging.error(f"Error retrieving metadata for document {document_id}: {str(e)}")
raise raise
def get_all_metadata(self) -> list: def get_all_metadata(self) -> List[Dict[str, Any]]:
"""Retrieve metadata for all documents.""" """Retrieve metadata for all documents."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() results = session.query(
cursor.execute(''' Metadata,
SELECT m.document_id, m.filename, m.document_type, m.description, m.created_at, Analysis.document_id.isnot(None).label('has_analysis')
CASE WHEN a.document_id IS NOT NULL THEN 1 ELSE 0 END as has_analysis ).outerjoin(
FROM metadata m Analysis,
LEFT JOIN analysis a ON m.document_id = a.document_id Metadata.document_id == Analysis.document_id
ORDER BY m.created_at DESC ).order_by(
''') Metadata.created_at.desc()
results = cursor.fetchall() ).all()
return [{ return [{
'document_id': row[0], 'document_id': row[0].document_id,
'filename': row[1], 'filename': row[0].filename,
'document_type': row[2], 'document_type': row[0].document_type,
'description': row[3], 'description': row[0].description,
'upload_date': row[4], 'upload_date': row[0].created_at,
'status': 'completed' if row[5] == 1 else 'processing' 'status': 'completed' if row[1] else 'processing'
} for row in results] } for row in results]
except Exception as e: except Exception as e:
logging.error(f"Error retrieving all metadata: {str(e)}") logging.error(f"Error retrieving all metadata: {str(e)}")
@@ -204,11 +186,10 @@ class Database:
def delete_document(self, document_id: str): def delete_document(self, document_id: str):
"""Delete a document and its associated data from the database.""" """Delete a document and its associated data from the database."""
try: try:
with sqlite3.connect(self.db_path) as conn: with self.SessionLocal() as session:
cursor = conn.cursor() session.query(Analysis).filter(Analysis.document_id == document_id).delete()
cursor.execute('DELETE FROM analysis WHERE document_id = ?', (document_id,)) session.query(Metadata).filter(Metadata.document_id == document_id).delete()
cursor.execute('DELETE FROM metadata WHERE document_id = ?', (document_id,)) session.commit()
conn.commit()
except Exception as e: except Exception as e:
logging.error(f"Error deleting document {document_id}: {str(e)}") logging.error(f"Error deleting document {document_id}: {str(e)}")
raise raise