Refactor code for improved readability and maintainability across multiple files

This commit is contained in:
bolade
2025-08-07 09:06:05 +01:00
parent 1f530da7c4
commit 9698e2fcaf
5 changed files with 224 additions and 123 deletions
+65 -16
View File
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Any, List
import config
from typing import Any, Dict, List
from models import Receipt, Transaction
from tax_rules_engine import TaxRulesEngine
@dataclass
class AIRule:
name: str
@@ -12,6 +13,7 @@ class AIRule:
source: str
status: str = "active"
class AIRulesEngine:
def __init__(self):
self.rules: List[AIRule] = []
@@ -20,17 +22,49 @@ class AIRulesEngine:
def _load_default_rules(self):
self.rules = [
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
AIRule(
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
),
AIRule(
"same_vendor_same_date",
"vendor_match and date_diff <= 1",
"high_confidence",
"system",
),
AIRule(
"gas_station_pattern",
"vendor_contains_gas_or_fuel",
"categorize_transport",
"system",
),
# Tax-related rules
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
AIRule(
"fx_currency_mismatch",
"currency_mismatch",
"flag_fx_review",
"tax_system",
),
AIRule(
"meals_entertainment",
"is_meals_entertainment",
"apply_me_tax_rule",
"tax_system",
),
AIRule(
"provincial_tax_calculation",
"has_address_info",
"calculate_provincial_tax",
"tax_system",
),
]
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
results = {
"auto_approve": False,
"confidence_boost": 0,
"category": None,
"tax_analysis": {},
}
for rule in self.rules:
if rule.status != "active":
@@ -41,18 +75,25 @@ class AIRulesEngine:
return results
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
def _evaluate_condition(
self, condition: str, receipt: Receipt, transaction: Transaction
) -> bool:
"""Safely evaluate rule conditions without using eval()"""
amount_diff = abs(receipt.amount - abs(transaction.amount))
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
vendor_match = (
receipt.vendor.lower() in transaction.vendor.lower()
or transaction.vendor.lower() in receipt.vendor.lower()
)
vendor_lower = receipt.vendor.lower()
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
# Tax-related conditions
currency_mismatch = receipt.currency != transaction.currency
is_meals_entertainment = receipt.is_meals_entertainment
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
has_address_info = (
receipt.billing_address is not None or receipt.shipping_address is not None
)
# Handle specific condition types safely
if condition == "amount_diff <= 0.01":
@@ -86,14 +127,20 @@ class AIRulesEngine:
"min": min,
"max": max,
"sum": sum,
"round": round
"round": round,
}
return eval(condition, safe_globals, {})
except (SyntaxError, NameError, TypeError) as e:
print(f"Warning: Invalid condition '{condition}': {e}")
return False
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
def _execute_action(
self,
action: str,
results: Dict[str, Any],
receipt: Receipt,
transaction: Transaction,
):
if action == "auto_approve":
results["auto_approve"] = True
elif action == "high_confidence":
@@ -121,6 +168,8 @@ class AIRulesEngine:
def remove_rule(self, rule_name: str):
self.rules = [r for r in self.rules if r.name != rule_name]
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
def apply_tax_rules(
self, receipt: Receipt, transaction: Transaction = None
) -> Dict[str, Any]:
"""Apply all tax rules to a receipt/transaction pair"""
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
+24 -6
View File
@@ -1,13 +1,16 @@
from pydantic import BaseModel
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
class AddressRequest(BaseModel):
province: str
city: str
postal_code: str
country: str = "Canada"
class ReceiptRequest(BaseModel):
id: str
file_name: str
@@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel):
currency: str = "CAD"
is_meals_entertainment: bool = False
class TransactionRequest(BaseModel):
id: str
transaction_date: datetime
@@ -34,6 +38,7 @@ class TransactionRequest(BaseModel):
currency: str = "CAD"
fx_rate: Optional[float] = None
class AssetRequest(BaseModel):
id: str
name: str
@@ -44,42 +49,51 @@ class AssetRequest(BaseModel):
cca_rate: float
asset_class: str
class MatchingRequest(BaseModel):
receipt_ids: List[str]
transaction_ids: List[str]
class MatchResponse(BaseModel):
receipt_id: str
transaction_id: str
confidence_score: float
match_reason: str
tax_analysis: Optional[dict] = None
# Currency information
receipt_currency: str = "CAD"
transaction_currency: str = "CAD"
currency_match: bool = True
receipt_vendor: str
receipt_amount: float
receipt_description: str
receipt_category: str
receipt_tax_amount: float
transaction_vendor: str
transaction_amount: float
class MatchingResponse(BaseModel):
matches: List[MatchResponse]
stats: dict
class ApprovalRequest(BaseModel):
match_id: str
approved: bool
reason: Optional[str] = None
class RuleRequest(BaseModel):
name: str
condition: str
action: str
source: str = "user"
class DocumentUploadResponse(BaseModel):
file_id: str
filename: str
upload_date: datetime
status: str
class DocumentProcessResponse(BaseModel):
file_id: str
extraction_success: bool
@@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel):
confidence: Optional[float] = None
error: Optional[str] = None
# New tax-related models
class TaxCalculationRequest(BaseModel):
receipt_id: str
transaction_id: Optional[str] = None
class TaxCalculationResponse(BaseModel):
receipt_id: str
rules_applied: List[str]
@@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel):
fx_analysis: Optional[dict] = None
meals_entertainment: dict
class DepreciationRequest(BaseModel):
asset: AssetRequest
year: int
method: str # "straight_line" or "cca"
class DepreciationResponse(BaseModel):
asset_id: str
year: int
+32 -16
View File
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional
import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List
@dataclass
class FeedbackLog:
@@ -13,6 +14,7 @@ class FeedbackLog:
timestamp: datetime
user_id: str
class FeedbackLogger:
def __init__(self, log_file: str = "feedback_logs.json"):
self.log_file = log_file
@@ -23,31 +25,45 @@ class FeedbackLogger:
return []
try:
with open(self.log_file, 'r') as f:
with open(self.log_file, "r") as f:
data = json.load(f)
return [FeedbackLog(**log) for log in data]
except:
except Exception:
return []
def _save_logs(self):
with open(self.log_file, 'w') as f:
json.dump([{
'transaction_id': log.transaction_id,
'original_match': log.original_match,
'correction': log.correction,
'reason': log.reason,
'timestamp': log.timestamp.isoformat(),
'user_id': log.user_id
} for log in self.logs], f, indent=2)
with open(self.log_file, "w") as f:
json.dump(
[
{
"transaction_id": log.transaction_id,
"original_match": log.original_match,
"correction": log.correction,
"reason": log.reason,
"timestamp": log.timestamp.isoformat(),
"user_id": log.user_id,
}
for log in self.logs
],
f,
indent=2,
)
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str):
def log_override(
self,
transaction_id: str,
original_match: str,
correction: str,
reason: str,
user_id: str,
):
log = FeedbackLog(
transaction_id=transaction_id,
original_match=original_match,
correction=correction,
reason=reason,
timestamp=datetime.now(),
user_id=user_id
user_id=user_id,
)
self.logs.append(log)
self._save_logs()
+50 -31
View File
@@ -1,7 +1,7 @@
import os
import io
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from typing import Any, Dict, List
class GoogleDriveSync:
def __init__(self):
@@ -16,29 +16,33 @@ class GoogleDriveSync:
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
# Load existing credentials
if os.path.exists('token.json'):
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES)
if os.path.exists("token.json"):
self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
# If no valid credentials available, let user log in
if not self.creds or not self.creds.valid:
if self.creds and self.creds.expired and self.creds.refresh_token:
self.creds.refresh(Request())
else:
if not os.path.exists('credentials.json'):
raise Exception("credentials.json not found. Please download from Google Cloud Console.")
if not os.path.exists("credentials.json"):
raise Exception(
"credentials.json not found. Please download from Google Cloud Console."
)
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
flow = InstalledAppFlow.from_client_secrets_file(
"credentials.json", SCOPES
)
self.creds = flow.run_local_server(port=0)
# Save credentials for next run
with open('token.json', 'w') as token:
with open("token.json", "w") as token:
token.write(self.creds.to_json())
# Build the Drive service
self.service = build('drive', 'v3', credentials=self.creds)
self.service = build("drive", "v3", credentials=self.creds)
return True
except Exception as e:
@@ -52,13 +56,17 @@ class GoogleDriveSync:
return []
try:
results = self.service.files().list(
results = (
self.service.files()
.list(
q="mimeType='application/vnd.google-apps.folder'",
pageSize=100,
fields="nextPageToken, files(id, name, createdTime, modifiedTime)"
).execute()
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
)
.execute()
)
return results.get('files', [])
return results.get("files", [])
except Exception as e:
print(f"Error listing folders: {e}")
@@ -71,10 +79,11 @@ class GoogleDriveSync:
return {}
try:
folder = self.service.files().get(
fileId=folder_id,
fields="id, name, createdTime, modifiedTime"
).execute()
folder = (
self.service.files()
.get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
.execute()
)
return folder
@@ -92,7 +101,13 @@ class GoogleDriveSync:
try:
# File types to look for
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"]
file_types = [
"'application/pdf'",
"'image/jpeg'",
"'image/png'",
"'image/gif'",
"'image/bmp'",
]
mime_types = " or ".join(file_types)
# Build query
@@ -101,25 +116,29 @@ class GoogleDriveSync:
query += f" and '{folder_id}' in parents"
# Add date filter (last 30 days)
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z'
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
query += f" and modifiedTime > '{thirty_days_ago}'"
results_files = self.service.files().list(
results_files = (
self.service.files()
.list(
q=query,
pageSize=100,
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)"
).execute()
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
)
.execute()
)
files = results_files.get('files', [])
files = [file for file in files if file['id'] not in self.processed_files]
files = results_files.get("files", [])
files = [file for file in files if file["id"] not in self.processed_files]
# For demo purposes, return mock results
for file in files[:3]: # Process first 3 files
mock_result = {
"file_id": file['id'],
"filename": file['name'],
"drive_modified": file['modifiedTime'],
"file_size": file.get('size', 0),
"file_id": file["id"],
"filename": file["name"],
"drive_modified": file["modifiedTime"],
"file_size": file.get("size", 0),
"extraction_success": True,
"vendor": "Demo Vendor",
"description": "Coffee and sandwich",
@@ -127,10 +146,10 @@ class GoogleDriveSync:
"tax_amount": 2.04,
"date": "2024-01-15",
"category": "Food",
"confidence": 0.95
"confidence": 0.95,
}
results.append(mock_result)
self.processed_files.add(file['id'])
self.processed_files.add(file["id"])
except Exception as e:
print(f"Error processing Drive files: {e}")
+1 -2
View File
@@ -8,7 +8,6 @@ from typing import List
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
# Configure logging
from ai_rules import AIRule
from api_models import (
DocumentProcessResponse,
@@ -71,7 +70,7 @@ async def root():
@app.post("/transactions/import/csv")
async def import_transactions_csv(file: UploadFile = File(...)):
async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""):
"""
Import transactions from a CSV file (custom bank export format).
"""