diff --git a/ai_rules.py b/ai_rules.py index e247a6d..f77f119 100644 --- a/ai_rules.py +++ b/ai_rules.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Any, List -import config +from typing import Any, Dict, List + from models import Receipt, Transaction from tax_rules_engine import TaxRulesEngine + @dataclass class AIRule: name: str @@ -12,48 +13,88 @@ class AIRule: source: str status: str = "active" + class AIRulesEngine: def __init__(self): self.rules: List[AIRule] = [] self.tax_rules_engine = TaxRulesEngine() self._load_default_rules() - + def _load_default_rules(self): self.rules = [ - AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"), - AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"), - AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"), + AIRule( + "exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system" + ), + AIRule( + "same_vendor_same_date", + "vendor_match and date_diff <= 1", + "high_confidence", + "system", + ), + AIRule( + "gas_station_pattern", + "vendor_contains_gas_or_fuel", + "categorize_transport", + "system", + ), # Tax-related rules - AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"), - AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"), - AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system") + AIRule( + "fx_currency_mismatch", + "currency_mismatch", + "flag_fx_review", + "tax_system", + ), + AIRule( + "meals_entertainment", + "is_meals_entertainment", + "apply_me_tax_rule", + "tax_system", + ), + AIRule( + "provincial_tax_calculation", + "has_address_info", + "calculate_provincial_tax", + "tax_system", + ), ] - + def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]: - results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}} - + results = { + "auto_approve": False, + "confidence_boost": 0, + "category": None, + "tax_analysis": {}, + } + for rule in self.rules: if rule.status != "active": continue - + if self._evaluate_condition(rule.condition, receipt, transaction): self._execute_action(rule.action, results, receipt, transaction) - + return results - - def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool: + + def _evaluate_condition( + self, condition: str, receipt: Receipt, transaction: Transaction + ) -> bool: """Safely evaluate rule conditions without using eval()""" amount_diff = abs(receipt.amount - abs(transaction.amount)) date_diff = abs((receipt.receipt_date - transaction.transaction_date).days) - vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower() + vendor_match = ( + receipt.vendor.lower() in transaction.vendor.lower() + or transaction.vendor.lower() in receipt.vendor.lower() + ) vendor_lower = receipt.vendor.lower() - vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower - + vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower + # Tax-related conditions currency_mismatch = receipt.currency != transaction.currency is_meals_entertainment = receipt.is_meals_entertainment - has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None - + has_address_info = ( + receipt.billing_address is not None or receipt.shipping_address is not None + ) + # Handle specific condition types safely if condition == "amount_diff <= 0.01": return amount_diff <= 0.01 @@ -86,14 +127,20 @@ class AIRulesEngine: "min": min, "max": max, "sum": sum, - "round": round + "round": round, } return eval(condition, safe_globals, {}) except (SyntaxError, NameError, TypeError) as e: print(f"Warning: Invalid condition '{condition}': {e}") return False - - def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction): + + def _execute_action( + self, + action: str, + results: Dict[str, Any], + receipt: Receipt, + transaction: Transaction, + ): if action == "auto_approve": results["auto_approve"] = True elif action == "high_confidence": @@ -114,13 +161,15 @@ class AIRulesEngine: # Calculate provincial tax tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt) results["tax_analysis"]["sales_tax"] = tax_result - + def add_rule(self, rule: AIRule): self.rules.append(rule) - + def remove_rule(self, rule_name: str): self.rules = [r for r in self.rules if r.name != rule_name] - - def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]: + + def apply_tax_rules( + self, receipt: Receipt, transaction: Transaction = None + ) -> Dict[str, Any]: """Apply all tax rules to a receipt/transaction pair""" - return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction) \ No newline at end of file + return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction) diff --git a/api_models.py b/api_models.py index a98e576..9addb5f 100644 --- a/api_models.py +++ b/api_models.py @@ -1,13 +1,16 @@ -from pydantic import BaseModel from datetime import datetime from typing import List, Optional +from pydantic import BaseModel + + class AddressRequest(BaseModel): province: str city: str postal_code: str country: str = "Canada" + class ReceiptRequest(BaseModel): id: str file_name: str @@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel): currency: str = "CAD" is_meals_entertainment: bool = False + class TransactionRequest(BaseModel): id: str transaction_date: datetime @@ -34,6 +38,7 @@ class TransactionRequest(BaseModel): currency: str = "CAD" fx_rate: Optional[float] = None + class AssetRequest(BaseModel): id: str name: str @@ -44,42 +49,51 @@ class AssetRequest(BaseModel): cca_rate: float asset_class: str + class MatchingRequest(BaseModel): receipt_ids: List[str] transaction_ids: List[str] + class MatchResponse(BaseModel): receipt_id: str transaction_id: str confidence_score: float match_reason: str - tax_analysis: Optional[dict] = None - # Currency information - receipt_currency: str = "CAD" - transaction_currency: str = "CAD" - currency_match: bool = True + receipt_vendor: str + receipt_amount: float + receipt_description: str + receipt_category: str + receipt_tax_amount: float + transaction_vendor: str + transaction_amount: float + class MatchingResponse(BaseModel): matches: List[MatchResponse] stats: dict + class ApprovalRequest(BaseModel): match_id: str approved: bool reason: Optional[str] = None + class RuleRequest(BaseModel): name: str condition: str action: str source: str = "user" + class DocumentUploadResponse(BaseModel): file_id: str filename: str upload_date: datetime status: str + class DocumentProcessResponse(BaseModel): file_id: str extraction_success: bool @@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel): confidence: Optional[float] = None error: Optional[str] = None + # New tax-related models class TaxCalculationRequest(BaseModel): receipt_id: str transaction_id: Optional[str] = None + class TaxCalculationResponse(BaseModel): receipt_id: str rules_applied: List[str] @@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel): fx_analysis: Optional[dict] = None meals_entertainment: dict + class DepreciationRequest(BaseModel): asset: AssetRequest year: int method: str # "straight_line" or "cca" + class DepreciationResponse(BaseModel): asset_id: str year: int @@ -117,4 +135,4 @@ class DepreciationResponse(BaseModel): book_value: float total_depreciation: Optional[float] = None success: bool - error: Optional[str] = None \ No newline at end of file + error: Optional[str] = None diff --git a/feedback_logger.py b/feedback_logger.py index 3511b17..d28453b 100644 --- a/feedback_logger.py +++ b/feedback_logger.py @@ -1,8 +1,9 @@ -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import List, Optional import json import os +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import List + @dataclass class FeedbackLog: @@ -13,48 +14,63 @@ class FeedbackLog: timestamp: datetime user_id: str + class FeedbackLogger: def __init__(self, log_file: str = "feedback_logs.json"): self.log_file = log_file self.logs: List[FeedbackLog] = self._load_logs() - + def _load_logs(self) -> List[FeedbackLog]: if not os.path.exists(self.log_file): return [] - + try: - with open(self.log_file, 'r') as f: + with open(self.log_file, "r") as f: data = json.load(f) return [FeedbackLog(**log) for log in data] - except: + except Exception: return [] - + def _save_logs(self): - with open(self.log_file, 'w') as f: - json.dump([{ - 'transaction_id': log.transaction_id, - 'original_match': log.original_match, - 'correction': log.correction, - 'reason': log.reason, - 'timestamp': log.timestamp.isoformat(), - 'user_id': log.user_id - } for log in self.logs], f, indent=2) - - def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str): + with open(self.log_file, "w") as f: + json.dump( + [ + { + "transaction_id": log.transaction_id, + "original_match": log.original_match, + "correction": log.correction, + "reason": log.reason, + "timestamp": log.timestamp.isoformat(), + "user_id": log.user_id, + } + for log in self.logs + ], + f, + indent=2, + ) + + def log_override( + self, + transaction_id: str, + original_match: str, + correction: str, + reason: str, + user_id: str, + ): log = FeedbackLog( transaction_id=transaction_id, original_match=original_match, correction=correction, reason=reason, timestamp=datetime.now(), - user_id=user_id + user_id=user_id, ) self.logs.append(log) self._save_logs() - + def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]: return [log for log in self.logs if log.transaction_id == transaction_id] - + def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]: cutoff = datetime.now() - timedelta(days=days) - return [log for log in self.logs if log.timestamp > cutoff] \ No newline at end of file + return [log for log in self.logs if log.timestamp > cutoff] diff --git a/google_drive_sync.py b/google_drive_sync.py index 1596060..dda7ed4 100644 --- a/google_drive_sync.py +++ b/google_drive_sync.py @@ -1,13 +1,13 @@ import os -import io -from typing import List, Dict, Any, Optional from datetime import datetime, timedelta +from typing import Any, Dict, List + class GoogleDriveSync: def __init__(self): self.service = None self.processed_files = set() - + def authenticate(self): """Authenticate with Google Drive API""" try: @@ -15,111 +15,130 @@ class GoogleDriveSync: from google.oauth2.credentials import Credentials from google_auth_oauthlib.flow import InstalledAppFlow from googleapiclient.discovery import build - - SCOPES = ['https://www.googleapis.com/auth/drive.readonly'] - + + SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] + # Load existing credentials - if os.path.exists('token.json'): - self.creds = Credentials.from_authorized_user_file('token.json', SCOPES) - + if os.path.exists("token.json"): + self.creds = Credentials.from_authorized_user_file("token.json", SCOPES) + # If no valid credentials available, let user log in if not self.creds or not self.creds.valid: if self.creds and self.creds.expired and self.creds.refresh_token: self.creds.refresh(Request()) else: - if not os.path.exists('credentials.json'): - raise Exception("credentials.json not found. Please download from Google Cloud Console.") - - flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES) + if not os.path.exists("credentials.json"): + raise Exception( + "credentials.json not found. Please download from Google Cloud Console." + ) + + flow = InstalledAppFlow.from_client_secrets_file( + "credentials.json", SCOPES + ) self.creds = flow.run_local_server(port=0) - + # Save credentials for next run - with open('token.json', 'w') as token: + with open("token.json", "w") as token: token.write(self.creds.to_json()) - + # Build the Drive service - self.service = build('drive', 'v3', credentials=self.creds) + self.service = build("drive", "v3", credentials=self.creds) return True - + except Exception as e: print(f"Authentication error: {e}") return False - + def list_folders(self) -> List[Dict[str, Any]]: """List all folders in Google Drive""" if not self.service: if not self.authenticate(): return [] - + try: - results = self.service.files().list( - q="mimeType='application/vnd.google-apps.folder'", - pageSize=100, - fields="nextPageToken, files(id, name, createdTime, modifiedTime)" - ).execute() - - return results.get('files', []) - + results = ( + self.service.files() + .list( + q="mimeType='application/vnd.google-apps.folder'", + pageSize=100, + fields="nextPageToken, files(id, name, createdTime, modifiedTime)", + ) + .execute() + ) + + return results.get("files", []) + except Exception as e: print(f"Error listing folders: {e}") return [] - + def get_folder_info(self, folder_id: str) -> Dict[str, Any]: """Get information about a Google Drive folder""" if not self.service: if not self.authenticate(): return {} - + try: - folder = self.service.files().get( - fileId=folder_id, - fields="id, name, createdTime, modifiedTime" - ).execute() - + folder = ( + self.service.files() + .get(fileId=folder_id, fields="id, name, createdTime, modifiedTime") + .execute() + ) + return folder - + except Exception as e: print(f"Error getting folder info: {e}") return {} - + async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]: """Process all receipt files from Google Drive""" if not self.service: if not self.authenticate(): return [] - + results = [] - + try: # File types to look for - file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"] + file_types = [ + "'application/pdf'", + "'image/jpeg'", + "'image/png'", + "'image/gif'", + "'image/bmp'", + ] mime_types = " or ".join(file_types) - + # Build query query = f"mimeType contains {mime_types}" if folder_id: query += f" and '{folder_id}' in parents" - + # Add date filter (last 30 days) - thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z' + thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z" query += f" and modifiedTime > '{thirty_days_ago}'" - - results_files = self.service.files().list( - q=query, - pageSize=100, - fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)" - ).execute() - - files = results_files.get('files', []) - files = [file for file in files if file['id'] not in self.processed_files] - + + results_files = ( + self.service.files() + .list( + q=query, + pageSize=100, + fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)", + ) + .execute() + ) + + files = results_files.get("files", []) + files = [file for file in files if file["id"] not in self.processed_files] + # For demo purposes, return mock results for file in files[:3]: # Process first 3 files mock_result = { - "file_id": file['id'], - "filename": file['name'], - "drive_modified": file['modifiedTime'], - "file_size": file.get('size', 0), + "file_id": file["id"], + "filename": file["name"], + "drive_modified": file["modifiedTime"], + "file_size": file.get("size", 0), "extraction_success": True, "vendor": "Demo Vendor", "description": "Coffee and sandwich", @@ -127,12 +146,12 @@ class GoogleDriveSync: "tax_amount": 2.04, "date": "2024-01-15", "category": "Food", - "confidence": 0.95 + "confidence": 0.95, } results.append(mock_result) - self.processed_files.add(file['id']) - + self.processed_files.add(file["id"]) + except Exception as e: print(f"Error processing Drive files: {e}") - - return results \ No newline at end of file + + return results diff --git a/main.py b/main.py index 2022034..1f19df6 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,6 @@ from typing import List from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware -# Configure logging from ai_rules import AIRule from api_models import ( DocumentProcessResponse, @@ -71,7 +70,7 @@ async def root(): @app.post("/transactions/import/csv") -async def import_transactions_csv(file: UploadFile = File(...)): +async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""): """ Import transactions from a CSV file (custom bank export format). """