Refactor code for improved readability and maintainability across multiple files

This commit is contained in:
bolade
2025-08-07 09:06:05 +01:00
parent 1f530da7c4
commit 9698e2fcaf
5 changed files with 224 additions and 123 deletions
+78 -29
View File
@@ -1,9 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any, List from typing import Any, Dict, List
import config
from models import Receipt, Transaction from models import Receipt, Transaction
from tax_rules_engine import TaxRulesEngine from tax_rules_engine import TaxRulesEngine
@dataclass @dataclass
class AIRule: class AIRule:
name: str name: str
@@ -12,48 +13,88 @@ class AIRule:
source: str source: str
status: str = "active" status: str = "active"
class AIRulesEngine: class AIRulesEngine:
def __init__(self): def __init__(self):
self.rules: List[AIRule] = [] self.rules: List[AIRule] = []
self.tax_rules_engine = TaxRulesEngine() self.tax_rules_engine = TaxRulesEngine()
self._load_default_rules() self._load_default_rules()
def _load_default_rules(self): def _load_default_rules(self):
self.rules = [ self.rules = [
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"), AIRule(
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"), "exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"), ),
AIRule(
"same_vendor_same_date",
"vendor_match and date_diff <= 1",
"high_confidence",
"system",
),
AIRule(
"gas_station_pattern",
"vendor_contains_gas_or_fuel",
"categorize_transport",
"system",
),
# Tax-related rules # Tax-related rules
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"), AIRule(
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"), "fx_currency_mismatch",
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system") "currency_mismatch",
"flag_fx_review",
"tax_system",
),
AIRule(
"meals_entertainment",
"is_meals_entertainment",
"apply_me_tax_rule",
"tax_system",
),
AIRule(
"provincial_tax_calculation",
"has_address_info",
"calculate_provincial_tax",
"tax_system",
),
] ]
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]: def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}} results = {
"auto_approve": False,
"confidence_boost": 0,
"category": None,
"tax_analysis": {},
}
for rule in self.rules: for rule in self.rules:
if rule.status != "active": if rule.status != "active":
continue continue
if self._evaluate_condition(rule.condition, receipt, transaction): if self._evaluate_condition(rule.condition, receipt, transaction):
self._execute_action(rule.action, results, receipt, transaction) self._execute_action(rule.action, results, receipt, transaction)
return results return results
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool: def _evaluate_condition(
self, condition: str, receipt: Receipt, transaction: Transaction
) -> bool:
"""Safely evaluate rule conditions without using eval()""" """Safely evaluate rule conditions without using eval()"""
amount_diff = abs(receipt.amount - abs(transaction.amount)) amount_diff = abs(receipt.amount - abs(transaction.amount))
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days) date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower() vendor_match = (
receipt.vendor.lower() in transaction.vendor.lower()
or transaction.vendor.lower() in receipt.vendor.lower()
)
vendor_lower = receipt.vendor.lower() vendor_lower = receipt.vendor.lower()
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
# Tax-related conditions # Tax-related conditions
currency_mismatch = receipt.currency != transaction.currency currency_mismatch = receipt.currency != transaction.currency
is_meals_entertainment = receipt.is_meals_entertainment is_meals_entertainment = receipt.is_meals_entertainment
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None has_address_info = (
receipt.billing_address is not None or receipt.shipping_address is not None
)
# Handle specific condition types safely # Handle specific condition types safely
if condition == "amount_diff <= 0.01": if condition == "amount_diff <= 0.01":
return amount_diff <= 0.01 return amount_diff <= 0.01
@@ -86,14 +127,20 @@ class AIRulesEngine:
"min": min, "min": min,
"max": max, "max": max,
"sum": sum, "sum": sum,
"round": round "round": round,
} }
return eval(condition, safe_globals, {}) return eval(condition, safe_globals, {})
except (SyntaxError, NameError, TypeError) as e: except (SyntaxError, NameError, TypeError) as e:
print(f"Warning: Invalid condition '{condition}': {e}") print(f"Warning: Invalid condition '{condition}': {e}")
return False return False
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction): def _execute_action(
self,
action: str,
results: Dict[str, Any],
receipt: Receipt,
transaction: Transaction,
):
if action == "auto_approve": if action == "auto_approve":
results["auto_approve"] = True results["auto_approve"] = True
elif action == "high_confidence": elif action == "high_confidence":
@@ -114,13 +161,15 @@ class AIRulesEngine:
# Calculate provincial tax # Calculate provincial tax
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt) tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
results["tax_analysis"]["sales_tax"] = tax_result results["tax_analysis"]["sales_tax"] = tax_result
def add_rule(self, rule: AIRule): def add_rule(self, rule: AIRule):
self.rules.append(rule) self.rules.append(rule)
def remove_rule(self, rule_name: str): def remove_rule(self, rule_name: str):
self.rules = [r for r in self.rules if r.name != rule_name] self.rules = [r for r in self.rules if r.name != rule_name]
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]: def apply_tax_rules(
self, receipt: Receipt, transaction: Transaction = None
) -> Dict[str, Any]:
"""Apply all tax rules to a receipt/transaction pair""" """Apply all tax rules to a receipt/transaction pair"""
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction) return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
+25 -7
View File
@@ -1,13 +1,16 @@
from pydantic import BaseModel
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel
class AddressRequest(BaseModel): class AddressRequest(BaseModel):
province: str province: str
city: str city: str
postal_code: str postal_code: str
country: str = "Canada" country: str = "Canada"
class ReceiptRequest(BaseModel): class ReceiptRequest(BaseModel):
id: str id: str
file_name: str file_name: str
@@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel):
currency: str = "CAD" currency: str = "CAD"
is_meals_entertainment: bool = False is_meals_entertainment: bool = False
class TransactionRequest(BaseModel): class TransactionRequest(BaseModel):
id: str id: str
transaction_date: datetime transaction_date: datetime
@@ -34,6 +38,7 @@ class TransactionRequest(BaseModel):
currency: str = "CAD" currency: str = "CAD"
fx_rate: Optional[float] = None fx_rate: Optional[float] = None
class AssetRequest(BaseModel): class AssetRequest(BaseModel):
id: str id: str
name: str name: str
@@ -44,42 +49,51 @@ class AssetRequest(BaseModel):
cca_rate: float cca_rate: float
asset_class: str asset_class: str
class MatchingRequest(BaseModel): class MatchingRequest(BaseModel):
receipt_ids: List[str] receipt_ids: List[str]
transaction_ids: List[str] transaction_ids: List[str]
class MatchResponse(BaseModel): class MatchResponse(BaseModel):
receipt_id: str receipt_id: str
transaction_id: str transaction_id: str
confidence_score: float confidence_score: float
match_reason: str match_reason: str
tax_analysis: Optional[dict] = None receipt_vendor: str
# Currency information receipt_amount: float
receipt_currency: str = "CAD" receipt_description: str
transaction_currency: str = "CAD" receipt_category: str
currency_match: bool = True receipt_tax_amount: float
transaction_vendor: str
transaction_amount: float
class MatchingResponse(BaseModel): class MatchingResponse(BaseModel):
matches: List[MatchResponse] matches: List[MatchResponse]
stats: dict stats: dict
class ApprovalRequest(BaseModel): class ApprovalRequest(BaseModel):
match_id: str match_id: str
approved: bool approved: bool
reason: Optional[str] = None reason: Optional[str] = None
class RuleRequest(BaseModel): class RuleRequest(BaseModel):
name: str name: str
condition: str condition: str
action: str action: str
source: str = "user" source: str = "user"
class DocumentUploadResponse(BaseModel): class DocumentUploadResponse(BaseModel):
file_id: str file_id: str
filename: str filename: str
upload_date: datetime upload_date: datetime
status: str status: str
class DocumentProcessResponse(BaseModel): class DocumentProcessResponse(BaseModel):
file_id: str file_id: str
extraction_success: bool extraction_success: bool
@@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel):
confidence: Optional[float] = None confidence: Optional[float] = None
error: Optional[str] = None error: Optional[str] = None
# New tax-related models # New tax-related models
class TaxCalculationRequest(BaseModel): class TaxCalculationRequest(BaseModel):
receipt_id: str receipt_id: str
transaction_id: Optional[str] = None transaction_id: Optional[str] = None
class TaxCalculationResponse(BaseModel): class TaxCalculationResponse(BaseModel):
receipt_id: str receipt_id: str
rules_applied: List[str] rules_applied: List[str]
@@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel):
fx_analysis: Optional[dict] = None fx_analysis: Optional[dict] = None
meals_entertainment: dict meals_entertainment: dict
class DepreciationRequest(BaseModel): class DepreciationRequest(BaseModel):
asset: AssetRequest asset: AssetRequest
year: int year: int
method: str # "straight_line" or "cca" method: str # "straight_line" or "cca"
class DepreciationResponse(BaseModel): class DepreciationResponse(BaseModel):
asset_id: str asset_id: str
year: int year: int
@@ -117,4 +135,4 @@ class DepreciationResponse(BaseModel):
book_value: float book_value: float
total_depreciation: Optional[float] = None total_depreciation: Optional[float] = None
success: bool success: bool
error: Optional[str] = None error: Optional[str] = None
+39 -23
View File
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional
import json import json
import os import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List
@dataclass @dataclass
class FeedbackLog: class FeedbackLog:
@@ -13,48 +14,63 @@ class FeedbackLog:
timestamp: datetime timestamp: datetime
user_id: str user_id: str
class FeedbackLogger: class FeedbackLogger:
def __init__(self, log_file: str = "feedback_logs.json"): def __init__(self, log_file: str = "feedback_logs.json"):
self.log_file = log_file self.log_file = log_file
self.logs: List[FeedbackLog] = self._load_logs() self.logs: List[FeedbackLog] = self._load_logs()
def _load_logs(self) -> List[FeedbackLog]: def _load_logs(self) -> List[FeedbackLog]:
if not os.path.exists(self.log_file): if not os.path.exists(self.log_file):
return [] return []
try: try:
with open(self.log_file, 'r') as f: with open(self.log_file, "r") as f:
data = json.load(f) data = json.load(f)
return [FeedbackLog(**log) for log in data] return [FeedbackLog(**log) for log in data]
except: except Exception:
return [] return []
def _save_logs(self): def _save_logs(self):
with open(self.log_file, 'w') as f: with open(self.log_file, "w") as f:
json.dump([{ json.dump(
'transaction_id': log.transaction_id, [
'original_match': log.original_match, {
'correction': log.correction, "transaction_id": log.transaction_id,
'reason': log.reason, "original_match": log.original_match,
'timestamp': log.timestamp.isoformat(), "correction": log.correction,
'user_id': log.user_id "reason": log.reason,
} for log in self.logs], f, indent=2) "timestamp": log.timestamp.isoformat(),
"user_id": log.user_id,
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str): }
for log in self.logs
],
f,
indent=2,
)
def log_override(
self,
transaction_id: str,
original_match: str,
correction: str,
reason: str,
user_id: str,
):
log = FeedbackLog( log = FeedbackLog(
transaction_id=transaction_id, transaction_id=transaction_id,
original_match=original_match, original_match=original_match,
correction=correction, correction=correction,
reason=reason, reason=reason,
timestamp=datetime.now(), timestamp=datetime.now(),
user_id=user_id user_id=user_id,
) )
self.logs.append(log) self.logs.append(log)
self._save_logs() self._save_logs()
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]: def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
return [log for log in self.logs if log.transaction_id == transaction_id] return [log for log in self.logs if log.transaction_id == transaction_id]
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]: def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
cutoff = datetime.now() - timedelta(days=days) cutoff = datetime.now() - timedelta(days=days)
return [log for log in self.logs if log.timestamp > cutoff] return [log for log in self.logs if log.timestamp > cutoff]
+81 -62
View File
@@ -1,13 +1,13 @@
import os import os
import io
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List
class GoogleDriveSync: class GoogleDriveSync:
def __init__(self): def __init__(self):
self.service = None self.service = None
self.processed_files = set() self.processed_files = set()
def authenticate(self): def authenticate(self):
"""Authenticate with Google Drive API""" """Authenticate with Google Drive API"""
try: try:
@@ -15,111 +15,130 @@ class GoogleDriveSync:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build from googleapiclient.discovery import build
SCOPES = ['https://www.googleapis.com/auth/drive.readonly'] SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
# Load existing credentials # Load existing credentials
if os.path.exists('token.json'): if os.path.exists("token.json"):
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES) self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
# If no valid credentials available, let user log in # If no valid credentials available, let user log in
if not self.creds or not self.creds.valid: if not self.creds or not self.creds.valid:
if self.creds and self.creds.expired and self.creds.refresh_token: if self.creds and self.creds.expired and self.creds.refresh_token:
self.creds.refresh(Request()) self.creds.refresh(Request())
else: else:
if not os.path.exists('credentials.json'): if not os.path.exists("credentials.json"):
raise Exception("credentials.json not found. Please download from Google Cloud Console.") raise Exception(
"credentials.json not found. Please download from Google Cloud Console."
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES) )
flow = InstalledAppFlow.from_client_secrets_file(
"credentials.json", SCOPES
)
self.creds = flow.run_local_server(port=0) self.creds = flow.run_local_server(port=0)
# Save credentials for next run # Save credentials for next run
with open('token.json', 'w') as token: with open("token.json", "w") as token:
token.write(self.creds.to_json()) token.write(self.creds.to_json())
# Build the Drive service # Build the Drive service
self.service = build('drive', 'v3', credentials=self.creds) self.service = build("drive", "v3", credentials=self.creds)
return True return True
except Exception as e: except Exception as e:
print(f"Authentication error: {e}") print(f"Authentication error: {e}")
return False return False
def list_folders(self) -> List[Dict[str, Any]]: def list_folders(self) -> List[Dict[str, Any]]:
"""List all folders in Google Drive""" """List all folders in Google Drive"""
if not self.service: if not self.service:
if not self.authenticate(): if not self.authenticate():
return [] return []
try: try:
results = self.service.files().list( results = (
q="mimeType='application/vnd.google-apps.folder'", self.service.files()
pageSize=100, .list(
fields="nextPageToken, files(id, name, createdTime, modifiedTime)" q="mimeType='application/vnd.google-apps.folder'",
).execute() pageSize=100,
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
return results.get('files', []) )
.execute()
)
return results.get("files", [])
except Exception as e: except Exception as e:
print(f"Error listing folders: {e}") print(f"Error listing folders: {e}")
return [] return []
def get_folder_info(self, folder_id: str) -> Dict[str, Any]: def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
"""Get information about a Google Drive folder""" """Get information about a Google Drive folder"""
if not self.service: if not self.service:
if not self.authenticate(): if not self.authenticate():
return {} return {}
try: try:
folder = self.service.files().get( folder = (
fileId=folder_id, self.service.files()
fields="id, name, createdTime, modifiedTime" .get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
).execute() .execute()
)
return folder return folder
except Exception as e: except Exception as e:
print(f"Error getting folder info: {e}") print(f"Error getting folder info: {e}")
return {} return {}
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]: async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
"""Process all receipt files from Google Drive""" """Process all receipt files from Google Drive"""
if not self.service: if not self.service:
if not self.authenticate(): if not self.authenticate():
return [] return []
results = [] results = []
try: try:
# File types to look for # File types to look for
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"] file_types = [
"'application/pdf'",
"'image/jpeg'",
"'image/png'",
"'image/gif'",
"'image/bmp'",
]
mime_types = " or ".join(file_types) mime_types = " or ".join(file_types)
# Build query # Build query
query = f"mimeType contains {mime_types}" query = f"mimeType contains {mime_types}"
if folder_id: if folder_id:
query += f" and '{folder_id}' in parents" query += f" and '{folder_id}' in parents"
# Add date filter (last 30 days) # Add date filter (last 30 days)
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z' thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
query += f" and modifiedTime > '{thirty_days_ago}'" query += f" and modifiedTime > '{thirty_days_ago}'"
results_files = self.service.files().list( results_files = (
q=query, self.service.files()
pageSize=100, .list(
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)" q=query,
).execute() pageSize=100,
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
files = results_files.get('files', []) )
files = [file for file in files if file['id'] not in self.processed_files] .execute()
)
files = results_files.get("files", [])
files = [file for file in files if file["id"] not in self.processed_files]
# For demo purposes, return mock results # For demo purposes, return mock results
for file in files[:3]: # Process first 3 files for file in files[:3]: # Process first 3 files
mock_result = { mock_result = {
"file_id": file['id'], "file_id": file["id"],
"filename": file['name'], "filename": file["name"],
"drive_modified": file['modifiedTime'], "drive_modified": file["modifiedTime"],
"file_size": file.get('size', 0), "file_size": file.get("size", 0),
"extraction_success": True, "extraction_success": True,
"vendor": "Demo Vendor", "vendor": "Demo Vendor",
"description": "Coffee and sandwich", "description": "Coffee and sandwich",
@@ -127,12 +146,12 @@ class GoogleDriveSync:
"tax_amount": 2.04, "tax_amount": 2.04,
"date": "2024-01-15", "date": "2024-01-15",
"category": "Food", "category": "Food",
"confidence": 0.95 "confidence": 0.95,
} }
results.append(mock_result) results.append(mock_result)
self.processed_files.add(file['id']) self.processed_files.add(file["id"])
except Exception as e: except Exception as e:
print(f"Error processing Drive files: {e}") print(f"Error processing Drive files: {e}")
return results return results
+1 -2
View File
@@ -8,7 +8,6 @@ from typing import List
from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# Configure logging
from ai_rules import AIRule from ai_rules import AIRule
from api_models import ( from api_models import (
DocumentProcessResponse, DocumentProcessResponse,
@@ -71,7 +70,7 @@ async def root():
@app.post("/transactions/import/csv") @app.post("/transactions/import/csv")
async def import_transactions_csv(file: UploadFile = File(...)): async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""):
""" """
Import transactions from a CSV file (custom bank export format). Import transactions from a CSV file (custom bank export format).
""" """