Refactor code for improved readability and maintainability across multiple files
This commit is contained in:
+78
-29
@@ -1,9 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List
|
||||
import config
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from models import Receipt, Transaction
|
||||
from tax_rules_engine import TaxRulesEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIRule:
|
||||
name: str
|
||||
@@ -12,48 +13,88 @@ class AIRule:
|
||||
source: str
|
||||
status: str = "active"
|
||||
|
||||
|
||||
class AIRulesEngine:
|
||||
def __init__(self):
|
||||
self.rules: List[AIRule] = []
|
||||
self.tax_rules_engine = TaxRulesEngine()
|
||||
self._load_default_rules()
|
||||
|
||||
|
||||
def _load_default_rules(self):
|
||||
self.rules = [
|
||||
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
|
||||
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
|
||||
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
|
||||
AIRule(
|
||||
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
|
||||
),
|
||||
AIRule(
|
||||
"same_vendor_same_date",
|
||||
"vendor_match and date_diff <= 1",
|
||||
"high_confidence",
|
||||
"system",
|
||||
),
|
||||
AIRule(
|
||||
"gas_station_pattern",
|
||||
"vendor_contains_gas_or_fuel",
|
||||
"categorize_transport",
|
||||
"system",
|
||||
),
|
||||
# Tax-related rules
|
||||
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
|
||||
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
|
||||
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
|
||||
AIRule(
|
||||
"fx_currency_mismatch",
|
||||
"currency_mismatch",
|
||||
"flag_fx_review",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"meals_entertainment",
|
||||
"is_meals_entertainment",
|
||||
"apply_me_tax_rule",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"provincial_tax_calculation",
|
||||
"has_address_info",
|
||||
"calculate_provincial_tax",
|
||||
"tax_system",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
||||
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
|
||||
|
||||
results = {
|
||||
"auto_approve": False,
|
||||
"confidence_boost": 0,
|
||||
"category": None,
|
||||
"tax_analysis": {},
|
||||
}
|
||||
|
||||
for rule in self.rules:
|
||||
if rule.status != "active":
|
||||
continue
|
||||
|
||||
|
||||
if self._evaluate_condition(rule.condition, receipt, transaction):
|
||||
self._execute_action(rule.action, results, receipt, transaction)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str, receipt: Receipt, transaction: Transaction
|
||||
) -> bool:
|
||||
"""Safely evaluate rule conditions without using eval()"""
|
||||
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
|
||||
vendor_match = (
|
||||
receipt.vendor.lower() in transaction.vendor.lower()
|
||||
or transaction.vendor.lower() in receipt.vendor.lower()
|
||||
)
|
||||
vendor_lower = receipt.vendor.lower()
|
||||
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
|
||||
|
||||
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
|
||||
|
||||
# Tax-related conditions
|
||||
currency_mismatch = receipt.currency != transaction.currency
|
||||
is_meals_entertainment = receipt.is_meals_entertainment
|
||||
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
|
||||
|
||||
has_address_info = (
|
||||
receipt.billing_address is not None or receipt.shipping_address is not None
|
||||
)
|
||||
|
||||
# Handle specific condition types safely
|
||||
if condition == "amount_diff <= 0.01":
|
||||
return amount_diff <= 0.01
|
||||
@@ -86,14 +127,20 @@ class AIRulesEngine:
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"round": round
|
||||
"round": round,
|
||||
}
|
||||
return eval(condition, safe_globals, {})
|
||||
except (SyntaxError, NameError, TypeError) as e:
|
||||
print(f"Warning: Invalid condition '{condition}': {e}")
|
||||
return False
|
||||
|
||||
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
|
||||
|
||||
def _execute_action(
|
||||
self,
|
||||
action: str,
|
||||
results: Dict[str, Any],
|
||||
receipt: Receipt,
|
||||
transaction: Transaction,
|
||||
):
|
||||
if action == "auto_approve":
|
||||
results["auto_approve"] = True
|
||||
elif action == "high_confidence":
|
||||
@@ -114,13 +161,15 @@ class AIRulesEngine:
|
||||
# Calculate provincial tax
|
||||
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
||||
results["tax_analysis"]["sales_tax"] = tax_result
|
||||
|
||||
|
||||
def add_rule(self, rule: AIRule):
|
||||
self.rules.append(rule)
|
||||
|
||||
|
||||
def remove_rule(self, rule_name: str):
|
||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||
|
||||
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
|
||||
|
||||
def apply_tax_rules(
|
||||
self, receipt: Receipt, transaction: Transaction = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply all tax rules to a receipt/transaction pair"""
|
||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||
|
||||
+25
-7
@@ -1,13 +1,16 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AddressRequest(BaseModel):
|
||||
province: str
|
||||
city: str
|
||||
postal_code: str
|
||||
country: str = "Canada"
|
||||
|
||||
|
||||
class ReceiptRequest(BaseModel):
|
||||
id: str
|
||||
file_name: str
|
||||
@@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel):
|
||||
currency: str = "CAD"
|
||||
is_meals_entertainment: bool = False
|
||||
|
||||
|
||||
class TransactionRequest(BaseModel):
|
||||
id: str
|
||||
transaction_date: datetime
|
||||
@@ -34,6 +38,7 @@ class TransactionRequest(BaseModel):
|
||||
currency: str = "CAD"
|
||||
fx_rate: Optional[float] = None
|
||||
|
||||
|
||||
class AssetRequest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -44,42 +49,51 @@ class AssetRequest(BaseModel):
|
||||
cca_rate: float
|
||||
asset_class: str
|
||||
|
||||
|
||||
class MatchingRequest(BaseModel):
|
||||
receipt_ids: List[str]
|
||||
transaction_ids: List[str]
|
||||
|
||||
|
||||
class MatchResponse(BaseModel):
|
||||
receipt_id: str
|
||||
transaction_id: str
|
||||
confidence_score: float
|
||||
match_reason: str
|
||||
tax_analysis: Optional[dict] = None
|
||||
# Currency information
|
||||
receipt_currency: str = "CAD"
|
||||
transaction_currency: str = "CAD"
|
||||
currency_match: bool = True
|
||||
receipt_vendor: str
|
||||
receipt_amount: float
|
||||
receipt_description: str
|
||||
receipt_category: str
|
||||
receipt_tax_amount: float
|
||||
transaction_vendor: str
|
||||
transaction_amount: float
|
||||
|
||||
|
||||
class MatchingResponse(BaseModel):
|
||||
matches: List[MatchResponse]
|
||||
stats: dict
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
match_id: str
|
||||
approved: bool
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class RuleRequest(BaseModel):
|
||||
name: str
|
||||
condition: str
|
||||
action: str
|
||||
source: str = "user"
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
file_id: str
|
||||
filename: str
|
||||
upload_date: datetime
|
||||
status: str
|
||||
|
||||
|
||||
class DocumentProcessResponse(BaseModel):
|
||||
file_id: str
|
||||
extraction_success: bool
|
||||
@@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel):
|
||||
confidence: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# New tax-related models
|
||||
class TaxCalculationRequest(BaseModel):
|
||||
receipt_id: str
|
||||
transaction_id: Optional[str] = None
|
||||
|
||||
|
||||
class TaxCalculationResponse(BaseModel):
|
||||
receipt_id: str
|
||||
rules_applied: List[str]
|
||||
@@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel):
|
||||
fx_analysis: Optional[dict] = None
|
||||
meals_entertainment: dict
|
||||
|
||||
|
||||
class DepreciationRequest(BaseModel):
|
||||
asset: AssetRequest
|
||||
year: int
|
||||
method: str # "straight_line" or "cca"
|
||||
|
||||
|
||||
class DepreciationResponse(BaseModel):
|
||||
asset_id: str
|
||||
year: int
|
||||
@@ -117,4 +135,4 @@ class DepreciationResponse(BaseModel):
|
||||
book_value: float
|
||||
total_depreciation: Optional[float] = None
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
+39
-23
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeedbackLog:
|
||||
@@ -13,48 +14,63 @@ class FeedbackLog:
|
||||
timestamp: datetime
|
||||
user_id: str
|
||||
|
||||
|
||||
class FeedbackLogger:
|
||||
def __init__(self, log_file: str = "feedback_logs.json"):
|
||||
self.log_file = log_file
|
||||
self.logs: List[FeedbackLog] = self._load_logs()
|
||||
|
||||
|
||||
def _load_logs(self) -> List[FeedbackLog]:
|
||||
if not os.path.exists(self.log_file):
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
with open(self.log_file, 'r') as f:
|
||||
with open(self.log_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return [FeedbackLog(**log) for log in data]
|
||||
except:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _save_logs(self):
|
||||
with open(self.log_file, 'w') as f:
|
||||
json.dump([{
|
||||
'transaction_id': log.transaction_id,
|
||||
'original_match': log.original_match,
|
||||
'correction': log.correction,
|
||||
'reason': log.reason,
|
||||
'timestamp': log.timestamp.isoformat(),
|
||||
'user_id': log.user_id
|
||||
} for log in self.logs], f, indent=2)
|
||||
|
||||
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str):
|
||||
with open(self.log_file, "w") as f:
|
||||
json.dump(
|
||||
[
|
||||
{
|
||||
"transaction_id": log.transaction_id,
|
||||
"original_match": log.original_match,
|
||||
"correction": log.correction,
|
||||
"reason": log.reason,
|
||||
"timestamp": log.timestamp.isoformat(),
|
||||
"user_id": log.user_id,
|
||||
}
|
||||
for log in self.logs
|
||||
],
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
def log_override(
|
||||
self,
|
||||
transaction_id: str,
|
||||
original_match: str,
|
||||
correction: str,
|
||||
reason: str,
|
||||
user_id: str,
|
||||
):
|
||||
log = FeedbackLog(
|
||||
transaction_id=transaction_id,
|
||||
original_match=original_match,
|
||||
correction=correction,
|
||||
reason=reason,
|
||||
timestamp=datetime.now(),
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
self.logs.append(log)
|
||||
self._save_logs()
|
||||
|
||||
|
||||
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
||||
return [log for log in self.logs if log.transaction_id == transaction_id]
|
||||
|
||||
|
||||
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
return [log for log in self.logs if log.timestamp > cutoff]
|
||||
return [log for log in self.logs if log.timestamp > cutoff]
|
||||
|
||||
+81
-62
@@ -1,13 +1,13 @@
|
||||
import os
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class GoogleDriveSync:
|
||||
def __init__(self):
|
||||
self.service = None
|
||||
self.processed_files = set()
|
||||
|
||||
|
||||
def authenticate(self):
|
||||
"""Authenticate with Google Drive API"""
|
||||
try:
|
||||
@@ -15,111 +15,130 @@ class GoogleDriveSync:
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
|
||||
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
|
||||
# Load existing credentials
|
||||
if os.path.exists('token.json'):
|
||||
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
||||
|
||||
if os.path.exists("token.json"):
|
||||
self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
|
||||
|
||||
# If no valid credentials available, let user log in
|
||||
if not self.creds or not self.creds.valid:
|
||||
if self.creds and self.creds.expired and self.creds.refresh_token:
|
||||
self.creds.refresh(Request())
|
||||
else:
|
||||
if not os.path.exists('credentials.json'):
|
||||
raise Exception("credentials.json not found. Please download from Google Cloud Console.")
|
||||
|
||||
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
|
||||
if not os.path.exists("credentials.json"):
|
||||
raise Exception(
|
||||
"credentials.json not found. Please download from Google Cloud Console."
|
||||
)
|
||||
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
"credentials.json", SCOPES
|
||||
)
|
||||
self.creds = flow.run_local_server(port=0)
|
||||
|
||||
|
||||
# Save credentials for next run
|
||||
with open('token.json', 'w') as token:
|
||||
with open("token.json", "w") as token:
|
||||
token.write(self.creds.to_json())
|
||||
|
||||
|
||||
# Build the Drive service
|
||||
self.service = build('drive', 'v3', credentials=self.creds)
|
||||
self.service = build("drive", "v3", credentials=self.creds)
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Authentication error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def list_folders(self) -> List[Dict[str, Any]]:
|
||||
"""List all folders in Google Drive"""
|
||||
if not self.service:
|
||||
if not self.authenticate():
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
results = self.service.files().list(
|
||||
q="mimeType='application/vnd.google-apps.folder'",
|
||||
pageSize=100,
|
||||
fields="nextPageToken, files(id, name, createdTime, modifiedTime)"
|
||||
).execute()
|
||||
|
||||
return results.get('files', [])
|
||||
|
||||
results = (
|
||||
self.service.files()
|
||||
.list(
|
||||
q="mimeType='application/vnd.google-apps.folder'",
|
||||
pageSize=100,
|
||||
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return results.get("files", [])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error listing folders: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Google Drive folder"""
|
||||
if not self.service:
|
||||
if not self.authenticate():
|
||||
return {}
|
||||
|
||||
|
||||
try:
|
||||
folder = self.service.files().get(
|
||||
fileId=folder_id,
|
||||
fields="id, name, createdTime, modifiedTime"
|
||||
).execute()
|
||||
|
||||
folder = (
|
||||
self.service.files()
|
||||
.get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
|
||||
.execute()
|
||||
)
|
||||
|
||||
return folder
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting folder info: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Process all receipt files from Google Drive"""
|
||||
if not self.service:
|
||||
if not self.authenticate():
|
||||
return []
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
try:
|
||||
# File types to look for
|
||||
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"]
|
||||
file_types = [
|
||||
"'application/pdf'",
|
||||
"'image/jpeg'",
|
||||
"'image/png'",
|
||||
"'image/gif'",
|
||||
"'image/bmp'",
|
||||
]
|
||||
mime_types = " or ".join(file_types)
|
||||
|
||||
|
||||
# Build query
|
||||
query = f"mimeType contains {mime_types}"
|
||||
if folder_id:
|
||||
query += f" and '{folder_id}' in parents"
|
||||
|
||||
|
||||
# Add date filter (last 30 days)
|
||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z'
|
||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
|
||||
query += f" and modifiedTime > '{thirty_days_ago}'"
|
||||
|
||||
results_files = self.service.files().list(
|
||||
q=query,
|
||||
pageSize=100,
|
||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)"
|
||||
).execute()
|
||||
|
||||
files = results_files.get('files', [])
|
||||
files = [file for file in files if file['id'] not in self.processed_files]
|
||||
|
||||
|
||||
results_files = (
|
||||
self.service.files()
|
||||
.list(
|
||||
q=query,
|
||||
pageSize=100,
|
||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
files = results_files.get("files", [])
|
||||
files = [file for file in files if file["id"] not in self.processed_files]
|
||||
|
||||
# For demo purposes, return mock results
|
||||
for file in files[:3]: # Process first 3 files
|
||||
mock_result = {
|
||||
"file_id": file['id'],
|
||||
"filename": file['name'],
|
||||
"drive_modified": file['modifiedTime'],
|
||||
"file_size": file.get('size', 0),
|
||||
"file_id": file["id"],
|
||||
"filename": file["name"],
|
||||
"drive_modified": file["modifiedTime"],
|
||||
"file_size": file.get("size", 0),
|
||||
"extraction_success": True,
|
||||
"vendor": "Demo Vendor",
|
||||
"description": "Coffee and sandwich",
|
||||
@@ -127,12 +146,12 @@ class GoogleDriveSync:
|
||||
"tax_amount": 2.04,
|
||||
"date": "2024-01-15",
|
||||
"category": "Food",
|
||||
"confidence": 0.95
|
||||
"confidence": 0.95,
|
||||
}
|
||||
results.append(mock_result)
|
||||
self.processed_files.add(file['id'])
|
||||
|
||||
self.processed_files.add(file["id"])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing Drive files: {e}")
|
||||
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import List
|
||||
from fastapi import FastAPI, File, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Configure logging
|
||||
from ai_rules import AIRule
|
||||
from api_models import (
|
||||
DocumentProcessResponse,
|
||||
@@ -71,7 +70,7 @@ async def root():
|
||||
|
||||
|
||||
@app.post("/transactions/import/csv")
|
||||
async def import_transactions_csv(file: UploadFile = File(...)):
|
||||
async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""):
|
||||
"""
|
||||
Import transactions from a CSV file (custom bank export format).
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user