470 lines
19 KiB
Python
470 lines
19 KiB
Python
|
|
import logging
|
||
|
|
import time
|
||
|
|
from typing import List, Tuple
|
||
|
|
|
||
|
|
import groq
|
||
|
|
|
||
|
|
from config import settings
|
||
|
|
from schemas import Match, Receipt, Transaction
|
||
|
|
|
||
|
|
# Set up logging
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class AIMatcher:
|
||
|
|
def __init__(self, use_batch_matching=True):
|
||
|
|
self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
|
||
|
|
self.model = "llama-3.1-8b-instant"
|
||
|
|
self.max_retries = 3
|
||
|
|
self.retry_delay = 2 # seconds - increased for rate limiting
|
||
|
|
self.rate_limit_delay = 1.0 # seconds between API calls
|
||
|
|
self.last_api_call = 0
|
||
|
|
self.use_batch_matching = (
|
||
|
|
use_batch_matching # Toggle between new and legacy methods
|
||
|
|
)
|
||
|
|
|
||
|
|
def match_receipts_to_transactions(
|
||
|
|
self, receipts: List[Receipt], transactions: List[Transaction]
|
||
|
|
) -> List[Match]:
|
||
|
|
"""Match receipts to transactions using AI"""
|
||
|
|
logger.info(
|
||
|
|
f"Starting AI matching for {len(receipts)} receipts against {len(transactions)} transactions"
|
||
|
|
)
|
||
|
|
matches = []
|
||
|
|
|
||
|
|
for i, receipt in enumerate(receipts):
|
||
|
|
logger.info(
|
||
|
|
f"Processing receipt {i + 1}/{len(receipts)}: {receipt.vendor} - ${receipt.amount}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Rate limiting
|
||
|
|
self._rate_limit()
|
||
|
|
|
||
|
|
# Get the BEST match for this receipt (highest confidence score)
|
||
|
|
best_match = self._find_best_match(receipt, transactions)
|
||
|
|
if best_match:
|
||
|
|
matches.append(best_match)
|
||
|
|
logger.info(
|
||
|
|
f"Found match: {best_match.confidence_score:.3f} - {best_match.match_reason}"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
f"No match found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Sort by confidence score (highest first)
|
||
|
|
matches = sorted(matches, key=lambda x: x.confidence_score, reverse=True)
|
||
|
|
logger.info(f"AI matching completed. Found {len(matches)} matches")
|
||
|
|
return matches
|
||
|
|
|
||
|
|
def _rate_limit(self):
|
||
|
|
"""Implement rate limiting to avoid API quota exhaustion"""
|
||
|
|
current_time = time.time()
|
||
|
|
time_since_last_call = current_time - self.last_api_call
|
||
|
|
|
||
|
|
if time_since_last_call < self.rate_limit_delay:
|
||
|
|
sleep_time = self.rate_limit_delay - time_since_last_call
|
||
|
|
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
||
|
|
time.sleep(sleep_time)
|
||
|
|
|
||
|
|
self.last_api_call = time.time()
|
||
|
|
|
||
|
|
def _find_best_match(
|
||
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
||
|
|
) -> Match:
|
||
|
|
"""Find the BEST match for a receipt using a single AI call for all candidates"""
|
||
|
|
candidates = self._filter_candidates(receipt, transactions)
|
||
|
|
if not candidates:
|
||
|
|
logger.warning(
|
||
|
|
f"No candidates found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
logger.info(f"Found {len(candidates)} candidates for receipt: {receipt.vendor}")
|
||
|
|
|
||
|
|
# Choose matching method based on configuration
|
||
|
|
if self.use_batch_matching:
|
||
|
|
# New efficient method: single AI call for all candidates
|
||
|
|
best_match = self._find_best_match_single_call(receipt, candidates)
|
||
|
|
else:
|
||
|
|
# Legacy method: individual AI calls (fallback)
|
||
|
|
best_match = self._find_best_match_legacy(receipt, candidates)
|
||
|
|
|
||
|
|
return best_match
|
||
|
|
|
||
|
|
def _find_best_match_single_call(
|
||
|
|
self, receipt: Receipt, candidates: List[Transaction]
|
||
|
|
) -> Match:
|
||
|
|
"""Find the best match using a single AI call to evaluate all candidates"""
|
||
|
|
if not candidates:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Limit candidates to avoid token limits (adjust based on your needs)
|
||
|
|
max_candidates = 10
|
||
|
|
if len(candidates) > max_candidates:
|
||
|
|
# Sort by amount similarity and take top candidates
|
||
|
|
candidates = sorted(
|
||
|
|
candidates, key=lambda t: abs(receipt.amount - abs(t.amount))
|
||
|
|
)[:max_candidates]
|
||
|
|
logger.info(
|
||
|
|
f"Limited candidates to top {max_candidates} by amount similarity"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build comprehensive prompt with all candidates
|
||
|
|
candidates_text = ""
|
||
|
|
for i, transaction in enumerate(candidates):
|
||
|
|
transaction_amount_abs = abs(transaction.amount)
|
||
|
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||
|
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||
|
|
amount_percent_diff = (
|
||
|
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||
|
|
)
|
||
|
|
|
||
|
|
candidates_text += f"""
|
||
|
|
Candidate {i + 1}:
|
||
|
|
- Vendor: {transaction.vendor}
|
||
|
|
- Amount: ${transaction.amount} (absolute: ${transaction_amount_abs})
|
||
|
|
- Date: {transaction.transaction_date.strftime("%Y-%m-%d")} ({date_diff} days difference)
|
||
|
|
- Notes: {transaction.notes}
|
||
|
|
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
||
|
|
"""
|
||
|
|
|
||
|
|
prompt = f"""
|
||
|
|
You are an expert at matching receipts to bank transactions. Analyze the receipt below against ALL the candidate transactions and return the BEST match.
|
||
|
|
|
||
|
|
RECEIPT TO MATCH:
|
||
|
|
- Vendor: {receipt.vendor}
|
||
|
|
- Amount: ${receipt.amount}
|
||
|
|
- Date: {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||
|
|
- Description: {receipt.description}
|
||
|
|
- Category: {receipt.category}
|
||
|
|
|
||
|
|
CANDIDATE TRANSACTIONS:
|
||
|
|
{candidates_text}
|
||
|
|
|
||
|
|
SCORING CRITERIA:
|
||
|
|
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
||
|
|
- High confidence (minor differences): 0.8-0.94
|
||
|
|
- Medium confidence (moderate differences): 0.6-0.79
|
||
|
|
- Low confidence (significant differences): 0.4-0.59
|
||
|
|
- Very low confidence (major differences): 0.2-0.39
|
||
|
|
- Minimal similarity: 0.1-0.19
|
||
|
|
- No meaningful similarity: 0.0-0.09
|
||
|
|
|
||
|
|
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
|
||
|
|
|
||
|
|
IMPORTANT: You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
||
|
|
Return ONLY the best match in this exact format:
|
||
|
|
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
|
||
|
|
|
||
|
|
Example: 3|0.87|Same vendor name, exact amount match, 1 day apart
|
||
|
|
Example of low match: 5|0.15|Best available option despite significant differences in vendor and amount
|
||
|
|
"""
|
||
|
|
|
||
|
|
for attempt in range(self.max_retries):
|
||
|
|
try:
|
||
|
|
result = self._call_groq_api_with_timeout(
|
||
|
|
prompt, timeout=45
|
||
|
|
) # Longer timeout for complex prompt
|
||
|
|
|
||
|
|
# Parse the single result
|
||
|
|
candidate_num, score, reason = self._parse_single_match_response(result)
|
||
|
|
|
||
|
|
if candidate_num == -1: # Parsing error occurred
|
||
|
|
logger.warning(
|
||
|
|
f"Failed to parse AI response for receipt: {receipt.vendor}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
if 0 <= candidate_num < len(candidates):
|
||
|
|
best_transaction = candidates[candidate_num]
|
||
|
|
logger.info(
|
||
|
|
f"AI selected candidate {candidate_num + 1}: {best_transaction.vendor} (score: {score:.3f})"
|
||
|
|
)
|
||
|
|
return Match(receipt, best_transaction, score, reason)
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
f"AI returned invalid candidate number: {candidate_num}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(
|
||
|
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||
|
|
)
|
||
|
|
if attempt < self.max_retries - 1:
|
||
|
|
sleep_time = self.retry_delay * (2**attempt)
|
||
|
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||
|
|
time.sleep(sleep_time)
|
||
|
|
else:
|
||
|
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _parse_single_match_response(self, result: str) -> Tuple[int, float, str]:
|
||
|
|
"""Parse AI response for single best match"""
|
||
|
|
result = result.strip()
|
||
|
|
logger.debug(f"Parsing single match response: {result}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
if result.upper().startswith("NONE"):
|
||
|
|
# This should not happen with new prompt, but handle as parsing error
|
||
|
|
logger.warning(
|
||
|
|
"AI returned NONE despite being instructed to always return best match"
|
||
|
|
)
|
||
|
|
return -1, 0.0, "AI returned NONE unexpectedly"
|
||
|
|
|
||
|
|
if "|" in result:
|
||
|
|
parts = result.split("|")
|
||
|
|
if len(parts) >= 3:
|
||
|
|
candidate_str = parts[0].strip()
|
||
|
|
score_str = parts[1].strip()
|
||
|
|
reason = "|".join(parts[2:]).strip()
|
||
|
|
|
||
|
|
# Extract candidate number
|
||
|
|
import re
|
||
|
|
|
||
|
|
candidate_match = re.search(r"\d+", candidate_str)
|
||
|
|
if candidate_match:
|
||
|
|
candidate_num = (
|
||
|
|
int(candidate_match.group()) - 1
|
||
|
|
) # Convert to 0-based index
|
||
|
|
else:
|
||
|
|
raise ValueError("No candidate number found")
|
||
|
|
|
||
|
|
# Extract score
|
||
|
|
score_clean = "".join(
|
||
|
|
c for c in score_str if c.isdigit() or c == "."
|
||
|
|
)
|
||
|
|
score = float(score_clean) if score_clean else 0.0
|
||
|
|
|
||
|
|
# Ensure score is in valid range
|
||
|
|
score = max(0.0, min(1.0, score))
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
f"Parsed: candidate={candidate_num}, score={score}, reason={reason}"
|
||
|
|
)
|
||
|
|
return candidate_num, score, reason
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"Error parsing single match response: {e}")
|
||
|
|
|
||
|
|
# Fallback
|
||
|
|
logger.warning(f"Could not parse single match response: {result}")
|
||
|
|
return -1, 0.0, f"Parse error: {result[:50]}..."
|
||
|
|
|
||
|
|
def _filter_candidates(
|
||
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
||
|
|
) -> List[Transaction]:
|
||
|
|
"""Filter transactions to create a reasonable candidate list"""
|
||
|
|
candidates = []
|
||
|
|
amount_threshold = receipt.amount * 2.0 # 200% threshold - very inclusive
|
||
|
|
|
||
|
|
for transaction in transactions:
|
||
|
|
# Use absolute value for transaction amount comparison
|
||
|
|
transaction_amount_abs = abs(transaction.amount)
|
||
|
|
|
||
|
|
# Only exclude transactions with obviously different amounts
|
||
|
|
if abs(receipt.amount - transaction_amount_abs) <= amount_threshold:
|
||
|
|
candidates.append(transaction)
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
f"Filtered {len(transactions)} transactions to {len(candidates)} candidates"
|
||
|
|
)
|
||
|
|
return candidates
|
||
|
|
|
||
|
|
def _find_best_match_legacy(
|
||
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
||
|
|
) -> Match:
|
||
|
|
"""Legacy method: Find the best match using individual API calls (kept as fallback)"""
|
||
|
|
candidates = self._filter_candidates(receipt, transactions)
|
||
|
|
if not candidates:
|
||
|
|
return None
|
||
|
|
|
||
|
|
best_match = None
|
||
|
|
highest_score = 0
|
||
|
|
|
||
|
|
for transaction in candidates:
|
||
|
|
score, reason = self._calculate_match_score(receipt, transaction)
|
||
|
|
logger.debug(
|
||
|
|
f"Score {score:.3f} for transaction {transaction.vendor}: {reason}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if score > highest_score:
|
||
|
|
highest_score = score
|
||
|
|
best_match = Match(receipt, transaction, score, reason)
|
||
|
|
|
||
|
|
return best_match
|
||
|
|
|
||
|
|
def _calculate_match_score(
|
||
|
|
self, receipt: Receipt, transaction: Transaction
|
||
|
|
) -> Tuple[float, str]:
|
||
|
|
"""Calculate match score using AI"""
|
||
|
|
# Calculate differences for the AI to consider
|
||
|
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||
|
|
transaction_amount_abs = abs(transaction.amount)
|
||
|
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||
|
|
amount_percent_diff = (
|
||
|
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||
|
|
)
|
||
|
|
|
||
|
|
prompt = f"""
|
||
|
|
Compare this receipt with this transaction and provide a confidence score (0-1) and brief reason, the reason must be a single sentence without any special formatting.
|
||
|
|
|
||
|
|
Receipt: {receipt.vendor}, ${receipt.amount}, {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||
|
|
Receipt Description: {receipt.description}
|
||
|
|
Receipt Category: {receipt.category}
|
||
|
|
Transaction: {transaction.vendor}, ${transaction.amount} (absolute: ${transaction_amount_abs}), {transaction.transaction_date.strftime("%Y-%m-%d")}
|
||
|
|
Transaction Notes: {transaction.notes}
|
||
|
|
|
||
|
|
Differences:
|
||
|
|
- Date difference: {date_diff} days
|
||
|
|
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
||
|
|
- Vendor comparison: "{receipt.vendor}" vs "{transaction.vendor}"
|
||
|
|
- Description/Notes comparison: "{receipt.description}" vs "{transaction.notes}"
|
||
|
|
- Category: {receipt.category}
|
||
|
|
|
||
|
|
Score this potential match based on how likely it is the correct match:
|
||
|
|
|
||
|
|
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
||
|
|
- High confidence (minor differences): 0.8-0.94
|
||
|
|
- Medium confidence (moderate differences): 0.6-0.79
|
||
|
|
- Low confidence (significant differences): 0.4-0.59
|
||
|
|
- Very low confidence (major differences): 0.2-0.39
|
||
|
|
- Minimal similarity: 0.1-0.19
|
||
|
|
- No meaningful similarity: 0.0-0.09
|
||
|
|
|
||
|
|
Consider description and category similarity in your scoring.
|
||
|
|
|
||
|
|
IMPORTANT: Return ONLY the score and reason separated by a pipe character.
|
||
|
|
Format: [score]|[reason]
|
||
|
|
Example: 0.85|Same vendor, same amount, 2 days apart
|
||
|
|
"""
|
||
|
|
|
||
|
|
for attempt in range(self.max_retries):
|
||
|
|
try:
|
||
|
|
result = self._call_groq_api_with_timeout(
|
||
|
|
prompt, timeout=30
|
||
|
|
) # Increased timeout
|
||
|
|
|
||
|
|
# Parse the result - handle multiple formats
|
||
|
|
score, reason = self._parse_ai_response(result)
|
||
|
|
|
||
|
|
logger.debug(f"AI Response: {result}")
|
||
|
|
logger.debug(f"Parsed: score={score}, reason={reason}")
|
||
|
|
|
||
|
|
return score, reason
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(
|
||
|
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||
|
|
)
|
||
|
|
if attempt < self.max_retries - 1:
|
||
|
|
# Exponential backoff for rate limiting
|
||
|
|
sleep_time = self.retry_delay * (2**attempt)
|
||
|
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||
|
|
time.sleep(sleep_time)
|
||
|
|
else:
|
||
|
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||
|
|
return 0.0, f"AI error after {self.max_retries} attempts: {str(e)}"
|
||
|
|
|
||
|
|
def _parse_ai_response(self, result: str) -> Tuple[float, str]:
|
||
|
|
"""Parse AI response with robust error handling"""
|
||
|
|
result = result.strip()
|
||
|
|
logger.debug(f"Parsing AI response: {result}")
|
||
|
|
|
||
|
|
# Try to find score in various formats
|
||
|
|
if "|" in result:
|
||
|
|
parts = result.split("|")
|
||
|
|
logger.debug(f"Split response into {len(parts)} parts: {parts}")
|
||
|
|
|
||
|
|
# Look for a numeric score in any part
|
||
|
|
for i, part in enumerate(parts):
|
||
|
|
part = part.strip()
|
||
|
|
try:
|
||
|
|
# Remove any non-numeric characters except decimal point
|
||
|
|
score_str_clean = "".join(
|
||
|
|
c for c in part if c.isdigit() or c == "."
|
||
|
|
)
|
||
|
|
if score_str_clean:
|
||
|
|
score = float(score_str_clean)
|
||
|
|
if 0 <= score <= 1: # Valid confidence score
|
||
|
|
# Get reason from other parts
|
||
|
|
reason_parts = [
|
||
|
|
p.strip()
|
||
|
|
for j, p in enumerate(parts)
|
||
|
|
if j != i and p.strip()
|
||
|
|
]
|
||
|
|
reason = (
|
||
|
|
" | ".join(reason_parts)
|
||
|
|
if reason_parts
|
||
|
|
else "Score extracted"
|
||
|
|
)
|
||
|
|
logger.debug(
|
||
|
|
f"Found score {score} in part {i}, reason: {reason}"
|
||
|
|
)
|
||
|
|
return score, reason
|
||
|
|
except ValueError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Try to extract just a number from the response
|
||
|
|
try:
|
||
|
|
import re
|
||
|
|
|
||
|
|
numbers = re.findall(r"\d+\.?\d*", result)
|
||
|
|
if numbers:
|
||
|
|
for num_str in numbers:
|
||
|
|
score = float(num_str)
|
||
|
|
if 0 <= score <= 1: # Valid confidence score
|
||
|
|
logger.debug(f"Extracted score {score} from response")
|
||
|
|
return score, f"Extracted from response: {result[:50]}..."
|
||
|
|
except (ValueError, IndexError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Fallback - try to find any number and normalize it
|
||
|
|
try:
|
||
|
|
import re
|
||
|
|
|
||
|
|
numbers = re.findall(r"\d+\.?\d*", result)
|
||
|
|
if numbers:
|
||
|
|
score = float(numbers[0])
|
||
|
|
# Normalize to 0-1 range if it's a percentage or other scale
|
||
|
|
if score > 1:
|
||
|
|
score = score / 100 # Assume percentage
|
||
|
|
score = max(0, min(1, score)) # Clamp to 0-1
|
||
|
|
logger.debug(f"Normalized score {score} from response")
|
||
|
|
return score, f"Normalized from response: {result[:50]}..."
|
||
|
|
except (ValueError, IndexError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Final fallback
|
||
|
|
logger.warning(f"Could not parse AI response: {result}")
|
||
|
|
return 0.0, f"Unparseable response: {result[:50]}..."
|
||
|
|
|
||
|
|
def _call_groq_api_with_timeout(self, prompt: str, timeout: int = 15) -> str:
|
||
|
|
"""Make API call with timeout and retry logic"""
|
||
|
|
import concurrent.futures
|
||
|
|
|
||
|
|
def api_call():
|
||
|
|
try:
|
||
|
|
response = self.client.chat.completions.create(
|
||
|
|
model=self.model,
|
||
|
|
messages=[{"role": "user", "content": prompt}],
|
||
|
|
max_tokens=200,
|
||
|
|
temperature=0.1,
|
||
|
|
)
|
||
|
|
return response.choices[0].message.content.strip()
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|
||
|
|
|
||
|
|
try:
|
||
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||
|
|
future = executor.submit(api_call)
|
||
|
|
return future.result(timeout=timeout)
|
||
|
|
except concurrent.futures.TimeoutError:
|
||
|
|
raise Exception(f"API call timed out after {timeout} seconds")
|
||
|
|
except Exception as e:
|
||
|
|
raise e
|