Add test script for JSON extraction functionality
This commit introduces a new test script, `test_json_extraction.py`, which verifies the correctness of the JSON extraction logic. The script includes a function to extract the first valid JSON object from raw input and a series of test cases covering various scenarios, such as clean JSON, JSON with extra text, nested JSON, and escaped quotes. The tests ensure that the extraction function behaves as expected and handles edge cases appropriately.
This commit is contained in:
+26
-1
@@ -27,7 +27,32 @@ Base = declarative_base()
|
||||
|
||||
|
||||
def create_db_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
"""Create database tables safely with error handling"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
# Check if tables already exist to avoid unnecessary DDL operations
|
||||
from sqlalchemy import inspect
|
||||
inspector = inspect(engine)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if existing_tables:
|
||||
logger.info(f"Database tables already exist: {existing_tables}")
|
||||
return
|
||||
|
||||
# Create tables with timeout protection
|
||||
logger.info("Creating database tables...")
|
||||
Base.metadata.create_all(bind=engine, checkfirst=True)
|
||||
logger.info("Database tables created successfully")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Database creation interrupted by user")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating database tables: {e}")
|
||||
# Don't crash the app - tables might already exist
|
||||
pass
|
||||
|
||||
|
||||
def clear_all_data():
|
||||
|
||||
+12
-1
@@ -30,7 +30,8 @@ from services.document_processor import DocumentProcessor
|
||||
from services.matching_engine import MatchingEngine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
create_db_tables()
|
||||
# Don't create tables at import time - do it on startup
|
||||
# create_db_tables()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -54,6 +55,15 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize database on startup"""
|
||||
logger.info("Starting up application...")
|
||||
create_db_tables()
|
||||
logger.info("Application startup complete")
|
||||
|
||||
|
||||
# Initialize DS Engine components
|
||||
matching_engine = MatchingEngine()
|
||||
document_processor = DocumentProcessor()
|
||||
@@ -384,6 +394,7 @@ async def process_document(
|
||||
- ai_rules: Custom categorization rules to override default logic
|
||||
(e.g., [{"condition": "vendor is Starbucks", "action": "Food"}])
|
||||
"""
|
||||
logger.info(f"Request: {request}")
|
||||
try:
|
||||
# Get file info from database
|
||||
db_uploaded_file = get_uploaded_file_from_db(db, file_id)
|
||||
|
||||
@@ -152,9 +152,11 @@ SCORING CRITERIA:
|
||||
- Minimal similarity: 0.1-0.19
|
||||
- No meaningful similarity: 0.0-0.09
|
||||
|
||||
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
|
||||
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
|
||||
|
||||
IMPORTANT: You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
||||
IMPORTANT:
|
||||
You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
||||
Return ONLY the best match in this exact format:
|
||||
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
|
||||
|
||||
@@ -338,6 +340,8 @@ Example of low match: 5|0.15|Best available option despite significant differenc
|
||||
|
||||
Consider description and category similarity in your scoring.
|
||||
|
||||
The most important factor to consider is the Amount for both the transaction and the receipt. The closer the amounts, the higher the score. If the amounts are different or not close return a low score (0-0.1) based on other factors.
|
||||
|
||||
IMPORTANT: Return ONLY the score and reason separated by a pipe character.
|
||||
Format: [score]|[reason]
|
||||
Example: 0.85|Same vendor, same amount, 2 days apart
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
@@ -17,6 +18,59 @@ class DocumentProcessor:
|
||||
self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
|
||||
self.model = "meta-llama/llama-4-scout-17b-16e-instruct" # Vision model
|
||||
|
||||
def _extract_first_json(self, raw: str) -> dict:
|
||||
"""Extract the first valid JSON object from raw LLM output.
|
||||
|
||||
Handles cases where LLM returns extra text after/before the JSON.
|
||||
"""
|
||||
try:
|
||||
# First try direct parsing (fastest path)
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Find the first '{' and match closing '}'
|
||||
start = raw.find("{")
|
||||
if start == -1:
|
||||
raise ValueError("No JSON object found in LLM output")
|
||||
|
||||
depth = 0
|
||||
end = -1
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i in range(start, len(raw)):
|
||||
ch = raw[i]
|
||||
|
||||
# Handle string escaping
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
# Track if we're inside a string
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
# Only count braces outside of strings
|
||||
if not in_string:
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
if end == -1:
|
||||
raise ValueError("Unbalanced JSON braces in LLM output")
|
||||
|
||||
json_str = raw[start:end]
|
||||
return json.loads(json_str)
|
||||
|
||||
async def process_file(
|
||||
self,
|
||||
file_path: str,
|
||||
@@ -145,6 +199,8 @@ class DocumentProcessor:
|
||||
* residual_value: Estimated value at end of life (typically 10% of purchase price for equipment, 20% for vehicles)
|
||||
- If is_depreciable is false, set name_of_asset, cca_rate, useful_life, and residual_value to null
|
||||
|
||||
CATEGORY RULES:
|
||||
- Assign the category based on all the details in the receipt
|
||||
Return only valid JSON.
|
||||
"""
|
||||
|
||||
@@ -334,11 +390,16 @@ class DocumentProcessor:
|
||||
def _parse_extraction_result(self, result_text: str) -> Dict[str, Any]:
|
||||
"""Parse Groq response and extract JSON data"""
|
||||
try:
|
||||
# Clean up response and extract JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# Find JSON in response - try multiple patterns
|
||||
# Try robust JSON extraction first (handles extra text)
|
||||
try:
|
||||
data = self._extract_first_json(result_text)
|
||||
return data
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Robust JSON extraction failed: {e}. Trying fallback methods...")
|
||||
|
||||
# Fallback: Find JSON in response - try multiple patterns
|
||||
json_match = re.search(r"\{.*\}", result_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group()
|
||||
@@ -355,7 +416,7 @@ class DocumentProcessor:
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
# Try to fix common JSON issues
|
||||
logger.warning(f"Initial JSON parsing failed: {e}")
|
||||
logger.warning(f"Fallback JSON parsing also failed: {e}")
|
||||
|
||||
# Try to extract individual fields using regex
|
||||
vendor_match = re.search(r'"vendor"\s*:\s*"([^"]*)"', json_str)
|
||||
|
||||
Reference in New Issue
Block a user