Refactor code for improved readability and maintainability across multiple files
This commit is contained in:
+78
-29
@@ -1,9 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Any, List
|
from typing import Any, Dict, List
|
||||||
import config
|
|
||||||
from models import Receipt, Transaction
|
from models import Receipt, Transaction
|
||||||
from tax_rules_engine import TaxRulesEngine
|
from tax_rules_engine import TaxRulesEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIRule:
|
class AIRule:
|
||||||
name: str
|
name: str
|
||||||
@@ -12,48 +13,88 @@ class AIRule:
|
|||||||
source: str
|
source: str
|
||||||
status: str = "active"
|
status: str = "active"
|
||||||
|
|
||||||
|
|
||||||
class AIRulesEngine:
|
class AIRulesEngine:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.rules: List[AIRule] = []
|
self.rules: List[AIRule] = []
|
||||||
self.tax_rules_engine = TaxRulesEngine()
|
self.tax_rules_engine = TaxRulesEngine()
|
||||||
self._load_default_rules()
|
self._load_default_rules()
|
||||||
|
|
||||||
def _load_default_rules(self):
|
def _load_default_rules(self):
|
||||||
self.rules = [
|
self.rules = [
|
||||||
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
|
AIRule(
|
||||||
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
|
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
|
||||||
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
|
),
|
||||||
|
AIRule(
|
||||||
|
"same_vendor_same_date",
|
||||||
|
"vendor_match and date_diff <= 1",
|
||||||
|
"high_confidence",
|
||||||
|
"system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"gas_station_pattern",
|
||||||
|
"vendor_contains_gas_or_fuel",
|
||||||
|
"categorize_transport",
|
||||||
|
"system",
|
||||||
|
),
|
||||||
# Tax-related rules
|
# Tax-related rules
|
||||||
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
|
AIRule(
|
||||||
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
|
"fx_currency_mismatch",
|
||||||
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
|
"currency_mismatch",
|
||||||
|
"flag_fx_review",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"meals_entertainment",
|
||||||
|
"is_meals_entertainment",
|
||||||
|
"apply_me_tax_rule",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"provincial_tax_calculation",
|
||||||
|
"has_address_info",
|
||||||
|
"calculate_provincial_tax",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
||||||
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
|
results = {
|
||||||
|
"auto_approve": False,
|
||||||
|
"confidence_boost": 0,
|
||||||
|
"category": None,
|
||||||
|
"tax_analysis": {},
|
||||||
|
}
|
||||||
|
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
if rule.status != "active":
|
if rule.status != "active":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._evaluate_condition(rule.condition, receipt, transaction):
|
if self._evaluate_condition(rule.condition, receipt, transaction):
|
||||||
self._execute_action(rule.action, results, receipt, transaction)
|
self._execute_action(rule.action, results, receipt, transaction)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
|
def _evaluate_condition(
|
||||||
|
self, condition: str, receipt: Receipt, transaction: Transaction
|
||||||
|
) -> bool:
|
||||||
"""Safely evaluate rule conditions without using eval()"""
|
"""Safely evaluate rule conditions without using eval()"""
|
||||||
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
||||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||||
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
|
vendor_match = (
|
||||||
|
receipt.vendor.lower() in transaction.vendor.lower()
|
||||||
|
or transaction.vendor.lower() in receipt.vendor.lower()
|
||||||
|
)
|
||||||
vendor_lower = receipt.vendor.lower()
|
vendor_lower = receipt.vendor.lower()
|
||||||
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
|
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
|
||||||
|
|
||||||
# Tax-related conditions
|
# Tax-related conditions
|
||||||
currency_mismatch = receipt.currency != transaction.currency
|
currency_mismatch = receipt.currency != transaction.currency
|
||||||
is_meals_entertainment = receipt.is_meals_entertainment
|
is_meals_entertainment = receipt.is_meals_entertainment
|
||||||
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
|
has_address_info = (
|
||||||
|
receipt.billing_address is not None or receipt.shipping_address is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Handle specific condition types safely
|
# Handle specific condition types safely
|
||||||
if condition == "amount_diff <= 0.01":
|
if condition == "amount_diff <= 0.01":
|
||||||
return amount_diff <= 0.01
|
return amount_diff <= 0.01
|
||||||
@@ -86,14 +127,20 @@ class AIRulesEngine:
|
|||||||
"min": min,
|
"min": min,
|
||||||
"max": max,
|
"max": max,
|
||||||
"sum": sum,
|
"sum": sum,
|
||||||
"round": round
|
"round": round,
|
||||||
}
|
}
|
||||||
return eval(condition, safe_globals, {})
|
return eval(condition, safe_globals, {})
|
||||||
except (SyntaxError, NameError, TypeError) as e:
|
except (SyntaxError, NameError, TypeError) as e:
|
||||||
print(f"Warning: Invalid condition '{condition}': {e}")
|
print(f"Warning: Invalid condition '{condition}': {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
|
def _execute_action(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
results: Dict[str, Any],
|
||||||
|
receipt: Receipt,
|
||||||
|
transaction: Transaction,
|
||||||
|
):
|
||||||
if action == "auto_approve":
|
if action == "auto_approve":
|
||||||
results["auto_approve"] = True
|
results["auto_approve"] = True
|
||||||
elif action == "high_confidence":
|
elif action == "high_confidence":
|
||||||
@@ -114,13 +161,15 @@ class AIRulesEngine:
|
|||||||
# Calculate provincial tax
|
# Calculate provincial tax
|
||||||
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
||||||
results["tax_analysis"]["sales_tax"] = tax_result
|
results["tax_analysis"]["sales_tax"] = tax_result
|
||||||
|
|
||||||
def add_rule(self, rule: AIRule):
|
def add_rule(self, rule: AIRule):
|
||||||
self.rules.append(rule)
|
self.rules.append(rule)
|
||||||
|
|
||||||
def remove_rule(self, rule_name: str):
|
def remove_rule(self, rule_name: str):
|
||||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||||
|
|
||||||
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
|
def apply_tax_rules(
|
||||||
|
self, receipt: Receipt, transaction: Transaction = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Apply all tax rules to a receipt/transaction pair"""
|
"""Apply all tax rules to a receipt/transaction pair"""
|
||||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||||
|
|||||||
+25
-7
@@ -1,13 +1,16 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AddressRequest(BaseModel):
|
class AddressRequest(BaseModel):
|
||||||
province: str
|
province: str
|
||||||
city: str
|
city: str
|
||||||
postal_code: str
|
postal_code: str
|
||||||
country: str = "Canada"
|
country: str = "Canada"
|
||||||
|
|
||||||
|
|
||||||
class ReceiptRequest(BaseModel):
|
class ReceiptRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
file_name: str
|
file_name: str
|
||||||
@@ -24,6 +27,7 @@ class ReceiptRequest(BaseModel):
|
|||||||
currency: str = "CAD"
|
currency: str = "CAD"
|
||||||
is_meals_entertainment: bool = False
|
is_meals_entertainment: bool = False
|
||||||
|
|
||||||
|
|
||||||
class TransactionRequest(BaseModel):
|
class TransactionRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
transaction_date: datetime
|
transaction_date: datetime
|
||||||
@@ -34,6 +38,7 @@ class TransactionRequest(BaseModel):
|
|||||||
currency: str = "CAD"
|
currency: str = "CAD"
|
||||||
fx_rate: Optional[float] = None
|
fx_rate: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class AssetRequest(BaseModel):
|
class AssetRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -44,42 +49,51 @@ class AssetRequest(BaseModel):
|
|||||||
cca_rate: float
|
cca_rate: float
|
||||||
asset_class: str
|
asset_class: str
|
||||||
|
|
||||||
|
|
||||||
class MatchingRequest(BaseModel):
|
class MatchingRequest(BaseModel):
|
||||||
receipt_ids: List[str]
|
receipt_ids: List[str]
|
||||||
transaction_ids: List[str]
|
transaction_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
class MatchResponse(BaseModel):
|
class MatchResponse(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
transaction_id: str
|
transaction_id: str
|
||||||
confidence_score: float
|
confidence_score: float
|
||||||
match_reason: str
|
match_reason: str
|
||||||
tax_analysis: Optional[dict] = None
|
receipt_vendor: str
|
||||||
# Currency information
|
receipt_amount: float
|
||||||
receipt_currency: str = "CAD"
|
receipt_description: str
|
||||||
transaction_currency: str = "CAD"
|
receipt_category: str
|
||||||
currency_match: bool = True
|
receipt_tax_amount: float
|
||||||
|
transaction_vendor: str
|
||||||
|
transaction_amount: float
|
||||||
|
|
||||||
|
|
||||||
class MatchingResponse(BaseModel):
|
class MatchingResponse(BaseModel):
|
||||||
matches: List[MatchResponse]
|
matches: List[MatchResponse]
|
||||||
stats: dict
|
stats: dict
|
||||||
|
|
||||||
|
|
||||||
class ApprovalRequest(BaseModel):
|
class ApprovalRequest(BaseModel):
|
||||||
match_id: str
|
match_id: str
|
||||||
approved: bool
|
approved: bool
|
||||||
reason: Optional[str] = None
|
reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class RuleRequest(BaseModel):
|
class RuleRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
condition: str
|
condition: str
|
||||||
action: str
|
action: str
|
||||||
source: str = "user"
|
source: str = "user"
|
||||||
|
|
||||||
|
|
||||||
class DocumentUploadResponse(BaseModel):
|
class DocumentUploadResponse(BaseModel):
|
||||||
file_id: str
|
file_id: str
|
||||||
filename: str
|
filename: str
|
||||||
upload_date: datetime
|
upload_date: datetime
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
class DocumentProcessResponse(BaseModel):
|
class DocumentProcessResponse(BaseModel):
|
||||||
file_id: str
|
file_id: str
|
||||||
extraction_success: bool
|
extraction_success: bool
|
||||||
@@ -92,11 +106,13 @@ class DocumentProcessResponse(BaseModel):
|
|||||||
confidence: Optional[float] = None
|
confidence: Optional[float] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# New tax-related models
|
# New tax-related models
|
||||||
class TaxCalculationRequest(BaseModel):
|
class TaxCalculationRequest(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
transaction_id: Optional[str] = None
|
transaction_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class TaxCalculationResponse(BaseModel):
|
class TaxCalculationResponse(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
rules_applied: List[str]
|
rules_applied: List[str]
|
||||||
@@ -104,11 +120,13 @@ class TaxCalculationResponse(BaseModel):
|
|||||||
fx_analysis: Optional[dict] = None
|
fx_analysis: Optional[dict] = None
|
||||||
meals_entertainment: dict
|
meals_entertainment: dict
|
||||||
|
|
||||||
|
|
||||||
class DepreciationRequest(BaseModel):
|
class DepreciationRequest(BaseModel):
|
||||||
asset: AssetRequest
|
asset: AssetRequest
|
||||||
year: int
|
year: int
|
||||||
method: str # "straight_line" or "cca"
|
method: str # "straight_line" or "cca"
|
||||||
|
|
||||||
|
|
||||||
class DepreciationResponse(BaseModel):
|
class DepreciationResponse(BaseModel):
|
||||||
asset_id: str
|
asset_id: str
|
||||||
year: int
|
year: int
|
||||||
@@ -117,4 +135,4 @@ class DepreciationResponse(BaseModel):
|
|||||||
book_value: float
|
book_value: float
|
||||||
total_depreciation: Optional[float] = None
|
total_depreciation: Optional[float] = None
|
||||||
success: bool
|
success: bool
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|||||||
+39
-23
@@ -1,8 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeedbackLog:
|
class FeedbackLog:
|
||||||
@@ -13,48 +14,63 @@ class FeedbackLog:
|
|||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
||||||
class FeedbackLogger:
|
class FeedbackLogger:
|
||||||
def __init__(self, log_file: str = "feedback_logs.json"):
|
def __init__(self, log_file: str = "feedback_logs.json"):
|
||||||
self.log_file = log_file
|
self.log_file = log_file
|
||||||
self.logs: List[FeedbackLog] = self._load_logs()
|
self.logs: List[FeedbackLog] = self._load_logs()
|
||||||
|
|
||||||
def _load_logs(self) -> List[FeedbackLog]:
|
def _load_logs(self) -> List[FeedbackLog]:
|
||||||
if not os.path.exists(self.log_file):
|
if not os.path.exists(self.log_file):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.log_file, 'r') as f:
|
with open(self.log_file, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return [FeedbackLog(**log) for log in data]
|
return [FeedbackLog(**log) for log in data]
|
||||||
except:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _save_logs(self):
|
def _save_logs(self):
|
||||||
with open(self.log_file, 'w') as f:
|
with open(self.log_file, "w") as f:
|
||||||
json.dump([{
|
json.dump(
|
||||||
'transaction_id': log.transaction_id,
|
[
|
||||||
'original_match': log.original_match,
|
{
|
||||||
'correction': log.correction,
|
"transaction_id": log.transaction_id,
|
||||||
'reason': log.reason,
|
"original_match": log.original_match,
|
||||||
'timestamp': log.timestamp.isoformat(),
|
"correction": log.correction,
|
||||||
'user_id': log.user_id
|
"reason": log.reason,
|
||||||
} for log in self.logs], f, indent=2)
|
"timestamp": log.timestamp.isoformat(),
|
||||||
|
"user_id": log.user_id,
|
||||||
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str):
|
}
|
||||||
|
for log in self.logs
|
||||||
|
],
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_override(
|
||||||
|
self,
|
||||||
|
transaction_id: str,
|
||||||
|
original_match: str,
|
||||||
|
correction: str,
|
||||||
|
reason: str,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
log = FeedbackLog(
|
log = FeedbackLog(
|
||||||
transaction_id=transaction_id,
|
transaction_id=transaction_id,
|
||||||
original_match=original_match,
|
original_match=original_match,
|
||||||
correction=correction,
|
correction=correction,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.logs.append(log)
|
self.logs.append(log)
|
||||||
self._save_logs()
|
self._save_logs()
|
||||||
|
|
||||||
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
||||||
return [log for log in self.logs if log.transaction_id == transaction_id]
|
return [log for log in self.logs if log.transaction_id == transaction_id]
|
||||||
|
|
||||||
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
||||||
cutoff = datetime.now() - timedelta(days=days)
|
cutoff = datetime.now() - timedelta(days=days)
|
||||||
return [log for log in self.logs if log.timestamp > cutoff]
|
return [log for log in self.logs if log.timestamp > cutoff]
|
||||||
|
|||||||
+81
-62
@@ -1,13 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import io
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
class GoogleDriveSync:
|
class GoogleDriveSync:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.service = None
|
self.service = None
|
||||||
self.processed_files = set()
|
self.processed_files = set()
|
||||||
|
|
||||||
def authenticate(self):
|
def authenticate(self):
|
||||||
"""Authenticate with Google Drive API"""
|
"""Authenticate with Google Drive API"""
|
||||||
try:
|
try:
|
||||||
@@ -15,111 +15,130 @@ class GoogleDriveSync:
|
|||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
|
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||||
|
|
||||||
# Load existing credentials
|
# Load existing credentials
|
||||||
if os.path.exists('token.json'):
|
if os.path.exists("token.json"):
|
||||||
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
|
||||||
|
|
||||||
# If no valid credentials available, let user log in
|
# If no valid credentials available, let user log in
|
||||||
if not self.creds or not self.creds.valid:
|
if not self.creds or not self.creds.valid:
|
||||||
if self.creds and self.creds.expired and self.creds.refresh_token:
|
if self.creds and self.creds.expired and self.creds.refresh_token:
|
||||||
self.creds.refresh(Request())
|
self.creds.refresh(Request())
|
||||||
else:
|
else:
|
||||||
if not os.path.exists('credentials.json'):
|
if not os.path.exists("credentials.json"):
|
||||||
raise Exception("credentials.json not found. Please download from Google Cloud Console.")
|
raise Exception(
|
||||||
|
"credentials.json not found. Please download from Google Cloud Console."
|
||||||
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
|
)
|
||||||
|
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
"credentials.json", SCOPES
|
||||||
|
)
|
||||||
self.creds = flow.run_local_server(port=0)
|
self.creds = flow.run_local_server(port=0)
|
||||||
|
|
||||||
# Save credentials for next run
|
# Save credentials for next run
|
||||||
with open('token.json', 'w') as token:
|
with open("token.json", "w") as token:
|
||||||
token.write(self.creds.to_json())
|
token.write(self.creds.to_json())
|
||||||
|
|
||||||
# Build the Drive service
|
# Build the Drive service
|
||||||
self.service = build('drive', 'v3', credentials=self.creds)
|
self.service = build("drive", "v3", credentials=self.creds)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Authentication error: {e}")
|
print(f"Authentication error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_folders(self) -> List[Dict[str, Any]]:
|
def list_folders(self) -> List[Dict[str, Any]]:
|
||||||
"""List all folders in Google Drive"""
|
"""List all folders in Google Drive"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = self.service.files().list(
|
results = (
|
||||||
q="mimeType='application/vnd.google-apps.folder'",
|
self.service.files()
|
||||||
pageSize=100,
|
.list(
|
||||||
fields="nextPageToken, files(id, name, createdTime, modifiedTime)"
|
q="mimeType='application/vnd.google-apps.folder'",
|
||||||
).execute()
|
pageSize=100,
|
||||||
|
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
|
||||||
return results.get('files', [])
|
)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
return results.get("files", [])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error listing folders: {e}")
|
print(f"Error listing folders: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
|
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
|
||||||
"""Get information about a Google Drive folder"""
|
"""Get information about a Google Drive folder"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
folder = self.service.files().get(
|
folder = (
|
||||||
fileId=folder_id,
|
self.service.files()
|
||||||
fields="id, name, createdTime, modifiedTime"
|
.get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
|
||||||
).execute()
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
return folder
|
return folder
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting folder info: {e}")
|
print(f"Error getting folder info: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
|
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
|
||||||
"""Process all receipt files from Google Drive"""
|
"""Process all receipt files from Google Drive"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# File types to look for
|
# File types to look for
|
||||||
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"]
|
file_types = [
|
||||||
|
"'application/pdf'",
|
||||||
|
"'image/jpeg'",
|
||||||
|
"'image/png'",
|
||||||
|
"'image/gif'",
|
||||||
|
"'image/bmp'",
|
||||||
|
]
|
||||||
mime_types = " or ".join(file_types)
|
mime_types = " or ".join(file_types)
|
||||||
|
|
||||||
# Build query
|
# Build query
|
||||||
query = f"mimeType contains {mime_types}"
|
query = f"mimeType contains {mime_types}"
|
||||||
if folder_id:
|
if folder_id:
|
||||||
query += f" and '{folder_id}' in parents"
|
query += f" and '{folder_id}' in parents"
|
||||||
|
|
||||||
# Add date filter (last 30 days)
|
# Add date filter (last 30 days)
|
||||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z'
|
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
|
||||||
query += f" and modifiedTime > '{thirty_days_ago}'"
|
query += f" and modifiedTime > '{thirty_days_ago}'"
|
||||||
|
|
||||||
results_files = self.service.files().list(
|
results_files = (
|
||||||
q=query,
|
self.service.files()
|
||||||
pageSize=100,
|
.list(
|
||||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)"
|
q=query,
|
||||||
).execute()
|
pageSize=100,
|
||||||
|
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
|
||||||
files = results_files.get('files', [])
|
)
|
||||||
files = [file for file in files if file['id'] not in self.processed_files]
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
files = results_files.get("files", [])
|
||||||
|
files = [file for file in files if file["id"] not in self.processed_files]
|
||||||
|
|
||||||
# For demo purposes, return mock results
|
# For demo purposes, return mock results
|
||||||
for file in files[:3]: # Process first 3 files
|
for file in files[:3]: # Process first 3 files
|
||||||
mock_result = {
|
mock_result = {
|
||||||
"file_id": file['id'],
|
"file_id": file["id"],
|
||||||
"filename": file['name'],
|
"filename": file["name"],
|
||||||
"drive_modified": file['modifiedTime'],
|
"drive_modified": file["modifiedTime"],
|
||||||
"file_size": file.get('size', 0),
|
"file_size": file.get("size", 0),
|
||||||
"extraction_success": True,
|
"extraction_success": True,
|
||||||
"vendor": "Demo Vendor",
|
"vendor": "Demo Vendor",
|
||||||
"description": "Coffee and sandwich",
|
"description": "Coffee and sandwich",
|
||||||
@@ -127,12 +146,12 @@ class GoogleDriveSync:
|
|||||||
"tax_amount": 2.04,
|
"tax_amount": 2.04,
|
||||||
"date": "2024-01-15",
|
"date": "2024-01-15",
|
||||||
"category": "Food",
|
"category": "Food",
|
||||||
"confidence": 0.95
|
"confidence": 0.95,
|
||||||
}
|
}
|
||||||
results.append(mock_result)
|
results.append(mock_result)
|
||||||
self.processed_files.add(file['id'])
|
self.processed_files.add(file["id"])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing Drive files: {e}")
|
print(f"Error processing Drive files: {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from typing import List
|
|||||||
from fastapi import FastAPI, File, HTTPException, UploadFile
|
from fastapi import FastAPI, File, HTTPException, UploadFile
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
from ai_rules import AIRule
|
from ai_rules import AIRule
|
||||||
from api_models import (
|
from api_models import (
|
||||||
DocumentProcessResponse,
|
DocumentProcessResponse,
|
||||||
@@ -71,7 +70,7 @@ async def root():
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/transactions/import/csv")
|
@app.post("/transactions/import/csv")
|
||||||
async def import_transactions_csv(file: UploadFile = File(...)):
|
async def import_transactions_csv(file: UploadFile = File(...), user_id: str = "", categorization_id: str = ""):
|
||||||
"""
|
"""
|
||||||
Import transactions from a CSV file (custom bank export format).
|
Import transactions from a CSV file (custom bank export format).
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user