Add test script for JSON extraction functionality

This commit introduces a new test script, `test_json_extraction.py`, which verifies the correctness of the JSON extraction logic. The script includes a function to extract the first valid JSON object from raw input and a series of test cases covering various scenarios, such as clean JSON, JSON with extra text, nested JSON, and escaped quotes. The tests ensure that the extraction function behaves as expected and handles edge cases appropriately.
This commit is contained in:
2025-10-09 19:56:22 +00:00
parent 2e020437a8
commit 3559cbe19d
5 changed files with 769 additions and 891 deletions
+26 -1
View File
@@ -27,7 +27,32 @@ Base = declarative_base()
def create_db_tables(): def create_db_tables():
Base.metadata.create_all(bind=engine) """Create database tables safely with error handling"""
import logging
logger = logging.getLogger(__name__)
try:
# Check if tables already exist to avoid unnecessary DDL operations
from sqlalchemy import inspect
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
if existing_tables:
logger.info(f"Database tables already exist: {existing_tables}")
return
# Create tables with timeout protection
logger.info("Creating database tables...")
Base.metadata.create_all(bind=engine, checkfirst=True)
logger.info("Database tables created successfully")
except KeyboardInterrupt:
logger.warning("Database creation interrupted by user")
raise
except Exception as e:
logger.error(f"Error creating database tables: {e}")
# Don't crash the app - tables might already exist
pass
def clear_all_data(): def clear_all_data():
+12 -1
View File
@@ -30,7 +30,8 @@ from services.document_processor import DocumentProcessor
from services.matching_engine import MatchingEngine from services.matching_engine import MatchingEngine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
create_db_tables() # Don't create tables at import time - do it on startup
# create_db_tables()
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -54,6 +55,15 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.on_event("startup")
async def startup_event():
"""Initialize database on startup"""
logger.info("Starting up application...")
create_db_tables()
logger.info("Application startup complete")
# Initialize DS Engine components # Initialize DS Engine components
matching_engine = MatchingEngine() matching_engine = MatchingEngine()
document_processor = DocumentProcessor() document_processor = DocumentProcessor()
@@ -384,6 +394,7 @@ async def process_document(
- ai_rules: Custom categorization rules to override default logic - ai_rules: Custom categorization rules to override default logic
(e.g., [{"condition": "vendor is Starbucks", "action": "Food"}]) (e.g., [{"condition": "vendor is Starbucks", "action": "Food"}])
""" """
logger.info(f"Request: {request}")
try: try:
# Get file info from database # Get file info from database
db_uploaded_file = get_uploaded_file_from_db(db, file_id) db_uploaded_file = get_uploaded_file_from_db(db, file_id)
+5 -1
View File
@@ -152,9 +152,11 @@ SCORING CRITERIA:
- Minimal similarity: 0.1-0.19 - Minimal similarity: 0.1-0.19
- No meaningful similarity: 0.0-0.09 - No meaningful similarity: 0.0-0.09
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance. Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
IMPORTANT: You MUST return the candidate with the highest match score, even if it's very low. Never return NONE. IMPORTANT:
You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
Return ONLY the best match in this exact format: Return ONLY the best match in this exact format:
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
@@ -338,6 +340,8 @@ Example of low match: 5|0.15|Best available option despite significant differenc
Consider description and category similarity in your scoring. Consider description and category similarity in your scoring.
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
IMPORTANT: Return ONLY the score and reason separated by a pipe character. IMPORTANT: Return ONLY the score and reason separated by a pipe character.
Format: [score]|[reason] Format: [score]|[reason]
Example: 0.85|Same vendor, same amount, 2 days apart Example: 0.85|Same vendor, same amount, 2 days apart
+65 -4
View File
@@ -1,4 +1,5 @@
import base64 import base64
import json
import logging import logging
import os import os
from datetime import datetime from datetime import datetime
@@ -17,6 +18,59 @@ class DocumentProcessor:
self.client = groq.Groq(api_key=settings.GROQ_API_KEY) self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
self.model = "meta-llama/llama-4-scout-17b-16e-instruct" # Vision model self.model = "meta-llama/llama-4-scout-17b-16e-instruct" # Vision model
def _extract_first_json(self, raw: str) -> dict:
"""Extract the first valid JSON object from raw LLM output.
Handles cases where LLM returns extra text after/before the JSON.
"""
try:
# First try direct parsing (fastest path)
return json.loads(raw)
except json.JSONDecodeError:
pass
# Find the first '{' and match closing '}'
start = raw.find("{")
if start == -1:
raise ValueError("No JSON object found in LLM output")
depth = 0
end = -1
in_string = False
escape_next = False
for i in range(start, len(raw)):
ch = raw[i]
# Handle string escaping
if escape_next:
escape_next = False
continue
if ch == "\\":
escape_next = True
continue
# Track if we're inside a string
if ch == '"':
in_string = not in_string
continue
# Only count braces outside of strings
if not in_string:
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end = i + 1
break
if end == -1:
raise ValueError("Unbalanced JSON braces in LLM output")
json_str = raw[start:end]
return json.loads(json_str)
async def process_file( async def process_file(
self, self,
file_path: str, file_path: str,
@@ -145,6 +199,8 @@ class DocumentProcessor:
* residual_value: Estimated value at end of life (typically 10% of purchase price for equipment, 20% for vehicles) * residual_value: Estimated value at end of life (typically 10% of purchase price for equipment, 20% for vehicles)
- If is_depreciable is false, set name_of_asset, cca_rate, useful_life, and residual_value to null - If is_depreciable is false, set name_of_asset, cca_rate, useful_life, and residual_value to null
CATEGORY RULES:
- Assign the category based on all the details in the receipt
Return only valid JSON. Return only valid JSON.
""" """
@@ -334,11 +390,16 @@ class DocumentProcessor:
def _parse_extraction_result(self, result_text: str) -> Dict[str, Any]: def _parse_extraction_result(self, result_text: str) -> Dict[str, Any]:
"""Parse Groq response and extract JSON data""" """Parse Groq response and extract JSON data"""
try: try:
# Clean up response and extract JSON
import json
import re import re
# Find JSON in response - try multiple patterns # Try robust JSON extraction first (handles extra text)
try:
data = self._extract_first_json(result_text)
return data
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Robust JSON extraction failed: {e}. Trying fallback methods...")
# Fallback: Find JSON in response - try multiple patterns
json_match = re.search(r"\{.*\}", result_text, re.DOTALL) json_match = re.search(r"\{.*\}", result_text, re.DOTALL)
if json_match: if json_match:
json_str = json_match.group() json_str = json_match.group()
@@ -355,7 +416,7 @@ class DocumentProcessor:
data = json.loads(json_str) data = json.loads(json_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Try to fix common JSON issues # Try to fix common JSON issues
logger.warning(f"Initial JSON parsing failed: {e}") logger.warning(f"Fallback JSON parsing also failed: {e}")
# Try to extract individual fields using regex # Try to extract individual fields using regex
vendor_match = re.search(r'"vendor"\s*:\s*"([^"]*)"', json_str) vendor_match = re.search(r'"vendor"\s*:\s*"([^"]*)"', json_str)
+661 -884
View File
File diff suppressed because it is too large Load Diff