Refactor code for improved readability and maintainability across multiple files

This commit is contained in:
bolade
2025-08-07 09:06:05 +01:00
parent 1f530da7c4
commit 9698e2fcaf
5 changed files with 224 additions and 123 deletions
+78 -29
View File
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Any, List
import config
from typing import Any, Dict, List
from models import Receipt, Transaction
from tax_rules_engine import TaxRulesEngine
@dataclass
class AIRule:
name: str
@@ -12,48 +13,88 @@ class AIRule:
source: str
status: str = "active"
class AIRulesEngine:
def __init__(self):
self.rules: List[AIRule] = []
self.tax_rules_engine = TaxRulesEngine()
self._load_default_rules()
def _load_default_rules(self):
self.rules = [
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
AIRule(
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
),
AIRule(
"same_vendor_same_date",
"vendor_match and date_diff <= 1",
"high_confidence",
"system",
),
AIRule(
"gas_station_pattern",
"vendor_contains_gas_or_fuel",
"categorize_transport",
"system",
),
# Tax-related rules
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
AIRule(
"fx_currency_mismatch",
"currency_mismatch",
"flag_fx_review",
"tax_system",
),
AIRule(
"meals_entertainment",
"is_meals_entertainment",
"apply_me_tax_rule",
"tax_system",
),
AIRule(
"provincial_tax_calculation",
"has_address_info",
"calculate_provincial_tax",
"tax_system",
),
]
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
results = {
"auto_approve": False,
"confidence_boost": 0,
"category": None,
"tax_analysis": {},
}
for rule in self.rules:
if rule.status != "active":
continue
if self._evaluate_condition(rule.condition, receipt, transaction):
self._execute_action(rule.action, results, receipt, transaction)
return results
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
def _evaluate_condition(
self, condition: str, receipt: Receipt, transaction: Transaction
) -> bool:
"""Safely evaluate rule conditions without using eval()"""
amount_diff = abs(receipt.amount - abs(transaction.amount))
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
vendor_match = (
receipt.vendor.lower() in transaction.vendor.lower()
or transaction.vendor.lower() in receipt.vendor.lower()
)
vendor_lower = receipt.vendor.lower()
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
# Tax-related conditions
currency_mismatch = receipt.currency != transaction.currency
is_meals_entertainment = receipt.is_meals_entertainment
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
has_address_info = (
receipt.billing_address is not None or receipt.shipping_address is not None
)
# Handle specific condition types safely
if condition == "amount_diff <= 0.01":
return amount_diff <= 0.01
@@ -86,14 +127,20 @@ class AIRulesEngine:
"min": min,
"max": max,
"sum": sum,
"round": round
"round": round,
}
return eval(condition, safe_globals, {})
except (SyntaxError, NameError, TypeError) as e:
print(f"Warning: Invalid condition '{condition}': {e}")
return False
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
def _execute_action(
self,
action: str,
results: Dict[str, Any],
receipt: Receipt,
transaction: Transaction,
):
if action == "auto_approve":
results["auto_approve"] = True
elif action == "high_confidence":
@@ -114,13 +161,15 @@ class AIRulesEngine:
# Calculate provincial tax
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
results["tax_analysis"]["sales_tax"] = tax_result
def add_rule(self, rule: AIRule):
self.rules.append(rule)
def remove_rule(self, rule_name: str):
self.rules = [r for r in self.rules if r.name != rule_name]
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
def apply_tax_rules(
self, receipt: Receipt, transaction: Transaction = None
) -> Dict[str, Any]:
"""Apply all tax rules to a receipt/transaction pair"""
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
+25 -7
View File
@@ -1,13 +1,16 @@
from pydantic import BaseModel
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
class AddressRequest(BaseModel):
province: str
city: str
postal_code: str
country: str = "Canada"
class ReceiptRequest(BaseModel):
id: str
file_name: str
@@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel):
currency: str = "CAD"
is_meals_entertainment: bool = False
class TransactionRequest(BaseModel):
id: str
transaction_date: datetime
@@ -34,6 +38,7 @@ class TransactionRequest(BaseModel):
currency: str = "CAD"
fx_rate: Optional[float] = None
class AssetRequest(BaseModel):
id: str
name: str
@@ -44,42 +49,51 @@ class AssetRequest(BaseModel):
cca_rate: float
asset_class: str
class MatchingRequest(BaseModel):
receipt_ids: List[str]
transaction_ids: List[str]
class MatchResponse(BaseModel):
receipt_id: str
transaction_id: str
confidence_score: float
match_reason: str
tax_analysis: Optional[dict] = None
# Currency information
receipt_currency: str = "CAD"
transaction_currency: str = "CAD"
currency_match: bool = True
receipt_vendor: str
receipt_amount: float
receipt_description: str
receipt_category: str
receipt_tax_amount: float
transaction_vendor: str
transaction_amount: float
class MatchingResponse(BaseModel):
matches: List[MatchResponse]
stats: dict
class ApprovalRequest(BaseModel):
match_id: str
approved: bool
reason: Optional[str] = None
class RuleRequest(BaseModel):
name: str
condition: str
action: str
source: str = "user"
class DocumentUploadResponse(BaseModel):
file_id: str
filename: str
upload_date: datetime
status: str
class DocumentProcessResponse(BaseModel):
file_id: str
extraction_success: bool
@@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel):
confidence: Optional[float] = None
error: Optional[str] = None
# New tax-related models
class TaxCalculationRequest(BaseModel):
receipt_id: str
transaction_id: Optional[str] = None
class TaxCalculationResponse(BaseModel):
receipt_id: str
rules_applied: List[str]
@@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel):
fx_analysis: Optional[dict] = None
meals_entertainment: dict
class DepreciationRequest(BaseModel):
asset: AssetRequest
year: int
method: str # "straight_line" or "cca"
class DepreciationResponse(BaseModel):
asset_id: str
year: int
@@ -117,4 +135,4 @@ class DepreciationResponse(BaseModel):
book_value: float
total_depreciation: Optional[float] = None
success: bool
error: Optional[str] = None
error: Optional[str] = None
+39 -23
View File
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional
import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List
@dataclass
class FeedbackLog:
@@ -13,48 +14,63 @@ class FeedbackLog:
timestamp: datetime
user_id: str
class FeedbackLogger:
def __init__(self, log_file: str = "feedback_logs.json"):
self.log_file = log_file
self.logs: List[FeedbackLog] = self._load_logs()
def _load_logs(self) -> List[FeedbackLog]:
if not os.path.exists(self.log_file):
return []
try:
with open(self.log_file, 'r') as f:
with open(self.log_file, "r") as f:
data = json.load(f)
return [FeedbackLog(**log) for log in data]
except:
except Exception:
return []
def _save_logs(self):
with open(self.log_file, 'w') as f:
json.dump([{
'transaction_id': log.transaction_id,
'original_match': log.original_match,
'correction': log.correction,
'reason': log.reason,
'timestamp': log.timestamp.isoformat(),
'user_id': log.user_id
} for log in self.logs], f, indent=2)
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str):
with open(self.log_file, "w") as f:
json.dump(
[
{
"transaction_id": log.transaction_id,
"original_match": log.original_match,
"correction": log.correction,
"reason": log.reason,
"timestamp": log.timestamp.isoformat(),
"user_id": log.user_id,
}
for log in self.logs
],
f,
indent=2,
)
def log_override(
self,
transaction_id: str,
original_match: str,
correction: str,
reason: str,
user_id: str,
):
log = FeedbackLog(
transaction_id=transaction_id,
original_match=original_match,
correction=correction,
reason=reason,
timestamp=datetime.now(),
user_id=user_id
user_id=user_id,
)
self.logs.append(log)
self._save_logs()
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
return [log for log in self.logs if log.transaction_id == transaction_id]
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
cutoff = datetime.now() - timedelta(days=days)
return [log for log in self.logs if log.timestamp > cutoff]
return [log for log in self.logs if log.timestamp > cutoff]
+81 -62
View File
@@ -1,13 +1,13 @@
import os
import io
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from typing import Any, Dict, List
class GoogleDriveSync:
def __init__(self):
self.service = None
self.processed_files = set()
def authenticate(self):
"""Authenticate with Google Drive API"""
try:
@@ -15,111 +15,130 @@ class GoogleDriveSync:
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
# Load existing credentials
if os.path.exists('token.json'):
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES)
if os.path.exists("token.json"):
self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
# If no valid credentials available, let user log in
if not self.creds or not self.creds.valid:
if self.creds and self.creds.expired and self.creds.refresh_token:
self.creds.refresh(Request())
else:
if not os.path.exists('credentials.json'):
raise Exception("credentials.json not found. Please download from Google Cloud Console.")
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
if not os.path.exists("credentials.json"):
raise Exception(
"credentials.json not found. Please download from Google Cloud Console."
)
flow = InstalledAppFlow.from_client_secrets_file(
"credentials.json", SCOPES
)
self.creds = flow.run_local_server(port=0)
# Save credentials for next run
with open('token.json', 'w') as token:
with open("token.json", "w") as token:
token.write(self.creds.to_json())
# Build the Drive service
self.service = build('drive', 'v3', credentials=self.creds)
self.service = build("drive", "v3", credentials=self.creds)
return True
except Exception as e:
print(f"Authentication error: {e}")
return False
def list_folders(self) -> List[Dict[str, Any]]:
"""List all folders in Google Drive"""
if not self.service:
if not self.authenticate():
return []
try:
results = self.service.files().list(
q="mimeType='application/vnd.google-apps.folder'",
pageSize=100,
fields="nextPageToken, files(id, name, createdTime, modifiedTime)"
).execute()
return results.get('files', [])
results = (
self.service.files()
.list(
q="mimeType='application/vnd.google-apps.folder'",
pageSize=100,
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
)
.execute()
)
return results.get("files", [])
except Exception as e:
print(f"Error listing folders: {e}")
return []
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
"""Get information about a Google Drive folder"""
if not self.service:
if not self.authenticate():
return {}
try:
folder = self.service.files().get(
fileId=folder_id,
fields="id, name, createdTime, modifiedTime"
).execute()
folder = (
self.service.files()
.get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
.execute()
)
return folder
except Exception as e:
print(f"Error getting folder info: {e}")
return {}
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
"""Process all receipt files from Google Drive"""
if not self.service:
if not self.authenticate():
return []
results = []
try:
# File types to look for
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"]
file_types = [
"'application/pdf'",
"'image/jpeg'",
"'image/png'",
"'image/gif'",
"'image/bmp'",
]
mime_types = " or ".join(file_types)
# Build query
query = f"mimeType contains {mime_types}"
if folder_id:
query += f" and '{folder_id}' in parents"
# Add date filter (last 30 days)
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z'
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
query += f" and modifiedTime > '{thirty_days_ago}'"
results_files = self.service.files().list(
q=query,
pageSize=100,
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)"
).execute()
files = results_files.get('files', [])
files = [file for file in files if file['id'] not in self.processed_files]
results_files = (
self.service.files()
.list(
q=query,
pageSize=100,
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
)
.execute()
)
files = results_files.get("files", [])
files = [file for file in files if file["id"] not in self.processed_files]
# For demo purposes, return mock results
for file in files[:3]: # Process first 3 files
mock_result = {
"file_id": file['id'],
"filename": file['name'],
"drive_modified": file['modifiedTime'],
"file_size": file.get('size', 0),
"file_id": file["id"],
"filename": file["name"],
"drive_modified": file["modifiedTime"],
"file_size": file.get("size", 0),
"extraction_success": True,
"vendor": "Demo Vendor",
"description": "Coffee and sandwich",
@@ -127,12 +146,12 @@ class GoogleDriveSync:
"tax_amount": 2.04,
"date": "2024-01-15",
"category": "Food",
"confidence": 0.95
"confidence": 0.95,
}
results.append(mock_result)
self.processed_files.add(file['id'])
self.processed_files.add(file["id"])
except Exception as e:
print(f"Error processing Drive files: {e}")
return results
return results
+1 -2
View File
@@ -8,7 +8,6 @@ from typing import List
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
# Configure logging
from ai_rules import AIRule
from api_models import (
DocumentProcessResponse,
@@ -71,7 +70,7 @@ async def root():
@app.post("/transactions/import/csv")
async def import_transactions_csv(file: UploadFile = File(...)):
async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""):
"""
Import transactions from a CSV file (custom bank export format).
"""