From 1784d2e40645cf8d38852bea888c9885a8cc840a Mon Sep 17 00:00:00 2001 From: bolade Date: Thu, 7 Aug 2025 10:58:35 +0100 Subject: [PATCH] Implement database integration for transactions and receipts, including CRUD operations and data retrieval endpoints --- database.py | 75 +++++++++ main.py | 460 ++++++++++++++++++++++++++++++++-------------------- 2 files changed, 363 insertions(+), 172 deletions(-) diff --git a/database.py b/database.py index e69de29..1786d3c 100644 --- a/database.py +++ b/database.py @@ -0,0 +1,75 @@ +from typing import Annotated + +from fastapi import Depends +from sqlalchemy import Column, DateTime, Float, Integer, String, create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker + +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +db_dependency = Annotated[Session, Depends(get_db)] +Base = declarative_base() + + +def create_db_tables(): + Base.metadata.create_all(bind=engine) + + +def clear_all_data(): + """Clear all data from the database (useful for testing)""" + db = SessionLocal() + try: + db.query(Transaction).delete() + db.query(Receipt).delete() + db.commit() + finally: + db.close() + + +# Transactions table +class Transaction(Base): + __tablename__ = "transactions" + + id = Column(Integer, primary_key=True, index=True) + transaction_id = Column(String, unique=True, index=True) + amount = Column(Float, nullable=False) + date = Column(DateTime, nullable=False) + vendor = Column(String, nullable=False) + description = Column(String, nullable=True) + category = Column(String, nullable=True) + tax_amount = Column(Float, nullable=True) + categorisation_id = Column(String, nullable=True) + user_id = Column(String, nullable=True) + + +# Receipts table +class Receipt(Base): + __tablename__ = "receipts" + + id = Column(Integer, primary_key=True, index=True) + receipt_id = Column(String, unique=True, index=True) + file_id = Column(String, unique=True, index=True) + amount = Column(Float, nullable=False) + date = Column(DateTime, nullable=False) + vendor = Column(String, nullable=False) + description = Column(String, nullable=True) + category = Column(String, nullable=True) + tax_amount = Column(Float, nullable=True) + confidence = Column(Float, nullable=True) + extraction_success = Column(String, nullable=True) + error_message = Column(String, nullable=True) diff --git a/main.py b/main.py index 72bc9d9..1e23bfa 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from typing import List from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy.orm import Session from ai_rules import AIRule from api_models import ( @@ -17,10 +18,15 @@ from api_models import ( MatchSpecificRequest, RuleRequest, ) +from database import Receipt as DBReceipt +from database import Transaction as DBTransaction +from database import create_db_tables, db_dependency from document_processor import DocumentProcessor from matching_engine import MatchingEngine from models import Receipt, Transaction +create_db_tables() + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -50,9 +56,28 @@ document_processor = DocumentProcessor() # In-memory storage for uploaded files (in production, use a database) uploaded_files = {} -# Store imported transactions globally for easy access -stored_transactions = [] -processed_receipts = {} + +# Helper functions for database operations +def get_transactions_from_db( + db: Session, user_id: str = None, categorization_id: str = None +): + """Retrieve transactions from database""" + query = db.query(DBTransaction) + if user_id: + query = query.filter(DBTransaction.user_id == user_id) + if categorization_id: + query = query.filter(DBTransaction.categorisation_id == categorization_id) + return query.all() + + +def get_receipt_from_db(db: Session, file_id: str): + """Retrieve receipt from database by file_id""" + return db.query(DBReceipt).filter(DBReceipt.file_id == file_id).first() + + +def get_receipts_from_db(db: Session, file_ids: List[str]): + """Retrieve multiple receipts from database by file_ids""" + return db.query(DBReceipt).filter(DBReceipt.file_id.in_(file_ids)).all() @app.get("/") @@ -138,6 +163,7 @@ async def root(): @app.post("/transactions/import/csv") async def import_transactions_csv( + db: db_dependency, file: UploadFile = File(...), categorization_id: str = Form(...), user_id: str = Form(...), @@ -182,6 +208,22 @@ async def import_transactions_csv( raise ValueError(f"Could not parse date: {txn_date_str}") # Parse amount amount = float(amount_raw.replace(",", "").strip()) + + # Create database transaction object + txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d") + db_transaction = DBTransaction( + transaction_id=txn_id, + amount=amount, + date=txn_date_obj, + vendor=payee_name.strip(), + description=memo, + categorisation_id=categorization_id, + user_id=user_id, + ) + + # Add to database + db.add(db_transaction) + transactions.append( { "id": txn_id, @@ -189,13 +231,15 @@ async def import_transactions_csv( "amount": amount, "payee_name": payee_name.strip(), "memo": memo, + "categorization_id": categorization_id, + "user_id": user_id, } ) except Exception as e: errors.append(f"Row {idx + 1}: {str(e)}") - # Store transactions globally for auto-matching - global stored_transactions - stored_transactions = transactions + + # Commit all transactions to database + db.commit() return { "imported_count": len(transactions), @@ -209,7 +253,12 @@ async def import_transactions_csv( @app.post("/transactions/import/image") -async def import_transactions_from_image(file: UploadFile = File(...)): +async def import_transactions_from_image( + db: db_dependency, + file: UploadFile = File(...), + categorization_id: str = Form("image_import"), + user_id: str = Form("default"), +): """ Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction. """ @@ -236,9 +285,9 @@ async def import_transactions_from_image(file: UploadFile = File(...)): detail=extraction_result.get("error", "Extraction failed"), ) extracted_transactions = extraction_result.get("transactions", []) - # Store transactions globally for auto-matching - global stored_transactions - stored_transactions = [] + + # Store transactions in database + transactions = [] for idx, txn in enumerate(extracted_transactions): try: txn_id = f"img_{file.filename}_{idx + 1}" @@ -253,7 +302,24 @@ async def import_transactions_from_image(file: UploadFile = File(...)): # Fallback: use current year if parsing fails txn_date = f"2024-{txn_date_raw}" - stored_transactions.append( + # Parse date for database + txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d") + + # Create database transaction object + db_transaction = DBTransaction( + transaction_id=txn_id, + amount=float(amount), + date=txn_date_obj, + vendor=vendor, + description=memo, + categorisation_id=categorization_id, + user_id=user_id, + ) + + # Add to database + db.add(db_transaction) + + transactions.append( { "id": txn_id, "txn_date": txn_date, @@ -262,11 +328,16 @@ async def import_transactions_from_image(file: UploadFile = File(...)): "memo": memo, } ) - except Exception: + except Exception as e: + logger.warning(f"Error processing transaction {idx}: {str(e)}") continue + + # Commit all transactions to database + db.commit() + return { - "imported_count": len(stored_transactions), - "converted_transactions": stored_transactions, + "imported_count": len(transactions), + "converted_transactions": transactions, "errors": [], } except Exception as e: @@ -330,7 +401,7 @@ async def upload_multiple_documents(files: List[UploadFile] = File(...)): @app.post("/process/{file_id}", response_model=DocumentProcessResponse) -async def process_document(file_id: str): +async def process_document(file_id: str, db: db_dependency): """ Process a previously uploaded document to extract receipt information. @@ -351,8 +422,34 @@ async def process_document(file_id: str): file_type = file_data["filename"].split(".")[-1].lower() receipt_data = await document_processor.process_file(file_path, file_type) - # Store processed receipt - processed_receipts[file_id] = receipt_data + # Parse date for database storage + receipt_date = None + if receipt_data.get("date"): + try: + receipt_date = datetime.strptime(receipt_data["date"], "%Y-%m-%d") + except ValueError: + receipt_date = datetime.now() + else: + receipt_date = datetime.now() + + # Create database receipt object + db_receipt = DBReceipt( + receipt_id=f"receipt_{file_id}", + file_id=file_id, + amount=receipt_data.get("total_amount", 0.0), + date=receipt_date, + vendor=receipt_data.get("vendor", ""), + description=receipt_data.get("description", ""), + category=receipt_data.get("category", ""), + tax_amount=receipt_data.get("tax_amount", 0.0), + confidence=receipt_data.get("confidence", 0.0), + extraction_success=str(receipt_data.get("extraction_success", False)), + error_message=receipt_data.get("error"), + ) + + # Add to database + db.add(db_receipt) + db.commit() return DocumentProcessResponse( file_id=file_id, @@ -610,7 +707,7 @@ async def process_document(file_id: str): @app.post("/match-specific", response_model=MatchingResponse) -async def match_specific_receipts(request: MatchSpecificRequest): +async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependency): """ Match specific receipts against imported transactions. @@ -624,196 +721,131 @@ async def match_specific_receipts(request: MatchSpecificRequest): f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}" ) - # Check if transactions are imported - if not stored_transactions: - logger.warning("No transactions imported") + # Get transactions from database + db_transactions = get_transactions_from_db( + db, categorization_id=categorization_id + ) + + if not db_transactions: + logger.warning("No transactions found in database") raise HTTPException( status_code=400, - detail="No transactions imported. Please upload CSV first.", + detail="No transactions found. Please upload CSV first.", ) - logger.info(f"Found {len(stored_transactions)} stored transactions") + logger.info(f"Found {len(db_transactions)} transactions in database") - # Convert stored transactions to Transaction objects + # Convert database transactions to Transaction objects transactions = [] - for txn in stored_transactions: + for db_txn in db_transactions: try: - txn_date = datetime.strptime(txn["txn_date"], "%Y-%m-%d") transaction = Transaction( - id=txn["id"], - transaction_date=txn_date, - amount=txn["amount"], - vendor=txn["payee_name"], - notes=txn["memo"], + id=db_txn.transaction_id, + transaction_date=db_txn.date, + amount=db_txn.amount, + vendor=db_txn.vendor, + notes=db_txn.description or "", ) transactions.append(transaction) except Exception as e: - logger.warning(f"Error converting transaction {txn['id']}: {str(e)}") + logger.warning( + f"Error converting transaction {db_txn.transaction_id}: {str(e)}" + ) continue logger.info(f"Converted {len(transactions)} transactions") - # Get receipts for the specified file IDs + # Get receipts for the specified file IDs from database + db_receipts = get_receipts_from_db(db, file_ids) receipts = [] missing_files = [] for file_id in file_ids: - if file_id in processed_receipts: - receipt_data = processed_receipts[file_id] - logger.info(f"DEBUG: receipt_data for {file_id}: {receipt_data}") - logger.info( - f"DEBUG: receipt_data keys for {file_id}: {list(receipt_data.keys())}" - ) + # Find the corresponding database receipt + db_receipt = next((r for r in db_receipts if r.file_id == file_id), None) + + if db_receipt: try: - # Handle missing date field - if "date" not in receipt_data or not receipt_data["date"]: - logger.warning( - f"Missing date for receipt {file_id}, using current date" - ) - receipt_date = datetime.now() - else: - receipt_date = datetime.strptime( - receipt_data["date"], "%Y-%m-%d" - ) - - # Handle missing amount field - try multiple possible keys - amount = receipt_data.get("amount") - if amount is None: - amount = receipt_data.get("total_amount") - if amount is None: - amount = receipt_data.get("amount_total") - if amount is None: - logger.warning( - f"Missing amount for receipt {file_id}, using 0.0" - ) - amount = 0.0 - - # Ensure amount is a float - try: - amount = float(amount) - except (ValueError, TypeError): - logger.warning( - f"Invalid amount '{amount}' for receipt {file_id}, using 0.0" - ) - amount = 0.0 - - logger.info(f"DEBUG: amount for {file_id}: {amount}") - - # Handle missing vendor field - vendor = receipt_data.get("vendor", "") - if not vendor: - logger.warning( - f"Missing vendor for receipt {file_id}, using 'Unknown'" - ) - vendor = "Unknown" - - # Handle missing category field - category = receipt_data.get("category", "Other") - - # Handle description field - description = receipt_data.get("description", "") - - # Handle tax field - tax = receipt_data.get("tax", receipt_data.get("tax_amount", 0.0)) - try: - tax = float(tax) - except (ValueError, TypeError): - tax = 0.0 - receipt = Receipt( - id=file_id, - file_name=uploaded_files[file_id]["filename"], - upload_date=uploaded_files[file_id]["upload_date"], - receipt_date=receipt_date, - amount=amount, - tax=tax, - vendor=vendor, - category=category, - description=description, + id=db_receipt.receipt_id, + receipt_date=db_receipt.date, + amount=db_receipt.amount, + vendor=db_receipt.vendor, + category=db_receipt.category or "Other", + description=db_receipt.description or "", + tax=db_receipt.tax_amount or 0.0, + file_name=db_receipt.file_id, + upload_date=datetime.now(), ) receipts.append(receipt) - logger.info(f"Added receipt: {receipt.vendor} - ${receipt.amount}") + logger.info(f"Successfully loaded receipt for file_id: {file_id}") except Exception as e: - logger.warning( + logger.error( f"Error creating receipt object for {file_id}: {str(e)}" ) - missing_files.append(f"{file_id} (error: {str(e)})") + missing_files.append(file_id) else: - logger.warning(f"Receipt {file_id} not found in processed_receipts") - missing_files.append(f"{file_id} (not found)") + logger.warning(f"Receipt {file_id} not found in database") + missing_files.append(file_id) + + logger.info(f"Found {len(receipts)} receipts, {len(missing_files)} missing") if missing_files: - logger.error(f"Missing files: {missing_files}") - raise HTTPException( - status_code=400, detail=f"Missing files: {missing_files}" - ) + logger.warning(f"Missing files: {missing_files}") - logger.info( - f"Processing {len(receipts)} receipts against {len(transactions)} transactions" - ) + if not receipts: + logger.warning("No valid receipts found") + raise HTTPException( + status_code=400, + detail="No valid receipts found for matching.", + ) # Perform matching + logger.info( + f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions" + ) + try: - logger.info("Starting direct matching call (without ThreadPoolExecutor)") - logger.info(f"matching_engine type: {type(matching_engine)}") - logger.info( - f"matching_engine.process_matching type: {type(matching_engine.process_matching)}" - ) - logger.info(f"receipts type: {type(receipts)}, length: {len(receipts)}") - logger.info( - f"transactions type: {type(transactions)}, length: {len(transactions)}" - ) + matching_results = matching_engine.process_matching(receipts, transactions) + logger.info(f"Matching completed, got {len(matching_results)} results") - matches = matching_engine.process_matching(receipts, transactions) - - logger.info( - f"Matching completed successfully. Found {len(matches)} matches" - ) - - # Convert matches to response format + # Convert matching results to response format match_responses = [] - for match in matches: - logger.info(f"Raw match object: {match}") - logger.info(f" receipt_id: {match.receipt.id}") - logger.info(f" transaction_id: {match.transaction.id}") - logger.info(f" confidence_score: {match.confidence_score}") - logger.info(f" match_reason: {match.match_reason}") - logger.info(f" receipt_vendor: {match.receipt.vendor}") - logger.info(f" receipt_amount: {match.receipt.amount}") - logger.info(f" transaction_vendor: {match.transaction.vendor}") - logger.info(f" transaction_amount: {match.transaction.amount}") - + for result in matching_results: match_response = MatchResponse( - receipt_id=match.receipt.id, - transaction_id=match.transaction.id, - confidence_score=match.confidence_score, - match_reason=match.match_reason, - receipt_vendor=match.receipt.vendor, - receipt_amount=match.receipt.amount, - receipt_description=match.receipt.description, - receipt_category=match.receipt.category, - receipt_tax_amount=match.receipt.tax, - transaction_vendor=match.transaction.vendor, - transaction_amount=match.transaction.amount, + receipt_id=result.receipt.id, + transaction_id=result.transaction.id + if result.transaction + else "no_match", + confidence_score=result.confidence_score, + match_reason=result.match_reason, + receipt_vendor=result.receipt.vendor, + receipt_amount=result.receipt.amount, + receipt_description=result.receipt.description, + receipt_category=result.receipt.category, + receipt_tax_amount=result.receipt.tax, + transaction_vendor=result.transaction.vendor + if result.transaction + else "", + transaction_amount=result.transaction.amount + if result.transaction + else 0.0, ) match_responses.append(match_response) - logger.info( - f"Successfully created MatchResponse for {match.receipt.vendor} -> {match.transaction.vendor}" - ) - - logger.info(f"Formatted {len(match_responses)} match responses") # Calculate statistics - if match_responses: - high_confidence = sum( - 1 for m in match_responses if m.confidence_score >= 0.8 - ) - low_confidence = len(match_responses) - high_confidence - avg_score = sum(m.confidence_score for m in match_responses) / len( - match_responses - ) - else: - high_confidence = low_confidence = avg_score = 0 + high_confidence = len( + [r for r in matching_results if r.confidence_score >= 0.8] + ) + low_confidence = len( + [r for r in matching_results if r.confidence_score < 0.5] + ) + avg_score = ( + sum(r.confidence_score for r in matching_results) + / len(matching_results) + if matching_results + else 0 + ) stats = { "total": len(match_responses), @@ -833,7 +865,6 @@ async def match_specific_receipts(request: MatchSpecificRequest): logger.error(f"Exception in matching section: {str(e)}") logger.error(f"Exception type: {type(e)}") logger.error(f"Exception args: {e.args}") - logger.error(f"Traceback: {e.__traceback__}") raise HTTPException( status_code=500, detail=f"Unexpected matching error: {str(e)}" ) @@ -845,6 +876,87 @@ async def match_specific_receipts(request: MatchSpecificRequest): raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================ +# DATABASE QUERY ENDPOINTS +# ============================================================================ + + +@app.get("/transactions") +async def get_transactions( + db: db_dependency, + user_id: str = None, + categorization_id: str = None, + limit: int = 100, +): + """ + Get transactions from the database. + """ + try: + transactions = get_transactions_from_db(db, user_id, categorization_id) + + # Limit results + transactions = transactions[:limit] + + # Convert to response format + result = [] + for txn in transactions: + result.append( + { + "id": txn.transaction_id, + "amount": txn.amount, + "date": txn.date.strftime("%Y-%m-%d"), + "vendor": txn.vendor, + "description": txn.description, + "category": txn.category, + "tax_amount": txn.tax_amount, + "categorisation_id": txn.categorisation_id, + "user_id": txn.user_id, + } + ) + + return { + "transactions": result, + "count": len(result), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/receipts") +async def get_receipts(db: db_dependency, limit: int = 100): + """ + Get receipts from the database. + """ + try: + receipts = db.query(DBReceipt).limit(limit).all() + + # Convert to response format + result = [] + for receipt in receipts: + result.append( + { + "id": receipt.receipt_id, + "file_id": receipt.file_id, + "amount": receipt.amount, + "date": receipt.date.strftime("%Y-%m-%d"), + "vendor": receipt.vendor, + "description": receipt.description, + "category": receipt.category, + "tax_amount": receipt.tax_amount, + "confidence": receipt.confidence, + "extraction_success": receipt.extraction_success, + "error_message": receipt.error_message, + } + ) + + return { + "receipts": result, + "count": len(result), + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================ # RULES MANAGEMENT ENDPOINTS # ============================================================================ @@ -921,14 +1033,18 @@ async def delete_rule(rule_name: str): @app.get("/stats") -async def get_stats(): +async def get_stats(db: db_dependency): """ Get system statistics. """ try: + # Count transactions and receipts from database + total_transactions = db.query(DBTransaction).count() + total_receipts = db.query(DBReceipt).count() + return { - "total_transactions": len(stored_transactions), - "total_receipts": len(processed_receipts), + "total_transactions": total_transactions, + "total_receipts": total_receipts, "total_uploaded_files": len(uploaded_files), "rules_count": len(matching_engine.rules_engine.rules), }