3559cbe19d
This commit introduces a new test script, `test_json_extraction.py`, which verifies the correctness of the JSON extraction logic. The script includes a function to extract the first valid JSON object from raw input and a series of test cases covering various scenarios, such as clean JSON, JSON with extra text, nested JSON, and escaped quotes. The tests ensure that the extraction function behaves as expected and handles edge cases appropriately.
908 lines
32 KiB
Python
908 lines
32 KiB
Python
import csv
|
|
import io
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import List
|
|
|
|
from database import (
|
|
DBReceipt,
|
|
DBTransaction,
|
|
DBUploadedFile,
|
|
create_db_tables,
|
|
db_dependency,
|
|
)
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from schemas import (
|
|
DocumentProcessRequest,
|
|
DocumentProcessResponse,
|
|
DocumentUploadResponse,
|
|
MatchingResponse,
|
|
MatchResponse,
|
|
MatchSpecificRequest,
|
|
Receipt,
|
|
RuleRequest,
|
|
Transaction,
|
|
)
|
|
from services.ai_rules import AIRule
|
|
from services.document_processor import DocumentProcessor
|
|
from services.matching_engine import MatchingEngine
|
|
from sqlalchemy.orm import Session
|
|
|
|
# Don't create tables at import time - do it on startup
|
|
# create_db_tables()
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
handlers=[logging.StreamHandler()],
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(
|
|
title="AI Bookkeeper - Data Science Engine",
|
|
description="AI-powered receipt-to-transaction matching engine. Receives transaction data and provides intelligent matching capabilities.",
|
|
version="1.0.0",
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize database on startup"""
|
|
logger.info("Starting up application...")
|
|
create_db_tables()
|
|
logger.info("Application startup complete")
|
|
|
|
|
|
# Initialize DS Engine components
|
|
matching_engine = MatchingEngine()
|
|
document_processor = DocumentProcessor()
|
|
|
|
|
|
# Helper functions for database operations
|
|
def get_transactions_from_db(
|
|
db: Session, user_id: str = None, categorization_id: str = None
|
|
):
|
|
"""Retrieve transactions from database"""
|
|
query = db.query(DBTransaction)
|
|
if user_id:
|
|
query = query.filter(DBTransaction.user_id == user_id)
|
|
if categorization_id:
|
|
query = query.filter(DBTransaction.categorisation_id == categorization_id)
|
|
return query.all()
|
|
|
|
|
|
def get_receipt_from_db(db: Session, file_id: str):
|
|
"""Retrieve receipt from database by file_id"""
|
|
return db.query(DBReceipt).filter(DBReceipt.file_id == file_id).first()
|
|
|
|
|
|
def get_receipts_from_db(db: Session, file_ids: List[str]):
|
|
"""Retrieve multiple receipts from database by file_ids"""
|
|
return db.query(DBReceipt).filter(DBReceipt.file_id.in_(file_ids)).all()
|
|
|
|
|
|
def get_uploaded_file_from_db(db: Session, file_id: str):
|
|
"""Retrieve uploaded file from database by file_id"""
|
|
return db.query(DBUploadedFile).filter(DBUploadedFile.file_id == file_id).first()
|
|
|
|
|
|
def get_uploaded_files_from_db(db: Session, file_ids: List[str]):
|
|
"""Retrieve multiple uploaded files from database by file_ids"""
|
|
return db.query(DBUploadedFile).filter(DBUploadedFile.file_id.in_(file_ids)).all()
|
|
|
|
|
|
@app.get("/", tags=["Health"])
|
|
async def root():
|
|
"""Health check endpoint"""
|
|
return {
|
|
"message": "AI Bookkeeper Data Science Engine is running",
|
|
"version": "1.0.0",
|
|
"status": "healthy",
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# TRANSACTION IMPORT ENDPOINTS
|
|
# ============================================================================
|
|
@app.post("/transactions/import/csv", tags=["Transaction Import"])
|
|
async def import_transactions_csv(
|
|
db: db_dependency,
|
|
file: UploadFile = File(...),
|
|
categorization_id: str = Form(...),
|
|
user_id: str = Form(...),
|
|
):
|
|
"""
|
|
Import transactions from a CSV file (custom bank export format).
|
|
"""
|
|
try:
|
|
content = await file.read()
|
|
decoded = content.decode("utf-8")
|
|
reader = csv.DictReader(io.StringIO(decoded))
|
|
transactions = []
|
|
errors = []
|
|
for idx, row in enumerate(reader):
|
|
try:
|
|
# Use correct headers and strip whitespace
|
|
account_number = row.get("Account Number") or row.get(
|
|
"Account Number ".strip()
|
|
)
|
|
txn_date_raw = row.get("Transaction Date") or row.get(
|
|
"Transaction Date ".strip() or row.get("Date")
|
|
)
|
|
amount_raw = row.get("Amount") or row.get("Amount ".strip())
|
|
payee_name = row.get("Description 2") or row.get(
|
|
"Description 2 ".strip()
|
|
)
|
|
memo = f"{row.get('Account Type', '').strip()} {row.get('Cheque Number', '').strip()} {row.get('Description 1', '').strip()}".strip()
|
|
# Compose ID
|
|
txn_id = f"{account_number}_{idx + 1}"
|
|
# Parse date (try multiple formats)
|
|
txn_date_str = txn_date_raw.strip()
|
|
txn_date = None
|
|
for fmt in ("%m/%d/%y", "%m/%d/%Y"):
|
|
try:
|
|
txn_date = datetime.strptime(txn_date_str, fmt).strftime(
|
|
"%Y-%m-%d"
|
|
)
|
|
break
|
|
except Exception:
|
|
continue
|
|
if not txn_date:
|
|
raise ValueError(f"Could not parse date: {txn_date_str}")
|
|
# Parse amount
|
|
amount = float(amount_raw.replace(",", "").strip())
|
|
|
|
# Create database transaction object
|
|
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
|
|
db_transaction = DBTransaction(
|
|
transaction_id=txn_id,
|
|
amount=amount,
|
|
date=txn_date_obj,
|
|
vendor=payee_name.strip(),
|
|
description=memo,
|
|
categorisation_id=categorization_id,
|
|
user_id=user_id,
|
|
)
|
|
|
|
# Add to database
|
|
db.add(db_transaction)
|
|
|
|
transactions.append(
|
|
{
|
|
"id": txn_id,
|
|
"txn_date": txn_date,
|
|
"amount": amount,
|
|
"payee_name": payee_name.strip(),
|
|
"memo": memo,
|
|
"categorization_id": categorization_id,
|
|
"user_id": user_id,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
errors.append(f"Row {idx + 1}: {str(e)}")
|
|
|
|
# Commit all transactions to database
|
|
db.commit()
|
|
|
|
return {
|
|
"imported_count": len(transactions),
|
|
"converted_transactions": transactions,
|
|
"errors": errors,
|
|
"categorization_id": categorization_id,
|
|
"user_id": user_id,
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/transactions/import/image", tags=["Transaction Import"])
|
|
async def import_transactions_from_image(
|
|
db: db_dependency,
|
|
file: UploadFile = File(...),
|
|
categorization_id: str = Form("image_import"),
|
|
user_id: str = Form("default"),
|
|
):
|
|
"""
|
|
Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction.
|
|
"""
|
|
try:
|
|
# Validate file type
|
|
allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"]
|
|
file_extension = file.filename.split(".")[-1].lower()
|
|
if file_extension not in allowed_types:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported file type. Allowed: {allowed_types}",
|
|
)
|
|
# Read file content
|
|
content = await file.read()
|
|
# Save file to disk
|
|
image_path = await document_processor.save_uploaded_file(content, file.filename)
|
|
# Extract transactions from image (pass file path)
|
|
extraction_result = await document_processor.extract_transactions_from_image(
|
|
image_path
|
|
)
|
|
if not extraction_result.get("extraction_success", False):
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=extraction_result.get("error", "Extraction failed"),
|
|
)
|
|
extracted_transactions = extraction_result.get("transactions", [])
|
|
|
|
# Store transactions in database
|
|
transactions = []
|
|
for idx, txn in enumerate(extracted_transactions):
|
|
try:
|
|
txn_id = f"img_{file.filename}_{idx + 1}"
|
|
txn_date_raw = txn.get("date")
|
|
amount = txn.get("amount")
|
|
vendor = txn.get("vendor")
|
|
memo = txn.get("memo", "")
|
|
|
|
# Parse date to YYYY-MM-DD format
|
|
txn_date = document_processor._parse_date_to_iso(txn_date_raw)
|
|
if not txn_date:
|
|
# Fallback: use current year if parsing fails
|
|
txn_date = f"2024-{txn_date_raw}"
|
|
|
|
# Parse date for database
|
|
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
|
|
|
|
# Create database transaction object
|
|
db_transaction = DBTransaction(
|
|
transaction_id=txn_id,
|
|
amount=float(amount),
|
|
date=txn_date_obj,
|
|
vendor=vendor,
|
|
description=memo,
|
|
categorisation_id=categorization_id,
|
|
user_id=user_id,
|
|
)
|
|
|
|
# Add to database
|
|
db.add(db_transaction)
|
|
|
|
transactions.append(
|
|
{
|
|
"id": txn_id,
|
|
"txn_date": txn_date,
|
|
"amount": amount,
|
|
"payee_name": vendor,
|
|
"memo": memo,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Error processing transaction {idx}: {str(e)}")
|
|
continue
|
|
|
|
# Commit all transactions to database
|
|
db.commit()
|
|
|
|
return {
|
|
"imported_count": len(transactions),
|
|
"converted_transactions": transactions,
|
|
"errors": [],
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error importing transactions from image: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# ============================================================================
|
|
# DOCUMENT PROCESSING ENDPOINTS
|
|
# ============================================================================
|
|
|
|
|
|
@app.post(
|
|
"/upload-multiple",
|
|
response_model=List[DocumentUploadResponse],
|
|
tags=["Document Processing"],
|
|
)
|
|
async def upload_multiple_documents(
|
|
files: List[UploadFile] = File(...), db: db_dependency = None
|
|
):
|
|
"""
|
|
Upload multiple receipt images for processing.
|
|
|
|
This endpoint accepts multiple image files and returns file IDs
|
|
that can be used with the /process/{file_id} endpoint.
|
|
"""
|
|
try:
|
|
responses = []
|
|
|
|
for file in files:
|
|
# Validate file type
|
|
allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"]
|
|
file_extension = file.filename.split(".")[-1].lower()
|
|
|
|
if file_extension not in allowed_types:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsupported file type for {file.filename}. Allowed: {allowed_types}",
|
|
)
|
|
|
|
# Generate unique file ID
|
|
file_id = str(uuid.uuid4())
|
|
|
|
# Read file content and save to disk
|
|
content = await file.read()
|
|
file_path = await document_processor.save_uploaded_file(
|
|
content, file.filename
|
|
)
|
|
|
|
# Create database record for uploaded file
|
|
db_uploaded_file = DBUploadedFile(
|
|
file_id=file_id,
|
|
filename=file.filename,
|
|
file_path=file_path,
|
|
file_type=file_extension,
|
|
upload_date=datetime.now(),
|
|
status="uploaded",
|
|
)
|
|
|
|
# Add to database
|
|
db.add(db_uploaded_file)
|
|
|
|
responses.append(
|
|
DocumentUploadResponse(
|
|
file_id=file_id,
|
|
filename=file.filename,
|
|
file_type=file_extension,
|
|
upload_date=datetime.now(),
|
|
status="uploaded",
|
|
)
|
|
)
|
|
|
|
# Commit all uploaded files to database
|
|
db.commit()
|
|
|
|
return responses
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error uploading documents: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post(
|
|
"/process/{file_id}",
|
|
response_model=DocumentProcessResponse,
|
|
tags=["Document Processing"],
|
|
)
|
|
async def process_document(
|
|
file_id: str, request: DocumentProcessRequest, db: db_dependency
|
|
):
|
|
"""
|
|
Process a previously uploaded document to extract receipt information.
|
|
|
|
This endpoint uses AI to extract structured data from receipt images,
|
|
including vendor, amount, date, and category information.
|
|
|
|
Optionally accepts:
|
|
- user_location: Guide tax calculations and depreciation based on location
|
|
(format: "State/Province, Country" e.g., "Ontario, Canada")
|
|
- ai_rules: Custom categorization rules to override default logic
|
|
(e.g., [{"condition": "vendor is Starbucks", "action": "Food"}])
|
|
"""
|
|
logger.info(f"Request: {request}")
|
|
try:
|
|
# Get file info from database
|
|
db_uploaded_file = get_uploaded_file_from_db(db, file_id)
|
|
if not db_uploaded_file:
|
|
raise HTTPException(status_code=404, detail=f"File {file_id} not found")
|
|
|
|
# Convert ai_rules from Pydantic models to dictionaries if provided
|
|
ai_rules_list = None
|
|
if request.ai_rules:
|
|
ai_rules_list = [
|
|
{"condition": rule.condition, "action": rule.action}
|
|
for rule in request.ai_rules
|
|
]
|
|
|
|
# Process the file using the stored file path
|
|
receipt_data = await document_processor.process_file(
|
|
db_uploaded_file.file_path,
|
|
db_uploaded_file.file_type,
|
|
user_location=request.user_location,
|
|
ai_rules=ai_rules_list,
|
|
)
|
|
|
|
# Parse date for database storage
|
|
receipt_date = None
|
|
if receipt_data.get("date"):
|
|
try:
|
|
receipt_date = datetime.strptime(receipt_data["date"], "%Y-%m-%d")
|
|
except ValueError:
|
|
receipt_date = datetime.now()
|
|
else:
|
|
receipt_date = datetime.now()
|
|
|
|
# Create database receipt object
|
|
db_receipt = DBReceipt(
|
|
receipt_id=f"receipt_{file_id}",
|
|
file_id=file_id,
|
|
amount=receipt_data.get("total_amount", 0.0),
|
|
date=receipt_date,
|
|
vendor=receipt_data.get("vendor", ""),
|
|
description=receipt_data.get("description", ""),
|
|
category=receipt_data.get("category", ""),
|
|
tax_amount=receipt_data.get("tax_amount", 0.0),
|
|
confidence=receipt_data.get("confidence", 0.0),
|
|
extraction_success=str(receipt_data.get("extraction_success", False)),
|
|
error_message=receipt_data.get("error"),
|
|
receipt_currency=receipt_data.get("currency"),
|
|
receipt_location=receipt_data.get("location"),
|
|
calculated_tax=receipt_data.get("calculated_tax"),
|
|
is_depreciable=str(receipt_data.get("is_depreciable"))
|
|
if receipt_data.get("is_depreciable") is not None
|
|
else None,
|
|
name_of_asset=receipt_data.get("name_of_asset"),
|
|
cca_rate=receipt_data.get("cca_rate"),
|
|
useful_life=receipt_data.get("useful_life"),
|
|
residual_value=receipt_data.get("residual_value"),
|
|
)
|
|
|
|
# Add to database
|
|
db.add(db_receipt)
|
|
db.commit()
|
|
|
|
return DocumentProcessResponse(
|
|
file_id=file_id,
|
|
receipt_id=db_receipt.receipt_id,
|
|
extraction_success=receipt_data.get("extraction_success", False),
|
|
vendor=receipt_data.get("vendor", ""),
|
|
description=receipt_data.get("description", ""),
|
|
total_amount=receipt_data.get("total_amount", 0.0),
|
|
tax_amount=receipt_data.get("tax_amount", 0.0),
|
|
date=receipt_data.get("date", ""),
|
|
category=receipt_data.get("category", ""),
|
|
confidence=receipt_data.get("confidence", 0.0),
|
|
error=receipt_data.get("error", None),
|
|
receipt_currency=receipt_data.get("currency"),
|
|
receipt_location=receipt_data.get("location"),
|
|
calculated_tax=receipt_data.get("calculated_tax"),
|
|
is_depreciable=receipt_data.get("is_depreciable"),
|
|
name_of_asset=receipt_data.get("name_of_asset"),
|
|
cca_rate=receipt_data.get("cca_rate"),
|
|
useful_life=receipt_data.get("useful_life"),
|
|
residual_value=receipt_data.get("residual_value"),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing document {request.file_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/match-specific", response_model=MatchingResponse, tags=["AI Matching"])
|
|
async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependency):
|
|
"""
|
|
Match specific receipts against imported transactions.
|
|
|
|
This endpoint takes a request with receipt file IDs and categorization ID,
|
|
and matches them against the currently imported transactions using AI-powered matching logic.
|
|
"""
|
|
try:
|
|
file_ids = request.file_ids
|
|
categorization_id = request.categorization_id
|
|
logger.info(
|
|
f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}"
|
|
)
|
|
|
|
# Get transactions from database
|
|
db_transactions = get_transactions_from_db(
|
|
db, categorization_id=categorization_id
|
|
)
|
|
|
|
if not db_transactions:
|
|
logger.warning("No transactions found in database")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No transactions found. Please upload CSV first.",
|
|
)
|
|
|
|
logger.info(f"Found {len(db_transactions)} transactions in database")
|
|
|
|
# Convert database transactions to Transaction objects
|
|
transactions = []
|
|
for db_txn in db_transactions:
|
|
try:
|
|
transaction = Transaction(
|
|
id=db_txn.transaction_id,
|
|
transaction_date=db_txn.date,
|
|
amount=db_txn.amount,
|
|
vendor=db_txn.vendor,
|
|
notes=db_txn.description or "",
|
|
)
|
|
transactions.append(transaction)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Error converting transaction {db_txn.transaction_id}: {str(e)}"
|
|
)
|
|
continue
|
|
|
|
logger.info(f"Converted {len(transactions)} transactions")
|
|
|
|
# Get receipts for the specified file IDs from database
|
|
db_receipts = get_receipts_from_db(db, file_ids)
|
|
receipts = []
|
|
missing_files = []
|
|
|
|
for file_id in file_ids:
|
|
# Find the corresponding database receipt
|
|
db_receipt = next((r for r in db_receipts if r.file_id == file_id), None)
|
|
|
|
if db_receipt:
|
|
try:
|
|
receipt = Receipt(
|
|
id=db_receipt.receipt_id,
|
|
receipt_date=db_receipt.date,
|
|
amount=db_receipt.amount,
|
|
vendor=db_receipt.vendor,
|
|
category=db_receipt.category or "Other",
|
|
description=db_receipt.description or "",
|
|
tax=db_receipt.tax_amount or 0.0,
|
|
file_name=db_receipt.file_id,
|
|
upload_date=datetime.now(),
|
|
)
|
|
receipts.append(receipt)
|
|
logger.info(f"Successfully loaded receipt for file_id: {file_id}")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error creating receipt object for {file_id}: {str(e)}"
|
|
)
|
|
missing_files.append(file_id)
|
|
else:
|
|
logger.warning(f"Receipt {file_id} not found in database")
|
|
missing_files.append(file_id)
|
|
|
|
logger.info(f"Found {len(receipts)} receipts, {len(missing_files)} missing")
|
|
|
|
if missing_files:
|
|
logger.warning(f"Missing files: {missing_files}")
|
|
|
|
if not receipts:
|
|
logger.warning("No valid receipts found")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No valid receipts found for matching.",
|
|
)
|
|
|
|
# Perform matching
|
|
logger.info(
|
|
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
|
|
)
|
|
|
|
# Extract user location from user_tax_info if provided
|
|
user_location = request.user_location # Default/fallback
|
|
if request.user_tax_info:
|
|
# Use state_code from user_tax_info (e.g., "ON", "QC", "BC")
|
|
user_location = request.user_tax_info.state.state_code
|
|
logger.info(
|
|
f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})"
|
|
)
|
|
else:
|
|
logger.info(f"Using default/provided user_location: {user_location}")
|
|
|
|
# Convert ai_rules from Pydantic models to dictionaries if provided
|
|
ai_rules_list = None
|
|
if request.ai_rules:
|
|
ai_rules_list = [
|
|
{"condition": rule.condition, "action": rule.action}
|
|
for rule in request.ai_rules
|
|
]
|
|
logger.info(f"Applying {len(ai_rules_list)} custom AI rules to matching")
|
|
|
|
try:
|
|
matching_results = matching_engine.process_matching(
|
|
receipts,
|
|
transactions,
|
|
user_location=user_location,
|
|
ai_rules=ai_rules_list,
|
|
)
|
|
logger.info(f"Matching completed, got {len(matching_results)} results")
|
|
|
|
# Convert matching results to response format
|
|
match_responses = []
|
|
for result in matching_results:
|
|
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
|
|
final_tax = result.receipt.tax
|
|
# if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
|
|
# final_tax = result.tax_analysis["final_tax_amount"]
|
|
|
|
# Extract flag_for_review and auto_approve from tax_analysis if available
|
|
flag_for_review = None
|
|
auto_approve = None
|
|
if result.tax_analysis:
|
|
flag_for_review = result.tax_analysis.get("flag_for_review")
|
|
auto_approve = result.tax_analysis.get("auto_approve")
|
|
|
|
match_response = MatchResponse(
|
|
receipt_id=result.receipt.id,
|
|
transaction_id=result.transaction.id
|
|
if result.transaction
|
|
else "no_match",
|
|
confidence_score=result.confidence_score * 100,
|
|
match_reason=result.match_reason,
|
|
receipt_vendor=result.receipt.vendor,
|
|
receipt_amount=result.receipt.amount,
|
|
receipt_description=result.receipt.description,
|
|
receipt_category=result.receipt.category,
|
|
receipt_tax_amount=final_tax,
|
|
transaction_vendor=result.transaction.vendor
|
|
if result.transaction
|
|
else "",
|
|
transaction_amount=result.transaction.amount
|
|
if result.transaction
|
|
else 0.0,
|
|
tax_analysis=result.tax_analysis,
|
|
flag_for_review=flag_for_review,
|
|
auto_approve=auto_approve,
|
|
)
|
|
match_responses.append(match_response)
|
|
|
|
# Calculate statistics
|
|
high_confidence = len(
|
|
[r for r in matching_results if r.confidence_score >= 0.8]
|
|
)
|
|
low_confidence = len(
|
|
[r for r in matching_results if r.confidence_score < 0.5]
|
|
)
|
|
avg_score = (
|
|
sum(r.confidence_score for r in matching_results)
|
|
/ len(matching_results)
|
|
if matching_results
|
|
else 0
|
|
)
|
|
|
|
stats = {
|
|
"total": len(match_responses),
|
|
"high_confidence": high_confidence,
|
|
"low_confidence": low_confidence,
|
|
"avg_score": round(avg_score, 2),
|
|
}
|
|
|
|
logger.info(f"Generated stats: {stats}")
|
|
logger.info(
|
|
f"Match-specific completed successfully with {len(match_responses)} matches"
|
|
)
|
|
|
|
return MatchingResponse(matches=match_responses, stats=stats)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Exception in matching section: {str(e)}")
|
|
logger.error(f"Exception type: {type(e)}")
|
|
logger.error(f"Exception args: {e.args}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Unexpected matching error: {str(e)}"
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in match_specific_receipts: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# ============================================================================
|
|
# DATABASE QUERY ENDPOINTS
|
|
# ============================================================================
|
|
|
|
|
|
@app.get("/transactions", tags=["Database Queries"])
|
|
async def get_transactions(
|
|
db: db_dependency,
|
|
user_id: str = None,
|
|
categorization_id: str = None,
|
|
limit: int = 100,
|
|
):
|
|
"""
|
|
Get transactions from the database.
|
|
"""
|
|
try:
|
|
transactions = get_transactions_from_db(db, user_id, categorization_id)
|
|
|
|
# Limit results
|
|
transactions = transactions[:limit]
|
|
|
|
# Convert to response format
|
|
result = []
|
|
for txn in transactions:
|
|
result.append(
|
|
{
|
|
"id": txn.transaction_id,
|
|
"amount": txn.amount,
|
|
"date": txn.date.strftime("%Y-%m-%d"),
|
|
"vendor": txn.vendor,
|
|
"description": txn.description,
|
|
"category": txn.category,
|
|
"tax_amount": txn.tax_amount,
|
|
"categorisation_id": txn.categorisation_id,
|
|
"user_id": txn.user_id,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"transactions": result,
|
|
"count": len(result),
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/receipts", tags=["Database Queries"])
|
|
async def get_receipts(db: db_dependency, limit: int = 100):
|
|
"""
|
|
Get receipts from the database.
|
|
"""
|
|
try:
|
|
receipts = db.query(DBReceipt).limit(limit).all()
|
|
|
|
# Convert to response format
|
|
result = []
|
|
for receipt in receipts:
|
|
result.append(
|
|
{
|
|
"id": receipt.receipt_id,
|
|
"file_id": receipt.file_id,
|
|
"amount": receipt.amount,
|
|
"date": receipt.date.strftime("%Y-%m-%d"),
|
|
"vendor": receipt.vendor,
|
|
"description": receipt.description,
|
|
"category": receipt.category,
|
|
"tax_amount": receipt.tax_amount,
|
|
"confidence": receipt.confidence,
|
|
"extraction_success": receipt.extraction_success,
|
|
"error_message": receipt.error_message,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"receipts": result,
|
|
"count": len(result),
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/uploaded-files", tags=["Database Queries"])
|
|
async def get_uploaded_files(db: db_dependency, limit: int = 100):
|
|
"""
|
|
Get uploaded files from the database.
|
|
"""
|
|
try:
|
|
uploaded_files = db.query(DBUploadedFile).limit(limit).all()
|
|
|
|
# Convert to response format
|
|
result = []
|
|
for file in uploaded_files:
|
|
result.append(
|
|
{
|
|
"file_id": file.file_id,
|
|
"filename": file.filename,
|
|
"file_path": file.file_path,
|
|
"file_type": file.file_type,
|
|
"upload_date": file.upload_date.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"status": file.status,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"uploaded_files": result,
|
|
"count": len(result),
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# ============================================================================
|
|
# RULES MANAGEMENT ENDPOINTS
|
|
# ============================================================================
|
|
|
|
|
|
@app.post("/rules", tags=["AI Rules Management"])
|
|
async def add_rule(request: RuleRequest):
|
|
"""
|
|
Add a new AI rule for transaction matching.
|
|
"""
|
|
try:
|
|
new_rule = AIRule(
|
|
name=request.name,
|
|
condition=request.condition,
|
|
action=request.action,
|
|
source=request.source,
|
|
)
|
|
|
|
matching_engine.rules_engine.rules.append(new_rule)
|
|
|
|
return {"message": f"Rule '{request.name}' added successfully"}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/rules", tags=["AI Rules Management"])
|
|
async def get_rules():
|
|
"""
|
|
Get all current AI rules.
|
|
"""
|
|
try:
|
|
rules = []
|
|
for rule in matching_engine.rules_engine.rules:
|
|
rules.append(
|
|
{
|
|
"name": rule.name,
|
|
"condition": rule.condition,
|
|
"action": rule.action,
|
|
"source": rule.source,
|
|
"status": rule.status,
|
|
}
|
|
)
|
|
|
|
return {"rules": rules}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/rules/{rule_name}", tags=["AI Rules Management"])
|
|
async def delete_rule(rule_name: str):
|
|
"""
|
|
Delete an AI rule by name.
|
|
"""
|
|
try:
|
|
rules = matching_engine.rules_engine.rules
|
|
for i, rule in enumerate(rules):
|
|
if rule.name == rule_name:
|
|
del rules[i]
|
|
return {"message": f"Rule '{rule_name}' deleted successfully"}
|
|
|
|
raise HTTPException(status_code=404, detail=f"Rule '{rule_name}' not found")
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# ============================================================================
|
|
# STATISTICS ENDPOINT
|
|
# ============================================================================
|
|
|
|
|
|
@app.get("/stats", tags=["Statistics"])
|
|
async def get_stats(db: db_dependency):
|
|
"""
|
|
Get system statistics.
|
|
"""
|
|
try:
|
|
# Count transactions, receipts, and uploaded files from database
|
|
total_transactions = db.query(DBTransaction).count()
|
|
total_receipts = db.query(DBReceipt).count()
|
|
total_uploaded_files = db.query(DBUploadedFile).count()
|
|
|
|
return {
|
|
"total_transactions": total_transactions,
|
|
"total_receipts": total_receipts,
|
|
"total_uploaded_files": total_uploaded_files,
|
|
"rules_count": len(matching_engine.rules_engine.rules),
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8654)
|