Refactor code for improved readability and maintainability across multiple files
This commit is contained in:
+78
-29
@@ -1,9 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List
|
||||
import config
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from models import Receipt, Transaction
|
||||
from tax_rules_engine import TaxRulesEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIRule:
|
||||
name: str
|
||||
@@ -12,48 +13,88 @@ class AIRule:
|
||||
source: str
|
||||
status: str = "active"
|
||||
|
||||
|
||||
class AIRulesEngine:
|
||||
def __init__(self):
|
||||
self.rules: List[AIRule] = []
|
||||
self.tax_rules_engine = TaxRulesEngine()
|
||||
self._load_default_rules()
|
||||
|
||||
|
||||
def _load_default_rules(self):
|
||||
self.rules = [
|
||||
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
|
||||
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
|
||||
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
|
||||
AIRule(
|
||||
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
|
||||
),
|
||||
AIRule(
|
||||
"same_vendor_same_date",
|
||||
"vendor_match and date_diff <= 1",
|
||||
"high_confidence",
|
||||
"system",
|
||||
),
|
||||
AIRule(
|
||||
"gas_station_pattern",
|
||||
"vendor_contains_gas_or_fuel",
|
||||
"categorize_transport",
|
||||
"system",
|
||||
),
|
||||
# Tax-related rules
|
||||
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
|
||||
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
|
||||
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
|
||||
AIRule(
|
||||
"fx_currency_mismatch",
|
||||
"currency_mismatch",
|
||||
"flag_fx_review",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"meals_entertainment",
|
||||
"is_meals_entertainment",
|
||||
"apply_me_tax_rule",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"provincial_tax_calculation",
|
||||
"has_address_info",
|
||||
"calculate_provincial_tax",
|
||||
"tax_system",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
||||
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
|
||||
|
||||
results = {
|
||||
"auto_approve": False,
|
||||
"confidence_boost": 0,
|
||||
"category": None,
|
||||
"tax_analysis": {},
|
||||
}
|
||||
|
||||
for rule in self.rules:
|
||||
if rule.status != "active":
|
||||
continue
|
||||
|
||||
|
||||
if self._evaluate_condition(rule.condition, receipt, transaction):
|
||||
self._execute_action(rule.action, results, receipt, transaction)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str, receipt: Receipt, transaction: Transaction
|
||||
) -> bool:
|
||||
"""Safely evaluate rule conditions without using eval()"""
|
||||
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
|
||||
vendor_match = (
|
||||
receipt.vendor.lower() in transaction.vendor.lower()
|
||||
or transaction.vendor.lower() in receipt.vendor.lower()
|
||||
)
|
||||
vendor_lower = receipt.vendor.lower()
|
||||
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
|
||||
|
||||
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
|
||||
|
||||
# Tax-related conditions
|
||||
currency_mismatch = receipt.currency != transaction.currency
|
||||
is_meals_entertainment = receipt.is_meals_entertainment
|
||||
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
|
||||
|
||||
has_address_info = (
|
||||
receipt.billing_address is not None or receipt.shipping_address is not None
|
||||
)
|
||||
|
||||
# Handle specific condition types safely
|
||||
if condition == "amount_diff <= 0.01":
|
||||
return amount_diff <= 0.01
|
||||
@@ -86,14 +127,20 @@ class AIRulesEngine:
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"round": round
|
||||
"round": round,
|
||||
}
|
||||
return eval(condition, safe_globals, {})
|
||||
except (SyntaxError, NameError, TypeError) as e:
|
||||
print(f"Warning: Invalid condition '{condition}': {e}")
|
||||
return False
|
||||
|
||||
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
|
||||
|
||||
def _execute_action(
|
||||
self,
|
||||
action: str,
|
||||
results: Dict[str, Any],
|
||||
receipt: Receipt,
|
||||
transaction: Transaction,
|
||||
):
|
||||
if action == "auto_approve":
|
||||
results["auto_approve"] = True
|
||||
elif action == "high_confidence":
|
||||
@@ -114,13 +161,15 @@ class AIRulesEngine:
|
||||
# Calculate provincial tax
|
||||
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
||||
results["tax_analysis"]["sales_tax"] = tax_result
|
||||
|
||||
|
||||
def add_rule(self, rule: AIRule):
|
||||
self.rules.append(rule)
|
||||
|
||||
|
||||
def remove_rule(self, rule_name: str):
|
||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||
|
||||
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
|
||||
|
||||
def apply_tax_rules(
|
||||
self, receipt: Receipt, transaction: Transaction = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply all tax rules to a receipt/transaction pair"""
|
||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||
|
||||
Reference in New Issue
Block a user