3559cbe19d
This commit introduces a new test script, `test_json_extraction.py`, which verifies the correctness of the JSON extraction logic. The script includes a function to extract the first valid JSON object from raw input and a series of test cases covering various scenarios, such as clean JSON, JSON with extra text, nested JSON, and escaped quotes. The tests ensure that the extraction function behaves as expected and handles edge cases appropriately.
474 lines
19 KiB
Python
474 lines
19 KiB
Python
import logging
|
|
import time
|
|
from typing import List, Tuple
|
|
|
|
import groq
|
|
|
|
from config import settings
|
|
from schemas import Match, Receipt, Transaction
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AIMatcher:
|
|
def __init__(self, use_batch_matching=True):
|
|
self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
|
|
self.model = "llama-3.1-8b-instant"
|
|
self.max_retries = 3
|
|
self.retry_delay = 2 # seconds - increased for rate limiting
|
|
self.rate_limit_delay = 1.0 # seconds between API calls
|
|
self.last_api_call = 0
|
|
self.use_batch_matching = (
|
|
use_batch_matching # Toggle between new and legacy methods
|
|
)
|
|
|
|
def match_receipts_to_transactions(
|
|
self, receipts: List[Receipt], transactions: List[Transaction]
|
|
) -> List[Match]:
|
|
"""Match receipts to transactions using AI"""
|
|
logger.info(
|
|
f"Starting AI matching for {len(receipts)} receipts against {len(transactions)} transactions"
|
|
)
|
|
matches = []
|
|
|
|
for i, receipt in enumerate(receipts):
|
|
logger.info(
|
|
f"Processing receipt {i + 1}/{len(receipts)}: {receipt.vendor} - ${receipt.amount}"
|
|
)
|
|
|
|
# Rate limiting
|
|
self._rate_limit()
|
|
|
|
# Get the BEST match for this receipt (highest confidence score)
|
|
best_match = self._find_best_match(receipt, transactions)
|
|
if best_match:
|
|
matches.append(best_match)
|
|
logger.info(
|
|
f"Found match: {best_match.confidence_score:.3f} - {best_match.match_reason}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"No match found for receipt: {receipt.vendor} - ${receipt.amount}"
|
|
)
|
|
|
|
# Sort by confidence score (highest first)
|
|
matches = sorted(matches, key=lambda x: x.confidence_score, reverse=True)
|
|
logger.info(f"AI matching completed. Found {len(matches)} matches")
|
|
return matches
|
|
|
|
def _rate_limit(self):
|
|
"""Implement rate limiting to avoid API quota exhaustion"""
|
|
current_time = time.time()
|
|
time_since_last_call = current_time - self.last_api_call
|
|
|
|
if time_since_last_call < self.rate_limit_delay:
|
|
sleep_time = self.rate_limit_delay - time_since_last_call
|
|
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
|
time.sleep(sleep_time)
|
|
|
|
self.last_api_call = time.time()
|
|
|
|
def _find_best_match(
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
|
) -> Match:
|
|
"""Find the BEST match for a receipt using a single AI call for all candidates"""
|
|
candidates = self._filter_candidates(receipt, transactions)
|
|
if not candidates:
|
|
logger.warning(
|
|
f"No candidates found for receipt: {receipt.vendor} - ${receipt.amount}"
|
|
)
|
|
return None
|
|
|
|
logger.info(f"Found {len(candidates)} candidates for receipt: {receipt.vendor}")
|
|
|
|
# Choose matching method based on configuration
|
|
if self.use_batch_matching:
|
|
# New efficient method: single AI call for all candidates
|
|
best_match = self._find_best_match_single_call(receipt, candidates)
|
|
else:
|
|
# Legacy method: individual AI calls (fallback)
|
|
best_match = self._find_best_match_legacy(receipt, candidates)
|
|
|
|
return best_match
|
|
|
|
def _find_best_match_single_call(
|
|
self, receipt: Receipt, candidates: List[Transaction]
|
|
) -> Match:
|
|
"""Find the best match using a single AI call to evaluate all candidates"""
|
|
if not candidates:
|
|
return None
|
|
|
|
# Limit candidates to avoid token limits (adjust based on your needs)
|
|
max_candidates = 10
|
|
if len(candidates) > max_candidates:
|
|
# Sort by amount similarity and take top candidates
|
|
candidates = sorted(
|
|
candidates, key=lambda t: abs(receipt.amount - abs(t.amount))
|
|
)[:max_candidates]
|
|
logger.info(
|
|
f"Limited candidates to top {max_candidates} by amount similarity"
|
|
)
|
|
|
|
# Build comprehensive prompt with all candidates
|
|
candidates_text = ""
|
|
for i, transaction in enumerate(candidates):
|
|
transaction_amount_abs = abs(transaction.amount)
|
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
|
amount_percent_diff = (
|
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
|
)
|
|
|
|
candidates_text += f"""
|
|
Candidate {i + 1}:
|
|
- Vendor: {transaction.vendor}
|
|
- Amount: ${transaction.amount} (absolute: ${transaction_amount_abs})
|
|
- Date: {transaction.transaction_date.strftime("%Y-%m-%d")} ({date_diff} days difference)
|
|
- Notes: {transaction.notes}
|
|
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
|
"""
|
|
|
|
prompt = f"""
|
|
You are an expert at matching receipts to bank transactions. Analyze the receipt below against ALL the candidate transactions and return the BEST match.
|
|
|
|
RECEIPT TO MATCH:
|
|
- Vendor: {receipt.vendor}
|
|
- Amount: ${receipt.amount}
|
|
- Date: {receipt.receipt_date.strftime("%Y-%m-%d")}
|
|
- Description: {receipt.description}
|
|
- Category: {receipt.category}
|
|
|
|
CANDIDATE TRANSACTIONS:
|
|
{candidates_text}
|
|
|
|
SCORING CRITERIA:
|
|
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
|
- High confidence (minor differences): 0.8-0.94
|
|
- Medium confidence (moderate differences): 0.6-0.79
|
|
- Low confidence (significant differences): 0.4-0.59
|
|
- Very low confidence (major differences): 0.2-0.39
|
|
- Minimal similarity: 0.1-0.19
|
|
- No meaningful similarity: 0.0-0.09
|
|
|
|
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
|
|
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
|
|
|
|
IMPORTANT:
|
|
You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
|
Return ONLY the best match in this exact format:
|
|
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
|
|
|
|
Example: 3|0.87|Same vendor name, exact amount match, 1 day apart
|
|
Example of low match: 5|0.15|Best available option despite significant differences in vendor and amount
|
|
"""
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
result = self._call_groq_api_with_timeout(
|
|
prompt, timeout=45
|
|
) # Longer timeout for complex prompt
|
|
|
|
# Parse the single result
|
|
candidate_num, score, reason = self._parse_single_match_response(result)
|
|
|
|
if candidate_num == -1: # Parsing error occurred
|
|
logger.warning(
|
|
f"Failed to parse AI response for receipt: {receipt.vendor}"
|
|
)
|
|
return None
|
|
|
|
if 0 <= candidate_num < len(candidates):
|
|
best_transaction = candidates[candidate_num]
|
|
logger.info(
|
|
f"AI selected candidate {candidate_num + 1}: {best_transaction.vendor} (score: {score:.3f})"
|
|
)
|
|
return Match(receipt, best_transaction, score, reason)
|
|
else:
|
|
logger.warning(
|
|
f"AI returned invalid candidate number: {candidate_num}"
|
|
)
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
|
)
|
|
if attempt < self.max_retries - 1:
|
|
sleep_time = self.retry_delay * (2**attempt)
|
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
|
time.sleep(sleep_time)
|
|
else:
|
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
|
return None
|
|
|
|
return None
|
|
|
|
def _parse_single_match_response(self, result: str) -> Tuple[int, float, str]:
|
|
"""Parse AI response for single best match"""
|
|
result = result.strip()
|
|
logger.debug(f"Parsing single match response: {result}")
|
|
|
|
try:
|
|
if result.upper().startswith("NONE"):
|
|
# This should not happen with new prompt, but handle as parsing error
|
|
logger.warning(
|
|
"AI returned NONE despite being instructed to always return best match"
|
|
)
|
|
return -1, 0.0, "AI returned NONE unexpectedly"
|
|
|
|
if "|" in result:
|
|
parts = result.split("|")
|
|
if len(parts) >= 3:
|
|
candidate_str = parts[0].strip()
|
|
score_str = parts[1].strip()
|
|
reason = "|".join(parts[2:]).strip()
|
|
|
|
# Extract candidate number
|
|
import re
|
|
|
|
candidate_match = re.search(r"\d+", candidate_str)
|
|
if candidate_match:
|
|
candidate_num = (
|
|
int(candidate_match.group()) - 1
|
|
) # Convert to 0-based index
|
|
else:
|
|
raise ValueError("No candidate number found")
|
|
|
|
# Extract score
|
|
score_clean = "".join(
|
|
c for c in score_str if c.isdigit() or c == "."
|
|
)
|
|
score = float(score_clean) if score_clean else 0.0
|
|
|
|
# Ensure score is in valid range
|
|
score = max(0.0, min(1.0, score))
|
|
|
|
logger.debug(
|
|
f"Parsed: candidate={candidate_num}, score={score}, reason={reason}"
|
|
)
|
|
return candidate_num, score, reason
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error parsing single match response: {e}")
|
|
|
|
# Fallback
|
|
logger.warning(f"Could not parse single match response: {result}")
|
|
return -1, 0.0, f"Parse error: {result[:50]}..."
|
|
|
|
def _filter_candidates(
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
|
) -> List[Transaction]:
|
|
"""Filter transactions to create a reasonable candidate list"""
|
|
candidates = []
|
|
amount_threshold = receipt.amount * 2.0 # 200% threshold - very inclusive
|
|
|
|
for transaction in transactions:
|
|
# Use absolute value for transaction amount comparison
|
|
transaction_amount_abs = abs(transaction.amount)
|
|
|
|
# Only exclude transactions with obviously different amounts
|
|
if abs(receipt.amount - transaction_amount_abs) <= amount_threshold:
|
|
candidates.append(transaction)
|
|
|
|
logger.debug(
|
|
f"Filtered {len(transactions)} transactions to {len(candidates)} candidates"
|
|
)
|
|
return candidates
|
|
|
|
def _find_best_match_legacy(
|
|
self, receipt: Receipt, transactions: List[Transaction]
|
|
) -> Match:
|
|
"""Legacy method: Find the best match using individual API calls (kept as fallback)"""
|
|
candidates = self._filter_candidates(receipt, transactions)
|
|
if not candidates:
|
|
return None
|
|
|
|
best_match = None
|
|
highest_score = 0
|
|
|
|
for transaction in candidates:
|
|
score, reason = self._calculate_match_score(receipt, transaction)
|
|
logger.debug(
|
|
f"Score {score:.3f} for transaction {transaction.vendor}: {reason}"
|
|
)
|
|
|
|
if score > highest_score:
|
|
highest_score = score
|
|
best_match = Match(receipt, transaction, score, reason)
|
|
|
|
return best_match
|
|
|
|
def _calculate_match_score(
|
|
self, receipt: Receipt, transaction: Transaction
|
|
) -> Tuple[float, str]:
|
|
"""Calculate match score using AI"""
|
|
# Calculate differences for the AI to consider
|
|
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
|
transaction_amount_abs = abs(transaction.amount)
|
|
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
|
amount_percent_diff = (
|
|
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
|
)
|
|
|
|
prompt = f"""
|
|
Compare this receipt with this transaction and provide a confidence score (0-1) and brief reason, the reason must be a single sentence without any special formatting.
|
|
|
|
Receipt: {receipt.vendor}, ${receipt.amount}, {receipt.receipt_date.strftime("%Y-%m-%d")}
|
|
Receipt Description: {receipt.description}
|
|
Receipt Category: {receipt.category}
|
|
Transaction: {transaction.vendor}, ${transaction.amount} (absolute: ${transaction_amount_abs}), {transaction.transaction_date.strftime("%Y-%m-%d")}
|
|
Transaction Notes: {transaction.notes}
|
|
|
|
Differences:
|
|
- Date difference: {date_diff} days
|
|
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
|
- Vendor comparison: "{receipt.vendor}" vs "{transaction.vendor}"
|
|
- Description/Notes comparison: "{receipt.description}" vs "{transaction.notes}"
|
|
- Category: {receipt.category}
|
|
|
|
Score this potential match based on how likely it is the correct match:
|
|
|
|
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
|
- High confidence (minor differences): 0.8-0.94
|
|
- Medium confidence (moderate differences): 0.6-0.79
|
|
- Low confidence (significant differences): 0.4-0.59
|
|
- Very low confidence (major differences): 0.2-0.39
|
|
- Minimal similarity: 0.1-0.19
|
|
- No meaningful similarity: 0.0-0.09
|
|
|
|
Consider description and category similarity in your scoring.
|
|
|
|
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
|
|
|
|
IMPORTANT: Return ONLY the score and reason separated by a pipe character.
|
|
Format: [score]|[reason]
|
|
Example: 0.85|Same vendor, same amount, 2 days apart
|
|
"""
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
result = self._call_groq_api_with_timeout(
|
|
prompt, timeout=30
|
|
) # Increased timeout
|
|
|
|
# Parse the result - handle multiple formats
|
|
score, reason = self._parse_ai_response(result)
|
|
|
|
logger.debug(f"AI Response: {result}")
|
|
logger.debug(f"Parsed: score={score}, reason={reason}")
|
|
|
|
return score, reason
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
|
)
|
|
if attempt < self.max_retries - 1:
|
|
# Exponential backoff for rate limiting
|
|
sleep_time = self.retry_delay * (2**attempt)
|
|
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
|
time.sleep(sleep_time)
|
|
else:
|
|
logger.error(f"All attempts failed for receipt {receipt.id}")
|
|
return 0.0, f"AI error after {self.max_retries} attempts: {str(e)}"
|
|
|
|
def _parse_ai_response(self, result: str) -> Tuple[float, str]:
|
|
"""Parse AI response with robust error handling"""
|
|
result = result.strip()
|
|
logger.debug(f"Parsing AI response: {result}")
|
|
|
|
# Try to find score in various formats
|
|
if "|" in result:
|
|
parts = result.split("|")
|
|
logger.debug(f"Split response into {len(parts)} parts: {parts}")
|
|
|
|
# Look for a numeric score in any part
|
|
for i, part in enumerate(parts):
|
|
part = part.strip()
|
|
try:
|
|
# Remove any non-numeric characters except decimal point
|
|
score_str_clean = "".join(
|
|
c for c in part if c.isdigit() or c == "."
|
|
)
|
|
if score_str_clean:
|
|
score = float(score_str_clean)
|
|
if 0 <= score <= 1: # Valid confidence score
|
|
# Get reason from other parts
|
|
reason_parts = [
|
|
p.strip()
|
|
for j, p in enumerate(parts)
|
|
if j != i and p.strip()
|
|
]
|
|
reason = (
|
|
" | ".join(reason_parts)
|
|
if reason_parts
|
|
else "Score extracted"
|
|
)
|
|
logger.debug(
|
|
f"Found score {score} in part {i}, reason: {reason}"
|
|
)
|
|
return score, reason
|
|
except ValueError:
|
|
continue
|
|
|
|
# Try to extract just a number from the response
|
|
try:
|
|
import re
|
|
|
|
numbers = re.findall(r"\d+\.?\d*", result)
|
|
if numbers:
|
|
for num_str in numbers:
|
|
score = float(num_str)
|
|
if 0 <= score <= 1: # Valid confidence score
|
|
logger.debug(f"Extracted score {score} from response")
|
|
return score, f"Extracted from response: {result[:50]}..."
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
# Fallback - try to find any number and normalize it
|
|
try:
|
|
import re
|
|
|
|
numbers = re.findall(r"\d+\.?\d*", result)
|
|
if numbers:
|
|
score = float(numbers[0])
|
|
# Normalize to 0-1 range if it's a percentage or other scale
|
|
if score > 1:
|
|
score = score / 100 # Assume percentage
|
|
score = max(0, min(1, score)) # Clamp to 0-1
|
|
logger.debug(f"Normalized score {score} from response")
|
|
return score, f"Normalized from response: {result[:50]}..."
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
# Final fallback
|
|
logger.warning(f"Could not parse AI response: {result}")
|
|
return 0.0, f"Unparseable response: {result[:50]}..."
|
|
|
|
def _call_groq_api_with_timeout(self, prompt: str, timeout: int = 15) -> str:
|
|
"""Make API call with timeout and retry logic"""
|
|
import concurrent.futures
|
|
|
|
def api_call():
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
max_tokens=200,
|
|
temperature=0.1,
|
|
)
|
|
return response.choices[0].message.content.strip()
|
|
except Exception as e:
|
|
raise e
|
|
|
|
try:
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(api_call)
|
|
return future.result(timeout=timeout)
|
|
except concurrent.futures.TimeoutError:
|
|
raise Exception(f"API call timed out after {timeout} seconds")
|
|
except Exception as e:
|
|
raise e
|