import csv import io import logging import uuid from datetime import datetime from typing import List from database import ( DBReceipt, DBTransaction, DBUploadedFile, create_db_tables, db_dependency, ) from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from schemas import ( DocumentProcessRequest, DocumentProcessResponse, DocumentUploadResponse, MatchingResponse, MatchResponse, MatchSpecificRequest, Receipt, RuleRequest, Transaction, ) from services.ai_rules import AIRule from services.document_processor import DocumentProcessor from services.matching_engine import MatchingEngine from sqlalchemy.orm import Session # Don't create tables at import time - do it on startup # create_db_tables() logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler()], ) logger = logging.getLogger(__name__) app = FastAPI( title="AI Bookkeeper - Data Science Engine", description="AI-powered receipt-to-transaction matching engine. Receives transaction data and provides intelligent matching capabilities.", version="1.0.0", ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Initialize database on startup""" logger.info("Starting up application...") create_db_tables() logger.info("Application startup complete") # Initialize DS Engine components matching_engine = MatchingEngine() document_processor = DocumentProcessor() # Helper functions for database operations def get_transactions_from_db( db: Session, user_id: str = None, categorization_id: str = None ): """Retrieve transactions from database""" query = db.query(DBTransaction) if user_id: query = query.filter(DBTransaction.user_id == user_id) if categorization_id: query = query.filter(DBTransaction.categorisation_id == categorization_id) return query.all() def get_receipt_from_db(db: Session, file_id: str): """Retrieve receipt from database by file_id""" return db.query(DBReceipt).filter(DBReceipt.file_id == file_id).first() def get_receipts_from_db(db: Session, file_ids: List[str]): """Retrieve multiple receipts from database by file_ids""" return db.query(DBReceipt).filter(DBReceipt.file_id.in_(file_ids)).all() def get_uploaded_file_from_db(db: Session, file_id: str): """Retrieve uploaded file from database by file_id""" return db.query(DBUploadedFile).filter(DBUploadedFile.file_id == file_id).first() def get_uploaded_files_from_db(db: Session, file_ids: List[str]): """Retrieve multiple uploaded files from database by file_ids""" return db.query(DBUploadedFile).filter(DBUploadedFile.file_id.in_(file_ids)).all() @app.get("/", tags=["Health"]) async def root(): """Health check endpoint""" return { "message": "AI Bookkeeper Data Science Engine is running", "version": "1.0.0", "status": "healthy", } # ============================================================================ # TRANSACTION IMPORT ENDPOINTS # ============================================================================ @app.post("/transactions/import/csv", tags=["Transaction Import"]) async def import_transactions_csv( db: db_dependency, file: UploadFile = File(...), categorization_id: str = Form(...), user_id: str = Form(...), ): """ Import transactions from a CSV file (custom bank export format). """ try: content = await file.read() decoded = content.decode("utf-8") reader = csv.DictReader(io.StringIO(decoded)) transactions = [] errors = [] for idx, row in enumerate(reader): try: # Use correct headers and strip whitespace account_number = row.get("Account Number") or row.get( "Account Number ".strip() ) txn_date_raw = row.get("Transaction Date") or row.get( "Transaction Date ".strip() or row.get("Date") ) amount_raw = row.get("Amount") or row.get("Amount ".strip()) payee_name = row.get("Description 2") or row.get( "Description 2 ".strip() ) memo = f"{row.get('Account Type', '').strip()} {row.get('Cheque Number', '').strip()} {row.get('Description 1', '').strip()}".strip() # Compose ID txn_id = f"{account_number}_{idx + 1}" # Parse date (try multiple formats) txn_date_str = txn_date_raw.strip() txn_date = None for fmt in ("%m/%d/%y", "%m/%d/%Y"): try: txn_date = datetime.strptime(txn_date_str, fmt).strftime( "%Y-%m-%d" ) break except Exception: continue if not txn_date: raise ValueError(f"Could not parse date: {txn_date_str}") # Parse amount amount = float(amount_raw.replace(",", "").strip()) # Create database transaction object txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d") db_transaction = DBTransaction( transaction_id=txn_id, amount=amount, date=txn_date_obj, vendor=payee_name.strip(), description=memo, categorisation_id=categorization_id, user_id=user_id, ) # Add to database db.add(db_transaction) transactions.append( { "id": txn_id, "txn_date": txn_date, "amount": amount, "payee_name": payee_name.strip(), "memo": memo, "categorization_id": categorization_id, "user_id": user_id, } ) except Exception as e: errors.append(f"Row {idx + 1}: {str(e)}") # Commit all transactions to database db.commit() return { "imported_count": len(transactions), "converted_transactions": transactions, "errors": errors, "categorization_id": categorization_id, "user_id": user_id, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/transactions/import/image", tags=["Transaction Import"]) async def import_transactions_from_image( db: db_dependency, file: UploadFile = File(...), categorization_id: str = Form("image_import"), user_id: str = Form("default"), ): """ Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction. """ try: # Validate file type allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"] file_extension = file.filename.split(".")[-1].lower() if file_extension not in allowed_types: raise HTTPException( status_code=400, detail=f"Unsupported file type. Allowed: {allowed_types}", ) # Read file content content = await file.read() # Save file to disk image_path = await document_processor.save_uploaded_file(content, file.filename) # Extract transactions from image (pass file path) extraction_result = await document_processor.extract_transactions_from_image( image_path ) if not extraction_result.get("extraction_success", False): raise HTTPException( status_code=500, detail=extraction_result.get("error", "Extraction failed"), ) extracted_transactions = extraction_result.get("transactions", []) # Store transactions in database transactions = [] for idx, txn in enumerate(extracted_transactions): try: txn_id = f"img_{file.filename}_{idx + 1}" txn_date_raw = txn.get("date") amount = txn.get("amount") vendor = txn.get("vendor") memo = txn.get("memo", "") # Parse date to YYYY-MM-DD format txn_date = document_processor._parse_date_to_iso(txn_date_raw) if not txn_date: # Fallback: use current year if parsing fails txn_date = f"2024-{txn_date_raw}" # Parse date for database txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d") # Create database transaction object db_transaction = DBTransaction( transaction_id=txn_id, amount=float(amount), date=txn_date_obj, vendor=vendor, description=memo, categorisation_id=categorization_id, user_id=user_id, ) # Add to database db.add(db_transaction) transactions.append( { "id": txn_id, "txn_date": txn_date, "amount": amount, "payee_name": vendor, "memo": memo, } ) except Exception as e: logger.warning(f"Error processing transaction {idx}: {str(e)}") continue # Commit all transactions to database db.commit() return { "imported_count": len(transactions), "converted_transactions": transactions, "errors": [], } except Exception as e: logger.error(f"Error importing transactions from image: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # DOCUMENT PROCESSING ENDPOINTS # ============================================================================ @app.post( "/upload-multiple", response_model=List[DocumentUploadResponse], tags=["Document Processing"], ) async def upload_multiple_documents( files: List[UploadFile] = File(...), db: db_dependency = None ): """ Upload multiple receipt images for processing. This endpoint accepts multiple image files and returns file IDs that can be used with the /process/{file_id} endpoint. """ try: responses = [] for file in files: # Validate file type allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"] file_extension = file.filename.split(".")[-1].lower() if file_extension not in allowed_types: raise HTTPException( status_code=400, detail=f"Unsupported file type for {file.filename}. Allowed: {allowed_types}", ) # Generate unique file ID file_id = str(uuid.uuid4()) # Read file content and save to disk content = await file.read() file_path = await document_processor.save_uploaded_file( content, file.filename ) # Create database record for uploaded file db_uploaded_file = DBUploadedFile( file_id=file_id, filename=file.filename, file_path=file_path, file_type=file_extension, upload_date=datetime.now(), status="uploaded", ) # Add to database db.add(db_uploaded_file) responses.append( DocumentUploadResponse( file_id=file_id, filename=file.filename, file_type=file_extension, upload_date=datetime.now(), status="uploaded", ) ) # Commit all uploaded files to database db.commit() return responses except Exception as e: logger.error(f"Error uploading documents: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post( "/process/{file_id}", response_model=DocumentProcessResponse, tags=["Document Processing"], ) async def process_document( file_id: str, request: DocumentProcessRequest, db: db_dependency ): """ Process a previously uploaded document to extract receipt information. This endpoint uses AI to extract structured data from receipt images, including vendor, amount, date, and category information. Optionally accepts: - user_location: Guide tax calculations and depreciation based on location (format: "State/Province, Country" e.g., "Ontario, Canada") - ai_rules: Custom categorization rules to override default logic (e.g., [{"condition": "vendor is Starbucks", "action": "Food"}]) """ logger.info(f"Request: {request}") try: # Get file info from database db_uploaded_file = get_uploaded_file_from_db(db, file_id) if not db_uploaded_file: raise HTTPException(status_code=404, detail=f"File {file_id} not found") # Convert ai_rules from Pydantic models to dictionaries if provided ai_rules_list = None if request.ai_rules: ai_rules_list = [ {"condition": rule.condition, "action": rule.action} for rule in request.ai_rules ] # Process the file using the stored file path receipt_data = await document_processor.process_file( db_uploaded_file.file_path, db_uploaded_file.file_type, user_location=request.user_location, ai_rules=ai_rules_list, ) logger.info(f"Extracted receipt data: {receipt_data}") # Parse date for database storage receipt_date = None if receipt_data.get("date"): try: receipt_date = datetime.strptime(receipt_data["date"], "%Y-%m-%d") except ValueError: receipt_date = datetime.now() else: receipt_date = datetime.now() # Create database receipt object db_receipt = DBReceipt( receipt_id=f"receipt_{file_id}", file_id=file_id, amount=receipt_data.get("total_amount", 0.0), date=receipt_date, vendor=receipt_data.get("vendor", ""), description=receipt_data.get("description", ""), category=receipt_data.get("category", ""), tax_amount=receipt_data.get("tax_amount", 0.0), confidence=receipt_data.get("confidence", 0.0), extraction_success=str(receipt_data.get("extraction_success", False)), error_message=receipt_data.get("error"), receipt_currency=receipt_data.get("currency"), receipt_location=receipt_data.get("location"), calculated_tax=receipt_data.get("calculated_tax"), is_depreciable=str(receipt_data.get("is_depreciable")) if receipt_data.get("is_depreciable") is not None else None, name_of_asset=receipt_data.get("name_of_asset"), cca_rate=receipt_data.get("cca_rate"), useful_life=receipt_data.get("useful_life"), residual_value=receipt_data.get("residual_value"), ) # Add to database db.add(db_receipt) db.commit() return DocumentProcessResponse( file_id=file_id, receipt_id=db_receipt.receipt_id, extraction_success=receipt_data.get("extraction_success", False), vendor=receipt_data.get("vendor", ""), description=receipt_data.get("description", ""), total_amount=receipt_data.get("total_amount", 0.0), tax_amount=receipt_data.get("tax_amount", 0.0), date=receipt_data.get("date", ""), category=receipt_data.get("category", ""), confidence=receipt_data.get("confidence", 0.0), error=receipt_data.get("error", None), receipt_currency=receipt_data.get("currency"), receipt_location=receipt_data.get("location"), calculated_tax=receipt_data.get("calculated_tax"), is_depreciable=receipt_data.get("is_depreciable"), name_of_asset=receipt_data.get("name_of_asset"), cca_rate=receipt_data.get("cca_rate"), useful_life=receipt_data.get("useful_life"), residual_value=receipt_data.get("residual_value"), ) except Exception as e: logger.error(f"Error processing document {request.file_id}: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/match-specific", response_model=MatchingResponse, tags=["AI Matching"]) async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependency): """ Match specific receipts against imported transactions. This endpoint takes a request with receipt file IDs and categorization ID, and matches them against the currently imported transactions using AI-powered matching logic. """ try: file_ids = request.file_ids categorization_id = request.categorization_id logger.info( f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}" ) # Get transactions from database db_transactions = get_transactions_from_db( db, categorization_id=categorization_id ) if not db_transactions: logger.warning("No transactions found in database") raise HTTPException( status_code=400, detail="No transactions found. Please upload CSV first.", ) logger.info(f"Found {len(db_transactions)} transactions in database") # Convert database transactions to Transaction objects transactions = [] for db_txn in db_transactions: try: transaction = Transaction( id=db_txn.transaction_id, transaction_date=db_txn.date, amount=db_txn.amount, vendor=db_txn.vendor, notes=db_txn.description or "", ) transactions.append(transaction) except Exception as e: logger.warning( f"Error converting transaction {db_txn.transaction_id}: {str(e)}" ) continue logger.info(f"Converted {len(transactions)} transactions") # Get receipts for the specified file IDs from database db_receipts = get_receipts_from_db(db, file_ids) receipts = [] missing_files = [] for file_id in file_ids: # Find the corresponding database receipt db_receipt = next((r for r in db_receipts if r.file_id == file_id), None) if db_receipt: try: receipt = Receipt( id=db_receipt.receipt_id, receipt_date=db_receipt.date, amount=db_receipt.amount, vendor=db_receipt.vendor, category=db_receipt.category or "Other", description=db_receipt.description or "", tax=db_receipt.tax_amount or 0.0, file_name=db_receipt.file_id, upload_date=datetime.now(), ) receipts.append(receipt) logger.info(f"Successfully loaded receipt for file_id: {file_id}") except Exception as e: logger.error( f"Error creating receipt object for {file_id}: {str(e)}" ) missing_files.append(file_id) else: logger.warning(f"Receipt {file_id} not found in database") missing_files.append(file_id) logger.info(f"Found {len(receipts)} receipts, {len(missing_files)} missing") if missing_files: logger.warning(f"Missing files: {missing_files}") if not receipts: logger.warning("No valid receipts found") raise HTTPException( status_code=400, detail="No valid receipts found for matching.", ) # Perform matching logger.info( f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions" ) # Extract user location from user_tax_info if provided user_location = request.user_location # Default/fallback if request.user_tax_info: # Use state_code from user_tax_info (e.g., "ON", "QC", "BC") user_location = request.user_tax_info.state.state_code logger.info( f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})" ) else: logger.info(f"Using default/provided user_location: {user_location}") # Convert ai_rules from Pydantic models to dictionaries if provided ai_rules_list = None if request.ai_rules: ai_rules_list = [ {"condition": rule.condition, "action": rule.action} for rule in request.ai_rules ] logger.info(f"Applying {len(ai_rules_list)} custom AI rules to matching") try: matching_results = matching_engine.process_matching( receipts, transactions, user_location=user_location, ai_rules=ai_rules_list, ) logger.info(f"Matching completed, got {len(matching_results)} results") # Convert matching results to response format match_responses = [] for result in matching_results: # Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax final_tax = result.receipt.tax # if result.tax_analysis and "final_tax_amount" in result.tax_analysis: # final_tax = result.tax_analysis["final_tax_amount"] # Extract flag_for_review and auto_approve from tax_analysis if available flag_for_review = None auto_approve = None if result.tax_analysis: flag_for_review = result.tax_analysis.get("flag_for_review") auto_approve = result.tax_analysis.get("auto_approve") match_response = MatchResponse( receipt_id=result.receipt.id, transaction_id=result.transaction.id if result.transaction else "no_match", confidence_score=result.confidence_score * 100, match_reason=result.match_reason, receipt_vendor=result.receipt.vendor, receipt_amount=result.receipt.amount, receipt_description=result.receipt.description, receipt_category=result.receipt.category, receipt_tax_amount=final_tax, transaction_vendor=result.transaction.vendor if result.transaction else "", transaction_amount=result.transaction.amount if result.transaction else 0.0, tax_analysis=result.tax_analysis, flag_for_review=flag_for_review, auto_approve=auto_approve, ) match_responses.append(match_response) # Calculate statistics high_confidence = len( [r for r in matching_results if r.confidence_score >= 0.8] ) low_confidence = len( [r for r in matching_results if r.confidence_score < 0.5] ) avg_score = ( sum(r.confidence_score for r in matching_results) / len(matching_results) if matching_results else 0 ) stats = { "total": len(match_responses), "high_confidence": high_confidence, "low_confidence": low_confidence, "avg_score": round(avg_score, 2), } logger.info(f"Generated stats: {stats}") logger.info( f"Match-specific completed successfully with {len(match_responses)} matches" ) return MatchingResponse(matches=match_responses, stats=stats) except Exception as e: logger.error(f"Exception in matching section: {str(e)}") logger.error(f"Exception type: {type(e)}") logger.error(f"Exception args: {e.args}") raise HTTPException( status_code=500, detail=f"Unexpected matching error: {str(e)}" ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error in match_specific_receipts: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # DATABASE QUERY ENDPOINTS # ============================================================================ @app.get("/transactions", tags=["Database Queries"]) async def get_transactions( db: db_dependency, user_id: str = None, categorization_id: str = None, limit: int = 100, ): """ Get transactions from the database. """ try: transactions = get_transactions_from_db(db, user_id, categorization_id) # Limit results transactions = transactions[:limit] # Convert to response format result = [] for txn in transactions: result.append( { "id": txn.transaction_id, "amount": txn.amount, "date": txn.date.strftime("%Y-%m-%d"), "vendor": txn.vendor, "description": txn.description, "category": txn.category, "tax_amount": txn.tax_amount, "categorisation_id": txn.categorisation_id, "user_id": txn.user_id, } ) return { "transactions": result, "count": len(result), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/receipts", tags=["Database Queries"]) async def get_receipts(db: db_dependency, limit: int = 100): """ Get receipts from the database. """ try: receipts = db.query(DBReceipt).limit(limit).all() # Convert to response format result = [] for receipt in receipts: result.append( { "id": receipt.receipt_id, "file_id": receipt.file_id, "amount": receipt.amount, "date": receipt.date.strftime("%Y-%m-%d"), "vendor": receipt.vendor, "description": receipt.description, "category": receipt.category, "tax_amount": receipt.tax_amount, "confidence": receipt.confidence, "extraction_success": receipt.extraction_success, "error_message": receipt.error_message, } ) return { "receipts": result, "count": len(result), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/uploaded-files", tags=["Database Queries"]) async def get_uploaded_files(db: db_dependency, limit: int = 100): """ Get uploaded files from the database. """ try: uploaded_files = db.query(DBUploadedFile).limit(limit).all() # Convert to response format result = [] for file in uploaded_files: result.append( { "file_id": file.file_id, "filename": file.filename, "file_path": file.file_path, "file_type": file.file_type, "upload_date": file.upload_date.strftime("%Y-%m-%d %H:%M:%S"), "status": file.status, } ) return { "uploaded_files": result, "count": len(result), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # RULES MANAGEMENT ENDPOINTS # ============================================================================ @app.post("/rules", tags=["AI Rules Management"]) async def add_rule(request: RuleRequest): """ Add a new AI rule for transaction matching. """ try: new_rule = AIRule( name=request.name, condition=request.condition, action=request.action, source=request.source, ) matching_engine.rules_engine.rules.append(new_rule) return {"message": f"Rule '{request.name}' added successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/rules", tags=["AI Rules Management"]) async def get_rules(): """ Get all current AI rules. """ try: rules = [] for rule in matching_engine.rules_engine.rules: rules.append( { "name": rule.name, "condition": rule.condition, "action": rule.action, "source": rule.source, "status": rule.status, } ) return {"rules": rules} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/rules/{rule_name}", tags=["AI Rules Management"]) async def delete_rule(rule_name: str): """ Delete an AI rule by name. """ try: rules = matching_engine.rules_engine.rules for i, rule in enumerate(rules): if rule.name == rule_name: del rules[i] return {"message": f"Rule '{rule_name}' deleted successfully"} raise HTTPException(status_code=404, detail=f"Rule '{rule_name}' not found") except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # STATISTICS ENDPOINT # ============================================================================ @app.get("/stats", tags=["Statistics"]) async def get_stats(db: db_dependency): """ Get system statistics. """ try: # Count transactions, receipts, and uploaded files from database total_transactions = db.query(DBTransaction).count() total_receipts = db.query(DBReceipt).count() total_uploaded_files = db.query(DBUploadedFile).count() return { "total_transactions": total_transactions, "total_receipts": total_receipts, "total_uploaded_files": total_uploaded_files, "rules_count": len(matching_engine.rules_engine.rules), } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8654)