Implement database integration for transactions and receipts, including CRUD operations and data retrieval endpoints

This commit is contained in:
bolade
2025-08-07 10:58:35 +01:00
parent 55ffc52339
commit 1784d2e406
2 changed files with 363 additions and 172 deletions
+75
View File
@@ -0,0 +1,75 @@
from typing import Annotated
from fastapi import Depends
from sqlalchemy import Column, DateTime, Float, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db_dependency = Annotated[Session, Depends(get_db)]
Base = declarative_base()
def create_db_tables():
Base.metadata.create_all(bind=engine)
def clear_all_data():
"""Clear all data from the database (useful for testing)"""
db = SessionLocal()
try:
db.query(Transaction).delete()
db.query(Receipt).delete()
db.commit()
finally:
db.close()
# Transactions table
class Transaction(Base):
__tablename__ = "transactions"
id = Column(Integer, primary_key=True, index=True)
transaction_id = Column(String, unique=True, index=True)
amount = Column(Float, nullable=False)
date = Column(DateTime, nullable=False)
vendor = Column(String, nullable=False)
description = Column(String, nullable=True)
category = Column(String, nullable=True)
tax_amount = Column(Float, nullable=True)
categorisation_id = Column(String, nullable=True)
user_id = Column(String, nullable=True)
# Receipts table
class Receipt(Base):
__tablename__ = "receipts"
id = Column(Integer, primary_key=True, index=True)
receipt_id = Column(String, unique=True, index=True)
file_id = Column(String, unique=True, index=True)
amount = Column(Float, nullable=False)
date = Column(DateTime, nullable=False)
vendor = Column(String, nullable=False)
description = Column(String, nullable=True)
category = Column(String, nullable=True)
tax_amount = Column(Float, nullable=True)
confidence = Column(Float, nullable=True)
extraction_success = Column(String, nullable=True)
error_message = Column(String, nullable=True)
+288 -172
View File
@@ -7,6 +7,7 @@ from typing import List
from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from ai_rules import AIRule from ai_rules import AIRule
from api_models import ( from api_models import (
@@ -17,10 +18,15 @@ from api_models import (
MatchSpecificRequest, MatchSpecificRequest,
RuleRequest, RuleRequest,
) )
from database import Receipt as DBReceipt
from database import Transaction as DBTransaction
from database import create_db_tables, db_dependency
from document_processor import DocumentProcessor from document_processor import DocumentProcessor
from matching_engine import MatchingEngine from matching_engine import MatchingEngine
from models import Receipt, Transaction from models import Receipt, Transaction
create_db_tables()
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -50,9 +56,28 @@ document_processor = DocumentProcessor()
# In-memory storage for uploaded files (in production, use a database) # In-memory storage for uploaded files (in production, use a database)
uploaded_files = {} uploaded_files = {}
# Store imported transactions globally for easy access
stored_transactions = [] # Helper functions for database operations
processed_receipts = {} def get_transactions_from_db(
db: Session, user_id: str = None, categorization_id: str = None
):
"""Retrieve transactions from database"""
query = db.query(DBTransaction)
if user_id:
query = query.filter(DBTransaction.user_id == user_id)
if categorization_id:
query = query.filter(DBTransaction.categorisation_id == categorization_id)
return query.all()
def get_receipt_from_db(db: Session, file_id: str):
"""Retrieve receipt from database by file_id"""
return db.query(DBReceipt).filter(DBReceipt.file_id == file_id).first()
def get_receipts_from_db(db: Session, file_ids: List[str]):
"""Retrieve multiple receipts from database by file_ids"""
return db.query(DBReceipt).filter(DBReceipt.file_id.in_(file_ids)).all()
@app.get("/") @app.get("/")
@@ -138,6 +163,7 @@ async def root():
@app.post("/transactions/import/csv") @app.post("/transactions/import/csv")
async def import_transactions_csv( async def import_transactions_csv(
db: db_dependency,
file: UploadFile = File(...), file: UploadFile = File(...),
categorization_id: str = Form(...), categorization_id: str = Form(...),
user_id: str = Form(...), user_id: str = Form(...),
@@ -182,6 +208,22 @@ async def import_transactions_csv(
raise ValueError(f"Could not parse date: {txn_date_str}") raise ValueError(f"Could not parse date: {txn_date_str}")
# Parse amount # Parse amount
amount = float(amount_raw.replace(",", "").strip()) amount = float(amount_raw.replace(",", "").strip())
# Create database transaction object
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
db_transaction = DBTransaction(
transaction_id=txn_id,
amount=amount,
date=txn_date_obj,
vendor=payee_name.strip(),
description=memo,
categorisation_id=categorization_id,
user_id=user_id,
)
# Add to database
db.add(db_transaction)
transactions.append( transactions.append(
{ {
"id": txn_id, "id": txn_id,
@@ -189,13 +231,15 @@ async def import_transactions_csv(
"amount": amount, "amount": amount,
"payee_name": payee_name.strip(), "payee_name": payee_name.strip(),
"memo": memo, "memo": memo,
"categorization_id": categorization_id,
"user_id": user_id,
} }
) )
except Exception as e: except Exception as e:
errors.append(f"Row {idx + 1}: {str(e)}") errors.append(f"Row {idx + 1}: {str(e)}")
# Store transactions globally for auto-matching
global stored_transactions # Commit all transactions to database
stored_transactions = transactions db.commit()
return { return {
"imported_count": len(transactions), "imported_count": len(transactions),
@@ -209,7 +253,12 @@ async def import_transactions_csv(
@app.post("/transactions/import/image") @app.post("/transactions/import/image")
async def import_transactions_from_image(file: UploadFile = File(...)): async def import_transactions_from_image(
db: db_dependency,
file: UploadFile = File(...),
categorization_id: str = Form("image_import"),
user_id: str = Form("default"),
):
""" """
Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction. Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction.
""" """
@@ -236,9 +285,9 @@ async def import_transactions_from_image(file: UploadFile = File(...)):
detail=extraction_result.get("error", "Extraction failed"), detail=extraction_result.get("error", "Extraction failed"),
) )
extracted_transactions = extraction_result.get("transactions", []) extracted_transactions = extraction_result.get("transactions", [])
# Store transactions globally for auto-matching
global stored_transactions # Store transactions in database
stored_transactions = [] transactions = []
for idx, txn in enumerate(extracted_transactions): for idx, txn in enumerate(extracted_transactions):
try: try:
txn_id = f"img_{file.filename}_{idx + 1}" txn_id = f"img_{file.filename}_{idx + 1}"
@@ -253,7 +302,24 @@ async def import_transactions_from_image(file: UploadFile = File(...)):
# Fallback: use current year if parsing fails # Fallback: use current year if parsing fails
txn_date = f"2024-{txn_date_raw}" txn_date = f"2024-{txn_date_raw}"
stored_transactions.append( # Parse date for database
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
# Create database transaction object
db_transaction = DBTransaction(
transaction_id=txn_id,
amount=float(amount),
date=txn_date_obj,
vendor=vendor,
description=memo,
categorisation_id=categorization_id,
user_id=user_id,
)
# Add to database
db.add(db_transaction)
transactions.append(
{ {
"id": txn_id, "id": txn_id,
"txn_date": txn_date, "txn_date": txn_date,
@@ -262,11 +328,16 @@ async def import_transactions_from_image(file: UploadFile = File(...)):
"memo": memo, "memo": memo,
} }
) )
except Exception: except Exception as e:
logger.warning(f"Error processing transaction {idx}: {str(e)}")
continue continue
# Commit all transactions to database
db.commit()
return { return {
"imported_count": len(stored_transactions), "imported_count": len(transactions),
"converted_transactions": stored_transactions, "converted_transactions": transactions,
"errors": [], "errors": [],
} }
except Exception as e: except Exception as e:
@@ -330,7 +401,7 @@ async def upload_multiple_documents(files: List[UploadFile] = File(...)):
@app.post("/process/{file_id}", response_model=DocumentProcessResponse) @app.post("/process/{file_id}", response_model=DocumentProcessResponse)
async def process_document(file_id: str): async def process_document(file_id: str, db: db_dependency):
""" """
Process a previously uploaded document to extract receipt information. Process a previously uploaded document to extract receipt information.
@@ -351,8 +422,34 @@ async def process_document(file_id: str):
file_type = file_data["filename"].split(".")[-1].lower() file_type = file_data["filename"].split(".")[-1].lower()
receipt_data = await document_processor.process_file(file_path, file_type) receipt_data = await document_processor.process_file(file_path, file_type)
# Store processed receipt # Parse date for database storage
processed_receipts[file_id] = receipt_data receipt_date = None
if receipt_data.get("date"):
try:
receipt_date = datetime.strptime(receipt_data["date"], "%Y-%m-%d")
except ValueError:
receipt_date = datetime.now()
else:
receipt_date = datetime.now()
# Create database receipt object
db_receipt = DBReceipt(
receipt_id=f"receipt_{file_id}",
file_id=file_id,
amount=receipt_data.get("total_amount", 0.0),
date=receipt_date,
vendor=receipt_data.get("vendor", ""),
description=receipt_data.get("description", ""),
category=receipt_data.get("category", ""),
tax_amount=receipt_data.get("tax_amount", 0.0),
confidence=receipt_data.get("confidence", 0.0),
extraction_success=str(receipt_data.get("extraction_success", False)),
error_message=receipt_data.get("error"),
)
# Add to database
db.add(db_receipt)
db.commit()
return DocumentProcessResponse( return DocumentProcessResponse(
file_id=file_id, file_id=file_id,
@@ -610,7 +707,7 @@ async def process_document(file_id: str):
@app.post("/match-specific", response_model=MatchingResponse) @app.post("/match-specific", response_model=MatchingResponse)
async def match_specific_receipts(request: MatchSpecificRequest): async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependency):
""" """
Match specific receipts against imported transactions. Match specific receipts against imported transactions.
@@ -624,196 +721,131 @@ async def match_specific_receipts(request: MatchSpecificRequest):
f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}" f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}"
) )
# Check if transactions are imported # Get transactions from database
if not stored_transactions: db_transactions = get_transactions_from_db(
logger.warning("No transactions imported") db, categorization_id=categorization_id
)
if not db_transactions:
logger.warning("No transactions found in database")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="No transactions imported. Please upload CSV first.", detail="No transactions found. Please upload CSV first.",
) )
logger.info(f"Found {len(stored_transactions)} stored transactions") logger.info(f"Found {len(db_transactions)} transactions in database")
# Convert stored transactions to Transaction objects # Convert database transactions to Transaction objects
transactions = [] transactions = []
for txn in stored_transactions: for db_txn in db_transactions:
try: try:
txn_date = datetime.strptime(txn["txn_date"], "%Y-%m-%d")
transaction = Transaction( transaction = Transaction(
id=txn["id"], id=db_txn.transaction_id,
transaction_date=txn_date, transaction_date=db_txn.date,
amount=txn["amount"], amount=db_txn.amount,
vendor=txn["payee_name"], vendor=db_txn.vendor,
notes=txn["memo"], notes=db_txn.description or "",
) )
transactions.append(transaction) transactions.append(transaction)
except Exception as e: except Exception as e:
logger.warning(f"Error converting transaction {txn['id']}: {str(e)}") logger.warning(
f"Error converting transaction {db_txn.transaction_id}: {str(e)}"
)
continue continue
logger.info(f"Converted {len(transactions)} transactions") logger.info(f"Converted {len(transactions)} transactions")
# Get receipts for the specified file IDs # Get receipts for the specified file IDs from database
db_receipts = get_receipts_from_db(db, file_ids)
receipts = [] receipts = []
missing_files = [] missing_files = []
for file_id in file_ids: for file_id in file_ids:
if file_id in processed_receipts: # Find the corresponding database receipt
receipt_data = processed_receipts[file_id] db_receipt = next((r for r in db_receipts if r.file_id == file_id), None)
logger.info(f"DEBUG: receipt_data for {file_id}: {receipt_data}")
logger.info( if db_receipt:
f"DEBUG: receipt_data keys for {file_id}: {list(receipt_data.keys())}"
)
try: try:
# Handle missing date field
if "date" not in receipt_data or not receipt_data["date"]:
logger.warning(
f"Missing date for receipt {file_id}, using current date"
)
receipt_date = datetime.now()
else:
receipt_date = datetime.strptime(
receipt_data["date"], "%Y-%m-%d"
)
# Handle missing amount field - try multiple possible keys
amount = receipt_data.get("amount")
if amount is None:
amount = receipt_data.get("total_amount")
if amount is None:
amount = receipt_data.get("amount_total")
if amount is None:
logger.warning(
f"Missing amount for receipt {file_id}, using 0.0"
)
amount = 0.0
# Ensure amount is a float
try:
amount = float(amount)
except (ValueError, TypeError):
logger.warning(
f"Invalid amount '{amount}' for receipt {file_id}, using 0.0"
)
amount = 0.0
logger.info(f"DEBUG: amount for {file_id}: {amount}")
# Handle missing vendor field
vendor = receipt_data.get("vendor", "")
if not vendor:
logger.warning(
f"Missing vendor for receipt {file_id}, using 'Unknown'"
)
vendor = "Unknown"
# Handle missing category field
category = receipt_data.get("category", "Other")
# Handle description field
description = receipt_data.get("description", "")
# Handle tax field
tax = receipt_data.get("tax", receipt_data.get("tax_amount", 0.0))
try:
tax = float(tax)
except (ValueError, TypeError):
tax = 0.0
receipt = Receipt( receipt = Receipt(
id=file_id, id=db_receipt.receipt_id,
file_name=uploaded_files[file_id]["filename"], receipt_date=db_receipt.date,
upload_date=uploaded_files[file_id]["upload_date"], amount=db_receipt.amount,
receipt_date=receipt_date, vendor=db_receipt.vendor,
amount=amount, category=db_receipt.category or "Other",
tax=tax, description=db_receipt.description or "",
vendor=vendor, tax=db_receipt.tax_amount or 0.0,
category=category, file_name=db_receipt.file_id,
description=description, upload_date=datetime.now(),
) )
receipts.append(receipt) receipts.append(receipt)
logger.info(f"Added receipt: {receipt.vendor} - ${receipt.amount}") logger.info(f"Successfully loaded receipt for file_id: {file_id}")
except Exception as e: except Exception as e:
logger.warning( logger.error(
f"Error creating receipt object for {file_id}: {str(e)}" f"Error creating receipt object for {file_id}: {str(e)}"
) )
missing_files.append(f"{file_id} (error: {str(e)})") missing_files.append(file_id)
else: else:
logger.warning(f"Receipt {file_id} not found in processed_receipts") logger.warning(f"Receipt {file_id} not found in database")
missing_files.append(f"{file_id} (not found)") missing_files.append(file_id)
logger.info(f"Found {len(receipts)} receipts, {len(missing_files)} missing")
if missing_files: if missing_files:
logger.error(f"Missing files: {missing_files}") logger.warning(f"Missing files: {missing_files}")
raise HTTPException(
status_code=400, detail=f"Missing files: {missing_files}"
)
logger.info( if not receipts:
f"Processing {len(receipts)} receipts against {len(transactions)} transactions" logger.warning("No valid receipts found")
) raise HTTPException(
status_code=400,
detail="No valid receipts found for matching.",
)
# Perform matching # Perform matching
logger.info(
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
)
try: try:
logger.info("Starting direct matching call (without ThreadPoolExecutor)") matching_results = matching_engine.process_matching(receipts, transactions)
logger.info(f"matching_engine type: {type(matching_engine)}") logger.info(f"Matching completed, got {len(matching_results)} results")
logger.info(
f"matching_engine.process_matching type: {type(matching_engine.process_matching)}"
)
logger.info(f"receipts type: {type(receipts)}, length: {len(receipts)}")
logger.info(
f"transactions type: {type(transactions)}, length: {len(transactions)}"
)
matches = matching_engine.process_matching(receipts, transactions) # Convert matching results to response format
logger.info(
f"Matching completed successfully. Found {len(matches)} matches"
)
# Convert matches to response format
match_responses = [] match_responses = []
for match in matches: for result in matching_results:
logger.info(f"Raw match object: {match}")
logger.info(f" receipt_id: {match.receipt.id}")
logger.info(f" transaction_id: {match.transaction.id}")
logger.info(f" confidence_score: {match.confidence_score}")
logger.info(f" match_reason: {match.match_reason}")
logger.info(f" receipt_vendor: {match.receipt.vendor}")
logger.info(f" receipt_amount: {match.receipt.amount}")
logger.info(f" transaction_vendor: {match.transaction.vendor}")
logger.info(f" transaction_amount: {match.transaction.amount}")
match_response = MatchResponse( match_response = MatchResponse(
receipt_id=match.receipt.id, receipt_id=result.receipt.id,
transaction_id=match.transaction.id, transaction_id=result.transaction.id
confidence_score=match.confidence_score, if result.transaction
match_reason=match.match_reason, else "no_match",
receipt_vendor=match.receipt.vendor, confidence_score=result.confidence_score,
receipt_amount=match.receipt.amount, match_reason=result.match_reason,
receipt_description=match.receipt.description, receipt_vendor=result.receipt.vendor,
receipt_category=match.receipt.category, receipt_amount=result.receipt.amount,
receipt_tax_amount=match.receipt.tax, receipt_description=result.receipt.description,
transaction_vendor=match.transaction.vendor, receipt_category=result.receipt.category,
transaction_amount=match.transaction.amount, receipt_tax_amount=result.receipt.tax,
transaction_vendor=result.transaction.vendor
if result.transaction
else "",
transaction_amount=result.transaction.amount
if result.transaction
else 0.0,
) )
match_responses.append(match_response) match_responses.append(match_response)
logger.info(
f"Successfully created MatchResponse for {match.receipt.vendor} -> {match.transaction.vendor}"
)
logger.info(f"Formatted {len(match_responses)} match responses")
# Calculate statistics # Calculate statistics
if match_responses: high_confidence = len(
high_confidence = sum( [r for r in matching_results if r.confidence_score >= 0.8]
1 for m in match_responses if m.confidence_score >= 0.8 )
) low_confidence = len(
low_confidence = len(match_responses) - high_confidence [r for r in matching_results if r.confidence_score < 0.5]
avg_score = sum(m.confidence_score for m in match_responses) / len( )
match_responses avg_score = (
) sum(r.confidence_score for r in matching_results)
else: / len(matching_results)
high_confidence = low_confidence = avg_score = 0 if matching_results
else 0
)
stats = { stats = {
"total": len(match_responses), "total": len(match_responses),
@@ -833,7 +865,6 @@ async def match_specific_receipts(request: MatchSpecificRequest):
logger.error(f"Exception in matching section: {str(e)}") logger.error(f"Exception in matching section: {str(e)}")
logger.error(f"Exception type: {type(e)}") logger.error(f"Exception type: {type(e)}")
logger.error(f"Exception args: {e.args}") logger.error(f"Exception args: {e.args}")
logger.error(f"Traceback: {e.__traceback__}")
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Unexpected matching error: {str(e)}" status_code=500, detail=f"Unexpected matching error: {str(e)}"
) )
@@ -845,6 +876,87 @@ async def match_specific_receipts(request: MatchSpecificRequest):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# DATABASE QUERY ENDPOINTS
# ============================================================================
@app.get("/transactions")
async def get_transactions(
db: db_dependency,
user_id: str = None,
categorization_id: str = None,
limit: int = 100,
):
"""
Get transactions from the database.
"""
try:
transactions = get_transactions_from_db(db, user_id, categorization_id)
# Limit results
transactions = transactions[:limit]
# Convert to response format
result = []
for txn in transactions:
result.append(
{
"id": txn.transaction_id,
"amount": txn.amount,
"date": txn.date.strftime("%Y-%m-%d"),
"vendor": txn.vendor,
"description": txn.description,
"category": txn.category,
"tax_amount": txn.tax_amount,
"categorisation_id": txn.categorisation_id,
"user_id": txn.user_id,
}
)
return {
"transactions": result,
"count": len(result),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/receipts")
async def get_receipts(db: db_dependency, limit: int = 100):
"""
Get receipts from the database.
"""
try:
receipts = db.query(DBReceipt).limit(limit).all()
# Convert to response format
result = []
for receipt in receipts:
result.append(
{
"id": receipt.receipt_id,
"file_id": receipt.file_id,
"amount": receipt.amount,
"date": receipt.date.strftime("%Y-%m-%d"),
"vendor": receipt.vendor,
"description": receipt.description,
"category": receipt.category,
"tax_amount": receipt.tax_amount,
"confidence": receipt.confidence,
"extraction_success": receipt.extraction_success,
"error_message": receipt.error_message,
}
)
return {
"receipts": result,
"count": len(result),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================ # ============================================================================
# RULES MANAGEMENT ENDPOINTS # RULES MANAGEMENT ENDPOINTS
# ============================================================================ # ============================================================================
@@ -921,14 +1033,18 @@ async def delete_rule(rule_name: str):
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats(db: db_dependency):
""" """
Get system statistics. Get system statistics.
""" """
try: try:
# Count transactions and receipts from database
total_transactions = db.query(DBTransaction).count()
total_receipts = db.query(DBReceipt).count()
return { return {
"total_transactions": len(stored_transactions), "total_transactions": total_transactions,
"total_receipts": len(processed_receipts), "total_receipts": total_receipts,
"total_uploaded_files": len(uploaded_files), "total_uploaded_files": len(uploaded_files),
"rules_count": len(matching_engine.rules_engine.rules), "rules_count": len(matching_engine.rules_engine.rules),
} }