Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1784d2e406 | |||
| 55ffc52339 | |||
| 9698e2fcaf | |||
| 1f530da7c4 |
+304
-79
@@ -1,115 +1,322 @@
|
|||||||
import groq
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Tuple
|
|
||||||
import config
|
|
||||||
from models import Receipt, Transaction, Match
|
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import time
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import groq
|
||||||
|
|
||||||
|
import config
|
||||||
|
from models import Match, Receipt, Transaction
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AIMatcher:
|
class AIMatcher:
|
||||||
def __init__(self):
|
def __init__(self, use_batch_matching=True):
|
||||||
self.client = groq.Groq(api_key=config.GROQ_API_KEY)
|
self.client = groq.Groq(api_key=config.GROQ_API_KEY)
|
||||||
self.model = "llama3-8b-8192"
|
self.model = "llama3-8b-8192"
|
||||||
self.max_retries = 3
|
self.max_retries = 3
|
||||||
self.retry_delay = 2 # seconds - increased for rate limiting
|
self.retry_delay = 2 # seconds - increased for rate limiting
|
||||||
self.rate_limit_delay = 1.0 # seconds between API calls
|
self.rate_limit_delay = 1.0 # seconds between API calls
|
||||||
self.last_api_call = 0
|
self.last_api_call = 0
|
||||||
|
self.use_batch_matching = (
|
||||||
def match_receipts_to_transactions(self, receipts: List[Receipt], transactions: List[Transaction]) -> List[Match]:
|
use_batch_matching # Toggle between new and legacy methods
|
||||||
|
)
|
||||||
|
|
||||||
|
def match_receipts_to_transactions(
|
||||||
|
self, receipts: List[Receipt], transactions: List[Transaction]
|
||||||
|
) -> List[Match]:
|
||||||
"""Match receipts to transactions using AI"""
|
"""Match receipts to transactions using AI"""
|
||||||
logger.info(f"Starting AI matching for {len(receipts)} receipts against {len(transactions)} transactions")
|
logger.info(
|
||||||
|
f"Starting AI matching for {len(receipts)} receipts against {len(transactions)} transactions"
|
||||||
|
)
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for i, receipt in enumerate(receipts):
|
for i, receipt in enumerate(receipts):
|
||||||
logger.info(f"Processing receipt {i+1}/{len(receipts)}: {receipt.vendor} - ${receipt.amount}")
|
logger.info(
|
||||||
|
f"Processing receipt {i + 1}/{len(receipts)}: {receipt.vendor} - ${receipt.amount}"
|
||||||
|
)
|
||||||
|
|
||||||
# Rate limiting
|
# Rate limiting
|
||||||
self._rate_limit()
|
self._rate_limit()
|
||||||
|
|
||||||
# Get the BEST match for this receipt (highest confidence score)
|
# Get the BEST match for this receipt (highest confidence score)
|
||||||
best_match = self._find_best_match(receipt, transactions)
|
best_match = self._find_best_match(receipt, transactions)
|
||||||
if best_match:
|
if best_match:
|
||||||
matches.append(best_match)
|
matches.append(best_match)
|
||||||
logger.info(f"Found match: {best_match.confidence_score:.3f} - {best_match.match_reason}")
|
logger.info(
|
||||||
|
f"Found match: {best_match.confidence_score:.3f} - {best_match.match_reason}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No match found for receipt: {receipt.vendor} - ${receipt.amount}")
|
logger.warning(
|
||||||
|
f"No match found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||||||
|
)
|
||||||
|
|
||||||
# Sort by confidence score (highest first)
|
# Sort by confidence score (highest first)
|
||||||
matches = sorted(matches, key=lambda x: x.confidence_score, reverse=True)
|
matches = sorted(matches, key=lambda x: x.confidence_score, reverse=True)
|
||||||
logger.info(f"AI matching completed. Found {len(matches)} matches")
|
logger.info(f"AI matching completed. Found {len(matches)} matches")
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
def _rate_limit(self):
|
def _rate_limit(self):
|
||||||
"""Implement rate limiting to avoid API quota exhaustion"""
|
"""Implement rate limiting to avoid API quota exhaustion"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
time_since_last_call = current_time - self.last_api_call
|
time_since_last_call = current_time - self.last_api_call
|
||||||
|
|
||||||
if time_since_last_call < self.rate_limit_delay:
|
if time_since_last_call < self.rate_limit_delay:
|
||||||
sleep_time = self.rate_limit_delay - time_since_last_call
|
sleep_time = self.rate_limit_delay - time_since_last_call
|
||||||
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
self.last_api_call = time.time()
|
self.last_api_call = time.time()
|
||||||
|
|
||||||
def _find_best_match(self, receipt: Receipt, transactions: List[Transaction]) -> Match:
|
def _find_best_match(
|
||||||
"""Find the BEST match for a receipt (highest confidence score)"""
|
self, receipt: Receipt, transactions: List[Transaction]
|
||||||
|
) -> Match:
|
||||||
|
"""Find the BEST match for a receipt using a single AI call for all candidates"""
|
||||||
candidates = self._filter_candidates(receipt, transactions)
|
candidates = self._filter_candidates(receipt, transactions)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
logger.warning(f"No candidates found for receipt: {receipt.vendor} - ${receipt.amount}")
|
logger.warning(
|
||||||
|
f"No candidates found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"Found {len(candidates)} candidates for receipt: {receipt.vendor}")
|
logger.info(f"Found {len(candidates)} candidates for receipt: {receipt.vendor}")
|
||||||
|
|
||||||
best_match = None
|
# Choose matching method based on configuration
|
||||||
highest_score = 0
|
if self.use_batch_matching:
|
||||||
|
# New efficient method: single AI call for all candidates
|
||||||
for transaction in candidates:
|
best_match = self._find_best_match_single_call(receipt, candidates)
|
||||||
score, reason = self._calculate_match_score(receipt, transaction)
|
else:
|
||||||
logger.debug(f"Score {score:.3f} for transaction {transaction.vendor}: {reason}")
|
# Legacy method: individual AI calls (fallback)
|
||||||
|
best_match = self._find_best_match_legacy(receipt, candidates)
|
||||||
# Keep the match with the highest score, regardless of how low it is
|
|
||||||
if score > highest_score:
|
|
||||||
highest_score = score
|
|
||||||
best_match = Match(receipt, transaction, score, reason)
|
|
||||||
|
|
||||||
return best_match
|
return best_match
|
||||||
|
|
||||||
def _filter_candidates(self, receipt: Receipt, transactions: List[Transaction]) -> List[Transaction]:
|
def _find_best_match_single_call(
|
||||||
|
self, receipt: Receipt, candidates: List[Transaction]
|
||||||
|
) -> Match:
|
||||||
|
"""Find the best match using a single AI call to evaluate all candidates"""
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Limit candidates to avoid token limits (adjust based on your needs)
|
||||||
|
max_candidates = 10
|
||||||
|
if len(candidates) > max_candidates:
|
||||||
|
# Sort by amount similarity and take top candidates
|
||||||
|
candidates = sorted(
|
||||||
|
candidates, key=lambda t: abs(receipt.amount - abs(t.amount))
|
||||||
|
)[:max_candidates]
|
||||||
|
logger.info(
|
||||||
|
f"Limited candidates to top {max_candidates} by amount similarity"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build comprehensive prompt with all candidates
|
||||||
|
candidates_text = ""
|
||||||
|
for i, transaction in enumerate(candidates):
|
||||||
|
transaction_amount_abs = abs(transaction.amount)
|
||||||
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||||
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||||||
|
amount_percent_diff = (
|
||||||
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates_text += f"""
|
||||||
|
Candidate {i + 1}:
|
||||||
|
- Vendor: {transaction.vendor}
|
||||||
|
- Amount: ${transaction.amount} (absolute: ${transaction_amount_abs})
|
||||||
|
- Date: {transaction.transaction_date.strftime("%Y-%m-%d")} ({date_diff} days difference)
|
||||||
|
- Notes: {transaction.notes}
|
||||||
|
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are an expert at matching receipts to bank transactions. Analyze the receipt below against ALL the candidate transactions and return the BEST match.
|
||||||
|
|
||||||
|
RECEIPT TO MATCH:
|
||||||
|
- Vendor: {receipt.vendor}
|
||||||
|
- Amount: ${receipt.amount}
|
||||||
|
- Date: {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||||||
|
- Description: {receipt.description}
|
||||||
|
- Category: {receipt.category}
|
||||||
|
|
||||||
|
CANDIDATE TRANSACTIONS:
|
||||||
|
{candidates_text}
|
||||||
|
|
||||||
|
SCORING CRITERIA:
|
||||||
|
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
||||||
|
- High confidence (minor differences): 0.8-0.94
|
||||||
|
- Medium confidence (moderate differences): 0.6-0.79
|
||||||
|
- Low confidence (significant differences): 0.4-0.59
|
||||||
|
- Very low confidence (major differences): 0.2-0.39
|
||||||
|
- Minimal similarity: 0.1-0.19
|
||||||
|
- No meaningful similarity: 0.0-0.09
|
||||||
|
|
||||||
|
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
|
||||||
|
|
||||||
|
IMPORTANT: You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
||||||
|
Return ONLY the best match in this exact format:
|
||||||
|
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
|
||||||
|
|
||||||
|
Example: 3|0.87|Same vendor name, exact amount match, 1 day apart
|
||||||
|
Example of low match: 5|0.15|Best available option despite significant differences in vendor and amount
|
||||||
|
"""
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
result = self._call_groq_api_with_timeout(
|
||||||
|
prompt, timeout=45
|
||||||
|
) # Longer timeout for complex prompt
|
||||||
|
|
||||||
|
# Parse the single result
|
||||||
|
candidate_num, score, reason = self._parse_single_match_response(result)
|
||||||
|
|
||||||
|
if candidate_num == -1: # Parsing error occurred
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse AI response for receipt: {receipt.vendor}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if 0 <= candidate_num < len(candidates):
|
||||||
|
best_transaction = candidates[candidate_num]
|
||||||
|
logger.info(
|
||||||
|
f"AI selected candidate {candidate_num + 1}: {best_transaction.vendor} (score: {score:.3f})"
|
||||||
|
)
|
||||||
|
return Match(receipt, best_transaction, score, reason)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"AI returned invalid candidate number: {candidate_num}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||||||
|
)
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
sleep_time = self.retry_delay * (2**attempt)
|
||||||
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
else:
|
||||||
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_single_match_response(self, result: str) -> Tuple[int, float, str]:
|
||||||
|
"""Parse AI response for single best match"""
|
||||||
|
result = result.strip()
|
||||||
|
logger.debug(f"Parsing single match response: {result}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if result.upper().startswith("NONE"):
|
||||||
|
# This should not happen with new prompt, but handle as parsing error
|
||||||
|
logger.warning(
|
||||||
|
"AI returned NONE despite being instructed to always return best match"
|
||||||
|
)
|
||||||
|
return -1, 0.0, "AI returned NONE unexpectedly"
|
||||||
|
|
||||||
|
if "|" in result:
|
||||||
|
parts = result.split("|")
|
||||||
|
if len(parts) >= 3:
|
||||||
|
candidate_str = parts[0].strip()
|
||||||
|
score_str = parts[1].strip()
|
||||||
|
reason = "|".join(parts[2:]).strip()
|
||||||
|
|
||||||
|
# Extract candidate number
|
||||||
|
import re
|
||||||
|
|
||||||
|
candidate_match = re.search(r"\d+", candidate_str)
|
||||||
|
if candidate_match:
|
||||||
|
candidate_num = (
|
||||||
|
int(candidate_match.group()) - 1
|
||||||
|
) # Convert to 0-based index
|
||||||
|
else:
|
||||||
|
raise ValueError("No candidate number found")
|
||||||
|
|
||||||
|
# Extract score
|
||||||
|
score_clean = "".join(
|
||||||
|
c for c in score_str if c.isdigit() or c == "."
|
||||||
|
)
|
||||||
|
score = float(score_clean) if score_clean else 0.0
|
||||||
|
|
||||||
|
# Ensure score is in valid range
|
||||||
|
score = max(0.0, min(1.0, score))
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Parsed: candidate={candidate_num}, score={score}, reason={reason}"
|
||||||
|
)
|
||||||
|
return candidate_num, score, reason
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error parsing single match response: {e}")
|
||||||
|
|
||||||
|
# Fallback
|
||||||
|
logger.warning(f"Could not parse single match response: {result}")
|
||||||
|
return -1, 0.0, f"Parse error: {result[:50]}..."
|
||||||
|
|
||||||
|
def _filter_candidates(
|
||||||
|
self, receipt: Receipt, transactions: List[Transaction]
|
||||||
|
) -> List[Transaction]:
|
||||||
"""Filter transactions to create a reasonable candidate list"""
|
"""Filter transactions to create a reasonable candidate list"""
|
||||||
candidates = []
|
candidates = []
|
||||||
amount_threshold = receipt.amount * 2.0 # 200% threshold - very inclusive
|
amount_threshold = receipt.amount * 2.0 # 200% threshold - very inclusive
|
||||||
|
|
||||||
for transaction in transactions:
|
for transaction in transactions:
|
||||||
# Use absolute value for transaction amount comparison
|
# Use absolute value for transaction amount comparison
|
||||||
transaction_amount_abs = abs(transaction.amount)
|
transaction_amount_abs = abs(transaction.amount)
|
||||||
|
|
||||||
# Only exclude transactions with obviously different amounts
|
# Only exclude transactions with obviously different amounts
|
||||||
if abs(receipt.amount - transaction_amount_abs) <= amount_threshold:
|
if abs(receipt.amount - transaction_amount_abs) <= amount_threshold:
|
||||||
candidates.append(transaction)
|
candidates.append(transaction)
|
||||||
|
|
||||||
logger.debug(f"Filtered {len(transactions)} transactions to {len(candidates)} candidates")
|
logger.debug(
|
||||||
|
f"Filtered {len(transactions)} transactions to {len(candidates)} candidates"
|
||||||
|
)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
def _calculate_match_score(self, receipt: Receipt, transaction: Transaction) -> Tuple[float, str]:
|
def _find_best_match_legacy(
|
||||||
|
self, receipt: Receipt, transactions: List[Transaction]
|
||||||
|
) -> Match:
|
||||||
|
"""Legacy method: Find the best match using individual API calls (kept as fallback)"""
|
||||||
|
candidates = self._filter_candidates(receipt, transactions)
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
highest_score = 0
|
||||||
|
|
||||||
|
for transaction in candidates:
|
||||||
|
score, reason = self._calculate_match_score(receipt, transaction)
|
||||||
|
logger.debug(
|
||||||
|
f"Score {score:.3f} for transaction {transaction.vendor}: {reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if score > highest_score:
|
||||||
|
highest_score = score
|
||||||
|
best_match = Match(receipt, transaction, score, reason)
|
||||||
|
|
||||||
|
return best_match
|
||||||
|
|
||||||
|
def _calculate_match_score(
|
||||||
|
self, receipt: Receipt, transaction: Transaction
|
||||||
|
) -> Tuple[float, str]:
|
||||||
"""Calculate match score using AI"""
|
"""Calculate match score using AI"""
|
||||||
# Calculate differences for the AI to consider
|
# Calculate differences for the AI to consider
|
||||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||||
transaction_amount_abs = abs(transaction.amount)
|
transaction_amount_abs = abs(transaction.amount)
|
||||||
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||||||
amount_percent_diff = (amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
amount_percent_diff = (
|
||||||
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
Compare this receipt with this transaction and provide a confidence score (0-1) and brief reason.
|
Compare this receipt with this transaction and provide a confidence score (0-1) and brief reason, the reason must be a single sentence without any special formatting.
|
||||||
|
|
||||||
Receipt: {receipt.vendor}, ${receipt.amount}, {receipt.receipt_date.strftime('%Y-%m-%d')}
|
Receipt: {receipt.vendor}, ${receipt.amount}, {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||||||
Receipt Description: {receipt.description}
|
Receipt Description: {receipt.description}
|
||||||
Receipt Category: {receipt.category}
|
Receipt Category: {receipt.category}
|
||||||
Transaction: {transaction.vendor}, ${transaction.amount} (absolute: ${transaction_amount_abs}), {transaction.transaction_date.strftime('%Y-%m-%d')}
|
Transaction: {transaction.vendor}, ${transaction.amount} (absolute: ${transaction_amount_abs}), {transaction.transaction_date.strftime("%Y-%m-%d")}
|
||||||
Transaction Notes: {transaction.notes}
|
Transaction Notes: {transaction.notes}
|
||||||
|
|
||||||
Differences:
|
Differences:
|
||||||
@@ -135,61 +342,78 @@ class AIMatcher:
|
|||||||
Format: [score]|[reason]
|
Format: [score]|[reason]
|
||||||
Example: 0.85|Same vendor, same amount, 2 days apart
|
Example: 0.85|Same vendor, same amount, 2 days apart
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for attempt in range(self.max_retries):
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
result = self._call_groq_api_with_timeout(prompt, timeout=30) # Increased timeout
|
result = self._call_groq_api_with_timeout(
|
||||||
|
prompt, timeout=30
|
||||||
|
) # Increased timeout
|
||||||
|
|
||||||
# Parse the result - handle multiple formats
|
# Parse the result - handle multiple formats
|
||||||
score, reason = self._parse_ai_response(result)
|
score, reason = self._parse_ai_response(result)
|
||||||
|
|
||||||
logger.debug(f"AI Response: {result}")
|
logger.debug(f"AI Response: {result}")
|
||||||
logger.debug(f"Parsed: score={score}, reason={reason}")
|
logger.debug(f"Parsed: score={score}, reason={reason}")
|
||||||
|
|
||||||
return score, reason
|
return score, reason
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}")
|
logger.warning(
|
||||||
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||||||
|
)
|
||||||
if attempt < self.max_retries - 1:
|
if attempt < self.max_retries - 1:
|
||||||
# Exponential backoff for rate limiting
|
# Exponential backoff for rate limiting
|
||||||
sleep_time = self.retry_delay * (2 ** attempt)
|
sleep_time = self.retry_delay * (2**attempt)
|
||||||
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
else:
|
else:
|
||||||
logger.error(f"All attempts failed for receipt {receipt.id}")
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||||||
return 0.0, f"AI error after {self.max_retries} attempts: {str(e)}"
|
return 0.0, f"AI error after {self.max_retries} attempts: {str(e)}"
|
||||||
|
|
||||||
def _parse_ai_response(self, result: str) -> Tuple[float, str]:
|
def _parse_ai_response(self, result: str) -> Tuple[float, str]:
|
||||||
"""Parse AI response with robust error handling"""
|
"""Parse AI response with robust error handling"""
|
||||||
result = result.strip()
|
result = result.strip()
|
||||||
logger.debug(f"Parsing AI response: {result}")
|
logger.debug(f"Parsing AI response: {result}")
|
||||||
|
|
||||||
# Try to find score in various formats
|
# Try to find score in various formats
|
||||||
if '|' in result:
|
if "|" in result:
|
||||||
parts = result.split('|')
|
parts = result.split("|")
|
||||||
logger.debug(f"Split response into {len(parts)} parts: {parts}")
|
logger.debug(f"Split response into {len(parts)} parts: {parts}")
|
||||||
|
|
||||||
# Look for a numeric score in any part
|
# Look for a numeric score in any part
|
||||||
for i, part in enumerate(parts):
|
for i, part in enumerate(parts):
|
||||||
part = part.strip()
|
part = part.strip()
|
||||||
try:
|
try:
|
||||||
# Remove any non-numeric characters except decimal point
|
# Remove any non-numeric characters except decimal point
|
||||||
score_str_clean = ''.join(c for c in part if c.isdigit() or c == '.')
|
score_str_clean = "".join(
|
||||||
|
c for c in part if c.isdigit() or c == "."
|
||||||
|
)
|
||||||
if score_str_clean:
|
if score_str_clean:
|
||||||
score = float(score_str_clean)
|
score = float(score_str_clean)
|
||||||
if 0 <= score <= 1: # Valid confidence score
|
if 0 <= score <= 1: # Valid confidence score
|
||||||
# Get reason from other parts
|
# Get reason from other parts
|
||||||
reason_parts = [p.strip() for j, p in enumerate(parts) if j != i and p.strip()]
|
reason_parts = [
|
||||||
reason = ' | '.join(reason_parts) if reason_parts else "Score extracted"
|
p.strip()
|
||||||
logger.debug(f"Found score {score} in part {i}, reason: {reason}")
|
for j, p in enumerate(parts)
|
||||||
|
if j != i and p.strip()
|
||||||
|
]
|
||||||
|
reason = (
|
||||||
|
" | ".join(reason_parts)
|
||||||
|
if reason_parts
|
||||||
|
else "Score extracted"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Found score {score} in part {i}, reason: {reason}"
|
||||||
|
)
|
||||||
return score, reason
|
return score, reason
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Try to extract just a number from the response
|
# Try to extract just a number from the response
|
||||||
try:
|
try:
|
||||||
import re
|
import re
|
||||||
numbers = re.findall(r'\d+\.?\d*', result)
|
|
||||||
|
numbers = re.findall(r"\d+\.?\d*", result)
|
||||||
if numbers:
|
if numbers:
|
||||||
for num_str in numbers:
|
for num_str in numbers:
|
||||||
score = float(num_str)
|
score = float(num_str)
|
||||||
@@ -198,11 +422,12 @@ class AIMatcher:
|
|||||||
return score, f"Extracted from response: {result[:50]}..."
|
return score, f"Extracted from response: {result[:50]}..."
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Fallback - try to find any number and normalize it
|
# Fallback - try to find any number and normalize it
|
||||||
try:
|
try:
|
||||||
import re
|
import re
|
||||||
numbers = re.findall(r'\d+\.?\d*', result)
|
|
||||||
|
numbers = re.findall(r"\d+\.?\d*", result)
|
||||||
if numbers:
|
if numbers:
|
||||||
score = float(numbers[0])
|
score = float(numbers[0])
|
||||||
# Normalize to 0-1 range if it's a percentage or other scale
|
# Normalize to 0-1 range if it's a percentage or other scale
|
||||||
@@ -213,27 +438,27 @@ class AIMatcher:
|
|||||||
return score, f"Normalized from response: {result[:50]}..."
|
return score, f"Normalized from response: {result[:50]}..."
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Final fallback
|
# Final fallback
|
||||||
logger.warning(f"Could not parse AI response: {result}")
|
logger.warning(f"Could not parse AI response: {result}")
|
||||||
return 0.0, f"Unparseable response: {result[:50]}..."
|
return 0.0, f"Unparseable response: {result[:50]}..."
|
||||||
|
|
||||||
def _call_groq_api_with_timeout(self, prompt: str, timeout: int = 15) -> str:
|
def _call_groq_api_with_timeout(self, prompt: str, timeout: int = 15) -> str:
|
||||||
"""Make API call with timeout and retry logic"""
|
"""Make API call with timeout and retry logic"""
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
def api_call():
|
def api_call():
|
||||||
try:
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
max_tokens=200,
|
max_tokens=200,
|
||||||
temperature=0.1
|
temperature=0.1,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content.strip()
|
return response.choices[0].message.content.strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(api_call)
|
future = executor.submit(api_call)
|
||||||
@@ -241,4 +466,4 @@ class AIMatcher:
|
|||||||
except concurrent.futures.TimeoutError:
|
except concurrent.futures.TimeoutError:
|
||||||
raise Exception(f"API call timed out after {timeout} seconds")
|
raise Exception(f"API call timed out after {timeout} seconds")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
+78
-29
@@ -1,9 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Any, List
|
from typing import Any, Dict, List
|
||||||
import config
|
|
||||||
from models import Receipt, Transaction
|
from models import Receipt, Transaction
|
||||||
from tax_rules_engine import TaxRulesEngine
|
from tax_rules_engine import TaxRulesEngine
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AIRule:
|
class AIRule:
|
||||||
name: str
|
name: str
|
||||||
@@ -12,48 +13,88 @@ class AIRule:
|
|||||||
source: str
|
source: str
|
||||||
status: str = "active"
|
status: str = "active"
|
||||||
|
|
||||||
|
|
||||||
class AIRulesEngine:
|
class AIRulesEngine:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.rules: List[AIRule] = []
|
self.rules: List[AIRule] = []
|
||||||
self.tax_rules_engine = TaxRulesEngine()
|
self.tax_rules_engine = TaxRulesEngine()
|
||||||
self._load_default_rules()
|
self._load_default_rules()
|
||||||
|
|
||||||
def _load_default_rules(self):
|
def _load_default_rules(self):
|
||||||
self.rules = [
|
self.rules = [
|
||||||
AIRule("exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"),
|
AIRule(
|
||||||
AIRule("same_vendor_same_date", "vendor_match and date_diff <= 1", "high_confidence", "system"),
|
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
|
||||||
AIRule("gas_station_pattern", "vendor_contains_gas_or_fuel", "categorize_transport", "system"),
|
),
|
||||||
|
AIRule(
|
||||||
|
"same_vendor_same_date",
|
||||||
|
"vendor_match and date_diff <= 1",
|
||||||
|
"high_confidence",
|
||||||
|
"system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"gas_station_pattern",
|
||||||
|
"vendor_contains_gas_or_fuel",
|
||||||
|
"categorize_transport",
|
||||||
|
"system",
|
||||||
|
),
|
||||||
# Tax-related rules
|
# Tax-related rules
|
||||||
AIRule("fx_currency_mismatch", "currency_mismatch", "flag_fx_review", "tax_system"),
|
AIRule(
|
||||||
AIRule("meals_entertainment", "is_meals_entertainment", "apply_me_tax_rule", "tax_system"),
|
"fx_currency_mismatch",
|
||||||
AIRule("provincial_tax_calculation", "has_address_info", "calculate_provincial_tax", "tax_system")
|
"currency_mismatch",
|
||||||
|
"flag_fx_review",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"meals_entertainment",
|
||||||
|
"is_meals_entertainment",
|
||||||
|
"apply_me_tax_rule",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
|
AIRule(
|
||||||
|
"provincial_tax_calculation",
|
||||||
|
"has_address_info",
|
||||||
|
"calculate_provincial_tax",
|
||||||
|
"tax_system",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
||||||
results = {"auto_approve": False, "confidence_boost": 0, "category": None, "tax_analysis": {}}
|
results = {
|
||||||
|
"auto_approve": False,
|
||||||
|
"confidence_boost": 0,
|
||||||
|
"category": None,
|
||||||
|
"tax_analysis": {},
|
||||||
|
}
|
||||||
|
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
if rule.status != "active":
|
if rule.status != "active":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._evaluate_condition(rule.condition, receipt, transaction):
|
if self._evaluate_condition(rule.condition, receipt, transaction):
|
||||||
self._execute_action(rule.action, results, receipt, transaction)
|
self._execute_action(rule.action, results, receipt, transaction)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _evaluate_condition(self, condition: str, receipt: Receipt, transaction: Transaction) -> bool:
|
def _evaluate_condition(
|
||||||
|
self, condition: str, receipt: Receipt, transaction: Transaction
|
||||||
|
) -> bool:
|
||||||
"""Safely evaluate rule conditions without using eval()"""
|
"""Safely evaluate rule conditions without using eval()"""
|
||||||
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
||||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||||
vendor_match = receipt.vendor.lower() in transaction.vendor.lower() or transaction.vendor.lower() in receipt.vendor.lower()
|
vendor_match = (
|
||||||
|
receipt.vendor.lower() in transaction.vendor.lower()
|
||||||
|
or transaction.vendor.lower() in receipt.vendor.lower()
|
||||||
|
)
|
||||||
vendor_lower = receipt.vendor.lower()
|
vendor_lower = receipt.vendor.lower()
|
||||||
vendor_contains_gas_or_fuel = 'gas' in vendor_lower or 'fuel' in vendor_lower
|
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
|
||||||
|
|
||||||
# Tax-related conditions
|
# Tax-related conditions
|
||||||
currency_mismatch = receipt.currency != transaction.currency
|
currency_mismatch = receipt.currency != transaction.currency
|
||||||
is_meals_entertainment = receipt.is_meals_entertainment
|
is_meals_entertainment = receipt.is_meals_entertainment
|
||||||
has_address_info = receipt.billing_address is not None or receipt.shipping_address is not None
|
has_address_info = (
|
||||||
|
receipt.billing_address is not None or receipt.shipping_address is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Handle specific condition types safely
|
# Handle specific condition types safely
|
||||||
if condition == "amount_diff <= 0.01":
|
if condition == "amount_diff <= 0.01":
|
||||||
return amount_diff <= 0.01
|
return amount_diff <= 0.01
|
||||||
@@ -86,14 +127,20 @@ class AIRulesEngine:
|
|||||||
"min": min,
|
"min": min,
|
||||||
"max": max,
|
"max": max,
|
||||||
"sum": sum,
|
"sum": sum,
|
||||||
"round": round
|
"round": round,
|
||||||
}
|
}
|
||||||
return eval(condition, safe_globals, {})
|
return eval(condition, safe_globals, {})
|
||||||
except (SyntaxError, NameError, TypeError) as e:
|
except (SyntaxError, NameError, TypeError) as e:
|
||||||
print(f"Warning: Invalid condition '{condition}': {e}")
|
print(f"Warning: Invalid condition '{condition}': {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _execute_action(self, action: str, results: Dict[str, Any], receipt: Receipt, transaction: Transaction):
|
def _execute_action(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
results: Dict[str, Any],
|
||||||
|
receipt: Receipt,
|
||||||
|
transaction: Transaction,
|
||||||
|
):
|
||||||
if action == "auto_approve":
|
if action == "auto_approve":
|
||||||
results["auto_approve"] = True
|
results["auto_approve"] = True
|
||||||
elif action == "high_confidence":
|
elif action == "high_confidence":
|
||||||
@@ -114,13 +161,15 @@ class AIRulesEngine:
|
|||||||
# Calculate provincial tax
|
# Calculate provincial tax
|
||||||
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
||||||
results["tax_analysis"]["sales_tax"] = tax_result
|
results["tax_analysis"]["sales_tax"] = tax_result
|
||||||
|
|
||||||
def add_rule(self, rule: AIRule):
|
def add_rule(self, rule: AIRule):
|
||||||
self.rules.append(rule)
|
self.rules.append(rule)
|
||||||
|
|
||||||
def remove_rule(self, rule_name: str):
|
def remove_rule(self, rule_name: str):
|
||||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||||
|
|
||||||
def apply_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
|
def apply_tax_rules(
|
||||||
|
self, receipt: Receipt, transaction: Transaction = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Apply all tax rules to a receipt/transaction pair"""
|
"""Apply all tax rules to a receipt/transaction pair"""
|
||||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||||
|
|||||||
+29
-7
@@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AddressRequest(BaseModel):
|
class AddressRequest(BaseModel):
|
||||||
province: str
|
province: str
|
||||||
@@ -8,6 +9,7 @@ class AddressRequest(BaseModel):
|
|||||||
postal_code: str
|
postal_code: str
|
||||||
country: str = "Canada"
|
country: str = "Canada"
|
||||||
|
|
||||||
|
|
||||||
class ReceiptRequest(BaseModel):
|
class ReceiptRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
file_name: str
|
file_name: str
|
||||||
@@ -24,6 +26,7 @@ class ReceiptRequest(BaseModel):
|
|||||||
currency: str = "CAD"
|
currency: str = "CAD"
|
||||||
is_meals_entertainment: bool = False
|
is_meals_entertainment: bool = False
|
||||||
|
|
||||||
|
|
||||||
class TransactionRequest(BaseModel):
|
class TransactionRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
transaction_date: datetime
|
transaction_date: datetime
|
||||||
@@ -34,6 +37,7 @@ class TransactionRequest(BaseModel):
|
|||||||
currency: str = "CAD"
|
currency: str = "CAD"
|
||||||
fx_rate: Optional[float] = None
|
fx_rate: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class AssetRequest(BaseModel):
|
class AssetRequest(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -44,42 +48,51 @@ class AssetRequest(BaseModel):
|
|||||||
cca_rate: float
|
cca_rate: float
|
||||||
asset_class: str
|
asset_class: str
|
||||||
|
|
||||||
|
|
||||||
class MatchingRequest(BaseModel):
|
class MatchingRequest(BaseModel):
|
||||||
receipt_ids: List[str]
|
receipt_ids: List[str]
|
||||||
transaction_ids: List[str]
|
transaction_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
class MatchResponse(BaseModel):
|
class MatchResponse(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
transaction_id: str
|
transaction_id: str
|
||||||
confidence_score: float
|
confidence_score: float
|
||||||
match_reason: str
|
match_reason: str
|
||||||
tax_analysis: Optional[dict] = None
|
receipt_vendor: str
|
||||||
# Currency information
|
receipt_amount: float
|
||||||
receipt_currency: str = "CAD"
|
receipt_description: str
|
||||||
transaction_currency: str = "CAD"
|
receipt_category: str
|
||||||
currency_match: bool = True
|
receipt_tax_amount: float
|
||||||
|
transaction_vendor: str
|
||||||
|
transaction_amount: float
|
||||||
|
|
||||||
|
|
||||||
class MatchingResponse(BaseModel):
|
class MatchingResponse(BaseModel):
|
||||||
matches: List[MatchResponse]
|
matches: List[MatchResponse]
|
||||||
stats: dict
|
stats: dict
|
||||||
|
|
||||||
|
|
||||||
class ApprovalRequest(BaseModel):
|
class ApprovalRequest(BaseModel):
|
||||||
match_id: str
|
match_id: str
|
||||||
approved: bool
|
approved: bool
|
||||||
reason: Optional[str] = None
|
reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class RuleRequest(BaseModel):
|
class RuleRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
condition: str
|
condition: str
|
||||||
action: str
|
action: str
|
||||||
source: str = "user"
|
source: str = "user"
|
||||||
|
|
||||||
|
|
||||||
class DocumentUploadResponse(BaseModel):
|
class DocumentUploadResponse(BaseModel):
|
||||||
file_id: str
|
file_id: str
|
||||||
filename: str
|
filename: str
|
||||||
upload_date: datetime
|
upload_date: datetime
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
class DocumentProcessResponse(BaseModel):
|
class DocumentProcessResponse(BaseModel):
|
||||||
file_id: str
|
file_id: str
|
||||||
extraction_success: bool
|
extraction_success: bool
|
||||||
@@ -92,11 +105,13 @@ class DocumentProcessResponse(BaseModel):
|
|||||||
confidence: Optional[float] = None
|
confidence: Optional[float] = None
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
# New tax-related models
|
# New tax-related models
|
||||||
class TaxCalculationRequest(BaseModel):
|
class TaxCalculationRequest(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
transaction_id: Optional[str] = None
|
transaction_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class TaxCalculationResponse(BaseModel):
|
class TaxCalculationResponse(BaseModel):
|
||||||
receipt_id: str
|
receipt_id: str
|
||||||
rules_applied: List[str]
|
rules_applied: List[str]
|
||||||
@@ -104,11 +119,13 @@ class TaxCalculationResponse(BaseModel):
|
|||||||
fx_analysis: Optional[dict] = None
|
fx_analysis: Optional[dict] = None
|
||||||
meals_entertainment: dict
|
meals_entertainment: dict
|
||||||
|
|
||||||
|
|
||||||
class DepreciationRequest(BaseModel):
|
class DepreciationRequest(BaseModel):
|
||||||
asset: AssetRequest
|
asset: AssetRequest
|
||||||
year: int
|
year: int
|
||||||
method: str # "straight_line" or "cca"
|
method: str # "straight_line" or "cca"
|
||||||
|
|
||||||
|
|
||||||
class DepreciationResponse(BaseModel):
|
class DepreciationResponse(BaseModel):
|
||||||
asset_id: str
|
asset_id: str
|
||||||
year: int
|
year: int
|
||||||
@@ -117,4 +134,9 @@ class DepreciationResponse(BaseModel):
|
|||||||
book_value: float
|
book_value: float
|
||||||
total_depreciation: Optional[float] = None
|
total_depreciation: Optional[float] = None
|
||||||
success: bool
|
success: bool
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
class MatchSpecificRequest(BaseModel):
|
||||||
|
file_ids: List[str]
|
||||||
|
categorization_id: str
|
||||||
|
|
||||||
+75
@@ -0,0 +1,75 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy import Column, DateTime, Float, Integer, String, create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
db_dependency = Annotated[Session, Depends(get_db)]
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def create_db_tables():
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all_data():
|
||||||
|
"""Clear all data from the database (useful for testing)"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
db.query(Transaction).delete()
|
||||||
|
db.query(Receipt).delete()
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Transactions table
|
||||||
|
class Transaction(Base):
|
||||||
|
__tablename__ = "transactions"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
transaction_id = Column(String, unique=True, index=True)
|
||||||
|
amount = Column(Float, nullable=False)
|
||||||
|
date = Column(DateTime, nullable=False)
|
||||||
|
vendor = Column(String, nullable=False)
|
||||||
|
description = Column(String, nullable=True)
|
||||||
|
category = Column(String, nullable=True)
|
||||||
|
tax_amount = Column(Float, nullable=True)
|
||||||
|
categorisation_id = Column(String, nullable=True)
|
||||||
|
user_id = Column(String, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Receipts table
|
||||||
|
class Receipt(Base):
|
||||||
|
__tablename__ = "receipts"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
receipt_id = Column(String, unique=True, index=True)
|
||||||
|
file_id = Column(String, unique=True, index=True)
|
||||||
|
amount = Column(Float, nullable=False)
|
||||||
|
date = Column(DateTime, nullable=False)
|
||||||
|
vendor = Column(String, nullable=False)
|
||||||
|
description = Column(String, nullable=True)
|
||||||
|
category = Column(String, nullable=True)
|
||||||
|
tax_amount = Column(Float, nullable=True)
|
||||||
|
confidence = Column(Float, nullable=True)
|
||||||
|
extraction_success = Column(String, nullable=True)
|
||||||
|
error_message = Column(String, nullable=True)
|
||||||
+39
-23
@@ -1,8 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FeedbackLog:
|
class FeedbackLog:
|
||||||
@@ -13,48 +14,63 @@ class FeedbackLog:
|
|||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
||||||
class FeedbackLogger:
|
class FeedbackLogger:
|
||||||
def __init__(self, log_file: str = "feedback_logs.json"):
|
def __init__(self, log_file: str = "feedback_logs.json"):
|
||||||
self.log_file = log_file
|
self.log_file = log_file
|
||||||
self.logs: List[FeedbackLog] = self._load_logs()
|
self.logs: List[FeedbackLog] = self._load_logs()
|
||||||
|
|
||||||
def _load_logs(self) -> List[FeedbackLog]:
|
def _load_logs(self) -> List[FeedbackLog]:
|
||||||
if not os.path.exists(self.log_file):
|
if not os.path.exists(self.log_file):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.log_file, 'r') as f:
|
with open(self.log_file, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return [FeedbackLog(**log) for log in data]
|
return [FeedbackLog(**log) for log in data]
|
||||||
except:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _save_logs(self):
|
def _save_logs(self):
|
||||||
with open(self.log_file, 'w') as f:
|
with open(self.log_file, "w") as f:
|
||||||
json.dump([{
|
json.dump(
|
||||||
'transaction_id': log.transaction_id,
|
[
|
||||||
'original_match': log.original_match,
|
{
|
||||||
'correction': log.correction,
|
"transaction_id": log.transaction_id,
|
||||||
'reason': log.reason,
|
"original_match": log.original_match,
|
||||||
'timestamp': log.timestamp.isoformat(),
|
"correction": log.correction,
|
||||||
'user_id': log.user_id
|
"reason": log.reason,
|
||||||
} for log in self.logs], f, indent=2)
|
"timestamp": log.timestamp.isoformat(),
|
||||||
|
"user_id": log.user_id,
|
||||||
def log_override(self, transaction_id: str, original_match: str, correction: str, reason: str, user_id: str):
|
}
|
||||||
|
for log in self.logs
|
||||||
|
],
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_override(
|
||||||
|
self,
|
||||||
|
transaction_id: str,
|
||||||
|
original_match: str,
|
||||||
|
correction: str,
|
||||||
|
reason: str,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
log = FeedbackLog(
|
log = FeedbackLog(
|
||||||
transaction_id=transaction_id,
|
transaction_id=transaction_id,
|
||||||
original_match=original_match,
|
original_match=original_match,
|
||||||
correction=correction,
|
correction=correction,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.logs.append(log)
|
self.logs.append(log)
|
||||||
self._save_logs()
|
self._save_logs()
|
||||||
|
|
||||||
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
||||||
return [log for log in self.logs if log.transaction_id == transaction_id]
|
return [log for log in self.logs if log.transaction_id == transaction_id]
|
||||||
|
|
||||||
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
||||||
cutoff = datetime.now() - timedelta(days=days)
|
cutoff = datetime.now() - timedelta(days=days)
|
||||||
return [log for log in self.logs if log.timestamp > cutoff]
|
return [log for log in self.logs if log.timestamp > cutoff]
|
||||||
|
|||||||
+81
-62
@@ -1,13 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import io
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
class GoogleDriveSync:
|
class GoogleDriveSync:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.service = None
|
self.service = None
|
||||||
self.processed_files = set()
|
self.processed_files = set()
|
||||||
|
|
||||||
def authenticate(self):
|
def authenticate(self):
|
||||||
"""Authenticate with Google Drive API"""
|
"""Authenticate with Google Drive API"""
|
||||||
try:
|
try:
|
||||||
@@ -15,111 +15,130 @@ class GoogleDriveSync:
|
|||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
|
||||||
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']
|
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||||
|
|
||||||
# Load existing credentials
|
# Load existing credentials
|
||||||
if os.path.exists('token.json'):
|
if os.path.exists("token.json"):
|
||||||
self.creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
self.creds = Credentials.from_authorized_user_file("token.json", SCOPES)
|
||||||
|
|
||||||
# If no valid credentials available, let user log in
|
# If no valid credentials available, let user log in
|
||||||
if not self.creds or not self.creds.valid:
|
if not self.creds or not self.creds.valid:
|
||||||
if self.creds and self.creds.expired and self.creds.refresh_token:
|
if self.creds and self.creds.expired and self.creds.refresh_token:
|
||||||
self.creds.refresh(Request())
|
self.creds.refresh(Request())
|
||||||
else:
|
else:
|
||||||
if not os.path.exists('credentials.json'):
|
if not os.path.exists("credentials.json"):
|
||||||
raise Exception("credentials.json not found. Please download from Google Cloud Console.")
|
raise Exception(
|
||||||
|
"credentials.json not found. Please download from Google Cloud Console."
|
||||||
flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
|
)
|
||||||
|
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
"credentials.json", SCOPES
|
||||||
|
)
|
||||||
self.creds = flow.run_local_server(port=0)
|
self.creds = flow.run_local_server(port=0)
|
||||||
|
|
||||||
# Save credentials for next run
|
# Save credentials for next run
|
||||||
with open('token.json', 'w') as token:
|
with open("token.json", "w") as token:
|
||||||
token.write(self.creds.to_json())
|
token.write(self.creds.to_json())
|
||||||
|
|
||||||
# Build the Drive service
|
# Build the Drive service
|
||||||
self.service = build('drive', 'v3', credentials=self.creds)
|
self.service = build("drive", "v3", credentials=self.creds)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Authentication error: {e}")
|
print(f"Authentication error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_folders(self) -> List[Dict[str, Any]]:
|
def list_folders(self) -> List[Dict[str, Any]]:
|
||||||
"""List all folders in Google Drive"""
|
"""List all folders in Google Drive"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = self.service.files().list(
|
results = (
|
||||||
q="mimeType='application/vnd.google-apps.folder'",
|
self.service.files()
|
||||||
pageSize=100,
|
.list(
|
||||||
fields="nextPageToken, files(id, name, createdTime, modifiedTime)"
|
q="mimeType='application/vnd.google-apps.folder'",
|
||||||
).execute()
|
pageSize=100,
|
||||||
|
fields="nextPageToken, files(id, name, createdTime, modifiedTime)",
|
||||||
return results.get('files', [])
|
)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
return results.get("files", [])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error listing folders: {e}")
|
print(f"Error listing folders: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
|
def get_folder_info(self, folder_id: str) -> Dict[str, Any]:
|
||||||
"""Get information about a Google Drive folder"""
|
"""Get information about a Google Drive folder"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
folder = self.service.files().get(
|
folder = (
|
||||||
fileId=folder_id,
|
self.service.files()
|
||||||
fields="id, name, createdTime, modifiedTime"
|
.get(fileId=folder_id, fields="id, name, createdTime, modifiedTime")
|
||||||
).execute()
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
return folder
|
return folder
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting folder info: {e}")
|
print(f"Error getting folder info: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
|
async def process_drive_files(self, folder_id: str = None) -> List[Dict[str, Any]]:
|
||||||
"""Process all receipt files from Google Drive"""
|
"""Process all receipt files from Google Drive"""
|
||||||
if not self.service:
|
if not self.service:
|
||||||
if not self.authenticate():
|
if not self.authenticate():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# File types to look for
|
# File types to look for
|
||||||
file_types = ["'application/pdf'", "'image/jpeg'", "'image/png'", "'image/gif'", "'image/bmp'"]
|
file_types = [
|
||||||
|
"'application/pdf'",
|
||||||
|
"'image/jpeg'",
|
||||||
|
"'image/png'",
|
||||||
|
"'image/gif'",
|
||||||
|
"'image/bmp'",
|
||||||
|
]
|
||||||
mime_types = " or ".join(file_types)
|
mime_types = " or ".join(file_types)
|
||||||
|
|
||||||
# Build query
|
# Build query
|
||||||
query = f"mimeType contains {mime_types}"
|
query = f"mimeType contains {mime_types}"
|
||||||
if folder_id:
|
if folder_id:
|
||||||
query += f" and '{folder_id}' in parents"
|
query += f" and '{folder_id}' in parents"
|
||||||
|
|
||||||
# Add date filter (last 30 days)
|
# Add date filter (last 30 days)
|
||||||
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + 'Z'
|
thirty_days_ago = (datetime.now() - timedelta(days=30)).isoformat() + "Z"
|
||||||
query += f" and modifiedTime > '{thirty_days_ago}'"
|
query += f" and modifiedTime > '{thirty_days_ago}'"
|
||||||
|
|
||||||
results_files = self.service.files().list(
|
results_files = (
|
||||||
q=query,
|
self.service.files()
|
||||||
pageSize=100,
|
.list(
|
||||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)"
|
q=query,
|
||||||
).execute()
|
pageSize=100,
|
||||||
|
fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
|
||||||
files = results_files.get('files', [])
|
)
|
||||||
files = [file for file in files if file['id'] not in self.processed_files]
|
.execute()
|
||||||
|
)
|
||||||
|
|
||||||
|
files = results_files.get("files", [])
|
||||||
|
files = [file for file in files if file["id"] not in self.processed_files]
|
||||||
|
|
||||||
# For demo purposes, return mock results
|
# For demo purposes, return mock results
|
||||||
for file in files[:3]: # Process first 3 files
|
for file in files[:3]: # Process first 3 files
|
||||||
mock_result = {
|
mock_result = {
|
||||||
"file_id": file['id'],
|
"file_id": file["id"],
|
||||||
"filename": file['name'],
|
"filename": file["name"],
|
||||||
"drive_modified": file['modifiedTime'],
|
"drive_modified": file["modifiedTime"],
|
||||||
"file_size": file.get('size', 0),
|
"file_size": file.get("size", 0),
|
||||||
"extraction_success": True,
|
"extraction_success": True,
|
||||||
"vendor": "Demo Vendor",
|
"vendor": "Demo Vendor",
|
||||||
"description": "Coffee and sandwich",
|
"description": "Coffee and sandwich",
|
||||||
@@ -127,12 +146,12 @@ class GoogleDriveSync:
|
|||||||
"tax_amount": 2.04,
|
"tax_amount": 2.04,
|
||||||
"date": "2024-01-15",
|
"date": "2024-01-15",
|
||||||
"category": "Food",
|
"category": "Food",
|
||||||
"confidence": 0.95
|
"confidence": 0.95,
|
||||||
}
|
}
|
||||||
results.append(mock_result)
|
results.append(mock_result)
|
||||||
self.processed_files.add(file['id'])
|
self.processed_files.add(file["id"])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing Drive files: {e}")
|
print(f"Error processing Drive files: {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
+36
-24
@@ -1,46 +1,53 @@
|
|||||||
from typing import List, Dict, Any
|
from typing import Any, Dict, List
|
||||||
from datetime import datetime
|
|
||||||
from ai_matcher import AIMatcher
|
from ai_matcher import AIMatcher
|
||||||
from ai_rules import AIRulesEngine
|
from ai_rules import AIRulesEngine
|
||||||
from feedback_logger import FeedbackLogger
|
from feedback_logger import FeedbackLogger
|
||||||
from models import Receipt, Transaction, Match
|
from models import Match, Receipt, Transaction
|
||||||
|
|
||||||
|
|
||||||
class MatchingEngine:
|
class MatchingEngine:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ai_matcher = AIMatcher()
|
self.ai_matcher = AIMatcher()
|
||||||
self.rules_engine = AIRulesEngine()
|
self.rules_engine = AIRulesEngine()
|
||||||
self.feedback_logger = FeedbackLogger()
|
self.feedback_logger = FeedbackLogger()
|
||||||
|
|
||||||
def process_matching(self, receipts: List[Receipt], transactions: List[Transaction]) -> List[Match]:
|
def process_matching(
|
||||||
|
self, receipts: List[Receipt], transactions: List[Transaction]
|
||||||
|
) -> List[Match]:
|
||||||
# Get AI matches
|
# Get AI matches
|
||||||
ai_matches = self.ai_matcher.match_receipts_to_transactions(receipts, transactions)
|
ai_matches = self.ai_matcher.match_receipts_to_transactions(
|
||||||
|
receipts, transactions
|
||||||
|
)
|
||||||
|
|
||||||
# Apply rules and enhance matches
|
# Apply rules and enhance matches
|
||||||
enhanced_matches = []
|
enhanced_matches = []
|
||||||
for match in ai_matches:
|
for match in ai_matches:
|
||||||
enhanced_match = self._enhance_match_with_rules(match)
|
enhanced_match = self._enhance_match_with_rules(match)
|
||||||
enhanced_matches.append(enhanced_match)
|
enhanced_matches.append(enhanced_match)
|
||||||
|
|
||||||
return enhanced_matches
|
return enhanced_matches
|
||||||
|
|
||||||
def _enhance_match_with_rules(self, match: Match) -> Match:
|
def _enhance_match_with_rules(self, match: Match) -> Match:
|
||||||
rule_results = self.rules_engine.apply_rules(match.receipt, match.transaction)
|
rule_results = self.rules_engine.apply_rules(match.receipt, match.transaction)
|
||||||
|
|
||||||
# Apply confidence boost from rules
|
# Apply confidence boost from rules
|
||||||
if rule_results["confidence_boost"] > 0:
|
if rule_results["confidence_boost"] > 0:
|
||||||
match.confidence_score = min(1.0, match.confidence_score + rule_results["confidence_boost"])
|
match.confidence_score = min(
|
||||||
|
1.0, match.confidence_score + rule_results["confidence_boost"]
|
||||||
|
)
|
||||||
|
|
||||||
# Auto-approve if rules say so
|
# Auto-approve if rules say so
|
||||||
if rule_results["auto_approve"]:
|
if rule_results["auto_approve"]:
|
||||||
match.confidence_score = 1.0
|
match.confidence_score = 1.0
|
||||||
match.match_reason += " (Auto-approved by rules)"
|
match.match_reason += " (Auto-approved by rules)"
|
||||||
|
|
||||||
# Add tax analysis to match
|
# Add tax analysis to match
|
||||||
if rule_results.get("tax_analysis"):
|
if rule_results.get("tax_analysis"):
|
||||||
match.tax_analysis = rule_results["tax_analysis"]
|
match.tax_analysis = rule_results["tax_analysis"]
|
||||||
|
|
||||||
return match
|
return match
|
||||||
|
|
||||||
def approve_match(self, match: Match, user_id: str):
|
def approve_match(self, match: Match, user_id: str):
|
||||||
# Log the approval
|
# Log the approval
|
||||||
self.feedback_logger.log_override(
|
self.feedback_logger.log_override(
|
||||||
@@ -48,9 +55,9 @@ class MatchingEngine:
|
|||||||
original_match=f"AI Score: {match.confidence_score}",
|
original_match=f"AI Score: {match.confidence_score}",
|
||||||
correction="Approved",
|
correction="Approved",
|
||||||
reason="User approved match",
|
reason="User approved match",
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reject_match(self, match: Match, reason: str, user_id: str):
|
def reject_match(self, match: Match, reason: str, user_id: str):
|
||||||
# Log the rejection
|
# Log the rejection
|
||||||
self.feedback_logger.log_override(
|
self.feedback_logger.log_override(
|
||||||
@@ -58,20 +65,25 @@ class MatchingEngine:
|
|||||||
original_match=f"AI Score: {match.confidence_score}",
|
original_match=f"AI Score: {match.confidence_score}",
|
||||||
correction="Rejected",
|
correction="Rejected",
|
||||||
reason=reason,
|
reason=reason,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_matching_stats(self, matches: List[Match]) -> Dict[str, Any]:
|
def get_matching_stats(self, matches: List[Match]) -> Dict[str, Any]:
|
||||||
if not matches:
|
if not matches:
|
||||||
return {"total": 0, "high_confidence": 0, "low_confidence": 0, "avg_score": 0}
|
return {
|
||||||
|
"total": 0,
|
||||||
|
"high_confidence": 0,
|
||||||
|
"low_confidence": 0,
|
||||||
|
"avg_score": 0,
|
||||||
|
}
|
||||||
|
|
||||||
high_confidence = len([m for m in matches if m.confidence_score >= 0.8])
|
high_confidence = len([m for m in matches if m.confidence_score >= 0.8])
|
||||||
low_confidence = len([m for m in matches if m.confidence_score < 0.8])
|
low_confidence = len([m for m in matches if m.confidence_score < 0.8])
|
||||||
avg_score = sum(m.confidence_score for m in matches) / len(matches)
|
avg_score = sum(m.confidence_score for m in matches) / len(matches)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total": len(matches),
|
"total": len(matches),
|
||||||
"high_confidence": high_confidence,
|
"high_confidence": high_confidence,
|
||||||
"low_confidence": low_confidence,
|
"low_confidence": low_confidence,
|
||||||
"avg_score": round(avg_score, 3)
|
"avg_score": round(avg_score, 3),
|
||||||
}
|
}
|
||||||
|
|||||||
+16
-16
@@ -1,16 +1,16 @@
|
|||||||
groq>=0.5.0
|
groq
|
||||||
python-dotenv==1.0.0
|
python-dotenv
|
||||||
pandas==2.1.4
|
pandas
|
||||||
numpy==1.24.3
|
numpy
|
||||||
fastapi==0.104.1
|
fastapi
|
||||||
uvicorn==0.24.0
|
uvicorn
|
||||||
pydantic==2.5.0
|
pydantic
|
||||||
requests==2.31.0
|
requests
|
||||||
python-multipart==0.0.6
|
python-multipart
|
||||||
Pillow==10.0.1
|
Pillow
|
||||||
PyPDF2==3.0.1
|
PyPDF2
|
||||||
aiofiles==23.2.1
|
aiofiles
|
||||||
google-auth==2.23.4
|
google-auth
|
||||||
google-auth-oauthlib==1.1.0
|
google-auth-oauthlib
|
||||||
google-auth-httplib2==0.1.1
|
google-auth-httplib2
|
||||||
google-api-python-client==2.108.0
|
google-api-python-client
|
||||||
+75
-70
@@ -1,13 +1,14 @@
|
|||||||
from typing import Dict, Any, Optional, Tuple
|
|
||||||
from datetime import datetime
|
|
||||||
from models import Receipt, Transaction, Address, Asset
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from models import Address, Asset, Receipt, Transaction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TaxRulesEngine:
|
class TaxRulesEngine:
|
||||||
"""Engine to handle tax calculations based on the four tax rules"""
|
"""Engine to handle tax calculations based on the four tax rules"""
|
||||||
|
|
||||||
# Provincial tax rates (simplified - in production, use a tax rate API)
|
# Provincial tax rates (simplified - in production, use a tax rate API)
|
||||||
PROVINCIAL_TAX_RATES = {
|
PROVINCIAL_TAX_RATES = {
|
||||||
"ON": 0.13, # Ontario HST
|
"ON": 0.13, # Ontario HST
|
||||||
@@ -24,10 +25,10 @@ class TaxRulesEngine:
|
|||||||
"NU": 0.05, # Nunavut
|
"NU": 0.05, # Nunavut
|
||||||
"YT": 0.05, # Yukon
|
"YT": 0.05, # Yukon
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def apply_sales_tax_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
def apply_sales_tax_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sales Tax Rule: Apply correct sales tax based on billing vs shipping addresses
|
Sales Tax Rule: Apply correct sales tax based on billing vs shipping addresses
|
||||||
@@ -35,43 +36,45 @@ class TaxRulesEngine:
|
|||||||
try:
|
try:
|
||||||
# Determine which address to use for tax calculation
|
# Determine which address to use for tax calculation
|
||||||
tax_address = self._get_tax_address(receipt)
|
tax_address = self._get_tax_address(receipt)
|
||||||
|
|
||||||
if not tax_address:
|
if not tax_address:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "No valid address found for tax calculation",
|
"error": "No valid address found for tax calculation",
|
||||||
"calculated_tax": 0.0,
|
"calculated_tax": 0.0,
|
||||||
"tax_rate": 0.0
|
"tax_rate": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get tax rate for the province
|
# Get tax rate for the province
|
||||||
tax_rate = self.PROVINCIAL_TAX_RATES.get(tax_address.province, 0.0)
|
tax_rate = self.PROVINCIAL_TAX_RATES.get(tax_address.province, 0.0)
|
||||||
|
|
||||||
# Calculate tax amount
|
# Calculate tax amount
|
||||||
calculated_tax = receipt.amount * tax_rate
|
calculated_tax = receipt.amount * tax_rate
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"calculated_tax": calculated_tax,
|
"calculated_tax": calculated_tax,
|
||||||
"tax_rate": tax_rate,
|
"tax_rate": tax_rate,
|
||||||
"tax_address": tax_address.province,
|
"tax_address": tax_address.province,
|
||||||
"rule_applied": "Sales Tax Rule"
|
"rule_applied": "Sales Tax Rule",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error applying sales tax rule: {str(e)}")
|
self.logger.error(f"Error applying sales tax rule: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"calculated_tax": 0.0,
|
"calculated_tax": 0.0,
|
||||||
"tax_rate": 0.0
|
"tax_rate": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_tax_address(self, receipt: Receipt) -> Optional[Address]:
|
def _get_tax_address(self, receipt: Receipt) -> Optional[Address]:
|
||||||
"""Determine which address to use for tax calculation"""
|
"""Determine which address to use for tax calculation"""
|
||||||
# Rule: Use shipping address if different from billing, otherwise use billing
|
# Rule: Use shipping address if different from billing, otherwise use billing
|
||||||
if receipt.shipping_address and receipt.billing_address:
|
if receipt.shipping_address and receipt.billing_address:
|
||||||
if self._addresses_different(receipt.billing_address, receipt.shipping_address):
|
if self._addresses_different(
|
||||||
|
receipt.billing_address, receipt.shipping_address
|
||||||
|
):
|
||||||
return receipt.shipping_address
|
return receipt.shipping_address
|
||||||
else:
|
else:
|
||||||
return receipt.billing_address
|
return receipt.billing_address
|
||||||
@@ -81,14 +84,18 @@ class TaxRulesEngine:
|
|||||||
return receipt.billing_address
|
return receipt.billing_address
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _addresses_different(self, billing: Address, shipping: Address) -> bool:
|
def _addresses_different(self, billing: Address, shipping: Address) -> bool:
|
||||||
"""Check if billing and shipping addresses are different"""
|
"""Check if billing and shipping addresses are different"""
|
||||||
return (billing.province != shipping.province or
|
return (
|
||||||
billing.city != shipping.city or
|
billing.province != shipping.province
|
||||||
billing.postal_code != shipping.postal_code)
|
or billing.city != shipping.city
|
||||||
|
or billing.postal_code != shipping.postal_code
|
||||||
def apply_fx_rule(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
)
|
||||||
|
|
||||||
|
def apply_fx_rule(
|
||||||
|
self, receipt: Receipt, transaction: Transaction
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Foreign Exchange Rule: Handle currency mismatches
|
Foreign Exchange Rule: Handle currency mismatches
|
||||||
"""
|
"""
|
||||||
@@ -96,7 +103,7 @@ class TaxRulesEngine:
|
|||||||
# Check for currency mismatch
|
# Check for currency mismatch
|
||||||
if receipt.currency != transaction.currency:
|
if receipt.currency != transaction.currency:
|
||||||
fx_discrepancy = abs(receipt.amount - abs(transaction.amount))
|
fx_discrepancy = abs(receipt.amount - abs(transaction.amount))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"fx_discrepancy": fx_discrepancy,
|
"fx_discrepancy": fx_discrepancy,
|
||||||
@@ -105,26 +112,28 @@ class TaxRulesEngine:
|
|||||||
"receipt_amount": receipt.amount,
|
"receipt_amount": receipt.amount,
|
||||||
"transaction_amount": abs(transaction.amount),
|
"transaction_amount": abs(transaction.amount),
|
||||||
"requires_manual_review": True,
|
"requires_manual_review": True,
|
||||||
"rule_applied": "Foreign Exchange Rule"
|
"rule_applied": "Foreign Exchange Rule",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"fx_discrepancy": 0.0,
|
"fx_discrepancy": 0.0,
|
||||||
"requires_manual_review": False,
|
"requires_manual_review": False,
|
||||||
"rule_applied": "No FX Rule (same currency)"
|
"rule_applied": "No FX Rule (same currency)",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error applying FX rule: {str(e)}")
|
self.logger.error(f"Error applying FX rule: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"fx_discrepancy": 0.0,
|
"fx_discrepancy": 0.0,
|
||||||
"requires_manual_review": False
|
"requires_manual_review": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def calculate_straight_line_depreciation(self, asset: Asset, year: int) -> Dict[str, Any]:
|
def calculate_straight_line_depreciation(
|
||||||
|
self, asset: Asset, year: int
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Straight-Line Depreciation for accounting purposes
|
Straight-Line Depreciation for accounting purposes
|
||||||
"""
|
"""
|
||||||
@@ -133,28 +142,26 @@ class TaxRulesEngine:
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"Year {year} exceeds useful life of {asset.useful_life_years} years",
|
"error": f"Year {year} exceeds useful life of {asset.useful_life_years} years",
|
||||||
"depreciation": 0.0
|
"depreciation": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Straight-line formula: (Cost - Residual Value) / Useful Life
|
# Straight-line formula: (Cost - Residual Value) / Useful Life
|
||||||
annual_depreciation = (asset.purchase_amount - asset.residual_value) / asset.useful_life_years
|
annual_depreciation = (
|
||||||
|
asset.purchase_amount - asset.residual_value
|
||||||
|
) / asset.useful_life_years
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"depreciation": annual_depreciation,
|
"depreciation": annual_depreciation,
|
||||||
"book_value": asset.purchase_amount - (annual_depreciation * year),
|
"book_value": asset.purchase_amount - (annual_depreciation * year),
|
||||||
"method": "Straight-Line",
|
"method": "Straight-Line",
|
||||||
"rule_applied": "Depreciation Rule (Accounting)"
|
"rule_applied": "Depreciation Rule (Accounting)",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error calculating straight-line depreciation: {str(e)}")
|
self.logger.error(f"Error calculating straight-line depreciation: {str(e)}")
|
||||||
return {
|
return {"success": False, "error": str(e), "depreciation": 0.0}
|
||||||
"success": False,
|
|
||||||
"error": str(e),
|
|
||||||
"depreciation": 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
def calculate_cca_depreciation(self, asset: Asset, year: int) -> Dict[str, Any]:
|
def calculate_cca_depreciation(self, asset: Asset, year: int) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
CCA (Capital Cost Allowance) Depreciation for tax purposes
|
CCA (Capital Cost Allowance) Depreciation for tax purposes
|
||||||
@@ -164,40 +171,36 @@ class TaxRulesEngine:
|
|||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "Year must be at least 1",
|
"error": "Year must be at least 1",
|
||||||
"depreciation": 0.0
|
"depreciation": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# CCA uses declining balance method
|
# CCA uses declining balance method
|
||||||
book_value = asset.purchase_amount
|
book_value = asset.purchase_amount
|
||||||
total_depreciation = 0.0
|
total_depreciation = 0.0
|
||||||
|
|
||||||
for current_year in range(1, year + 1):
|
for current_year in range(1, year + 1):
|
||||||
# CCA is calculated on the declining balance
|
# CCA is calculated on the declining balance
|
||||||
cca_amount = book_value * asset.cca_rate
|
cca_amount = book_value * asset.cca_rate
|
||||||
book_value -= cca_amount
|
book_value -= cca_amount
|
||||||
total_depreciation += cca_amount
|
total_depreciation += cca_amount
|
||||||
|
|
||||||
# Stop if book value reaches residual value
|
# Stop if book value reaches residual value
|
||||||
if book_value <= asset.residual_value:
|
if book_value <= asset.residual_value:
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"depreciation": cca_amount, # Current year depreciation
|
"depreciation": cca_amount, # Current year depreciation
|
||||||
"total_depreciation": total_depreciation,
|
"total_depreciation": total_depreciation,
|
||||||
"book_value": max(book_value, asset.residual_value),
|
"book_value": max(book_value, asset.residual_value),
|
||||||
"method": "CCA Declining Balance",
|
"method": "CCA Declining Balance",
|
||||||
"rule_applied": "Depreciation Rule (Tax)"
|
"rule_applied": "Depreciation Rule (Tax)",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error calculating CCA depreciation: {str(e)}")
|
self.logger.error(f"Error calculating CCA depreciation: {str(e)}")
|
||||||
return {
|
return {"success": False, "error": str(e), "depreciation": 0.0}
|
||||||
"success": False,
|
|
||||||
"error": str(e),
|
|
||||||
"depreciation": 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
def apply_meals_entertainment_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
def apply_meals_entertainment_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Meals & Entertainment Tax Deduction Rule
|
Meals & Entertainment Tax Deduction Rule
|
||||||
@@ -208,36 +211,38 @@ class TaxRulesEngine:
|
|||||||
"success": True,
|
"success": True,
|
||||||
"tax_deduction": receipt.amount,
|
"tax_deduction": receipt.amount,
|
||||||
"accounting_deduction": receipt.amount,
|
"accounting_deduction": receipt.amount,
|
||||||
"rule_applied": "No M&E Rule (not meals/entertainment)"
|
"rule_applied": "No M&E Rule (not meals/entertainment)",
|
||||||
}
|
}
|
||||||
|
|
||||||
# For tax purposes: 50% deductible
|
# For tax purposes: 50% deductible
|
||||||
tax_deduction = receipt.amount * 0.5
|
tax_deduction = receipt.amount * 0.5
|
||||||
|
|
||||||
# For accounting purposes: 100% deductible
|
# For accounting purposes: 100% deductible
|
||||||
accounting_deduction = receipt.amount
|
accounting_deduction = receipt.amount
|
||||||
|
|
||||||
# Sales tax is fully deductible for accounting
|
# Sales tax is fully deductible for accounting
|
||||||
tax_on_meal = receipt.tax
|
tax_on_meal = receipt.tax
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"tax_deduction": tax_deduction,
|
"tax_deduction": tax_deduction,
|
||||||
"accounting_deduction": accounting_deduction,
|
"accounting_deduction": accounting_deduction,
|
||||||
"tax_on_meal": tax_on_meal,
|
"tax_on_meal": tax_on_meal,
|
||||||
"rule_applied": "Meals & Entertainment Rule"
|
"rule_applied": "Meals & Entertainment Rule",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error applying meals & entertainment rule: {str(e)}")
|
self.logger.error(f"Error applying meals & entertainment rule: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"tax_deduction": 0.0,
|
"tax_deduction": 0.0,
|
||||||
"accounting_deduction": 0.0
|
"accounting_deduction": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def apply_all_tax_rules(self, receipt: Receipt, transaction: Transaction = None) -> Dict[str, Any]:
|
def apply_all_tax_rules(
|
||||||
|
self, receipt: Receipt, transaction: Transaction = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Apply all tax rules to a receipt
|
Apply all tax rules to a receipt
|
||||||
"""
|
"""
|
||||||
@@ -246,26 +251,26 @@ class TaxRulesEngine:
|
|||||||
"rules_applied": [],
|
"rules_applied": [],
|
||||||
"sales_tax": {},
|
"sales_tax": {},
|
||||||
"fx_analysis": {},
|
"fx_analysis": {},
|
||||||
"meals_entertainment": {}
|
"meals_entertainment": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply Sales Tax Rule
|
# Apply Sales Tax Rule
|
||||||
sales_tax_result = self.apply_sales_tax_rule(receipt)
|
sales_tax_result = self.apply_sales_tax_rule(receipt)
|
||||||
results["sales_tax"] = sales_tax_result
|
results["sales_tax"] = sales_tax_result
|
||||||
if sales_tax_result["success"]:
|
if sales_tax_result["success"]:
|
||||||
results["rules_applied"].append("Sales Tax Rule")
|
results["rules_applied"].append("Sales Tax Rule")
|
||||||
|
|
||||||
# Apply FX Rule (if transaction provided)
|
# Apply FX Rule (if transaction provided)
|
||||||
if transaction:
|
if transaction:
|
||||||
fx_result = self.apply_fx_rule(receipt, transaction)
|
fx_result = self.apply_fx_rule(receipt, transaction)
|
||||||
results["fx_analysis"] = fx_result
|
results["fx_analysis"] = fx_result
|
||||||
if fx_result["success"]:
|
if fx_result["success"]:
|
||||||
results["rules_applied"].append("Foreign Exchange Rule")
|
results["rules_applied"].append("Foreign Exchange Rule")
|
||||||
|
|
||||||
# Apply Meals & Entertainment Rule
|
# Apply Meals & Entertainment Rule
|
||||||
me_result = self.apply_meals_entertainment_rule(receipt)
|
me_result = self.apply_meals_entertainment_rule(receipt)
|
||||||
results["meals_entertainment"] = me_result
|
results["meals_entertainment"] = me_result
|
||||||
if me_result["success"]:
|
if me_result["success"]:
|
||||||
results["rules_applied"].append("Meals & Entertainment Rule")
|
results["rules_applied"].append("Meals & Entertainment Rule")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
Reference in New Issue
Block a user