3559cbe19d
This commit introduces a new test script, `test_json_extraction.py`, which verifies the correctness of the JSON extraction logic. The script includes a function to extract the first valid JSON object from raw input and a series of test cases covering various scenarios, such as clean JSON, JSON with extra text, nested JSON, and escaped quotes. The tests ensure that the extraction function behaves as expected and handles edge cases appropriately.
123 lines
4.0 KiB
Python
123 lines
4.0 KiB
Python
from typing import Annotated
|
|
|
|
from fastapi import Depends
|
|
from sqlalchemy import Column, DateTime, Float, Integer, String, create_engine
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
|
|
|
|
engine = create_engine(
|
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
)
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
def get_db():
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
db_dependency = Annotated[Session, Depends(get_db)]
|
|
Base = declarative_base()
|
|
|
|
|
|
def create_db_tables():
|
|
"""Create database tables safely with error handling"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
# Check if tables already exist to avoid unnecessary DDL operations
|
|
from sqlalchemy import inspect
|
|
inspector = inspect(engine)
|
|
existing_tables = inspector.get_table_names()
|
|
|
|
if existing_tables:
|
|
logger.info(f"Database tables already exist: {existing_tables}")
|
|
return
|
|
|
|
# Create tables with timeout protection
|
|
logger.info("Creating database tables...")
|
|
Base.metadata.create_all(bind=engine, checkfirst=True)
|
|
logger.info("Database tables created successfully")
|
|
|
|
except KeyboardInterrupt:
|
|
logger.warning("Database creation interrupted by user")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error creating database tables: {e}")
|
|
# Don't crash the app - tables might already exist
|
|
pass
|
|
|
|
|
|
def clear_all_data():
|
|
"""Clear all data from the database (useful for testing)"""
|
|
db = SessionLocal()
|
|
try:
|
|
db.query(DBTransaction).delete()
|
|
db.query(DBReceipt).delete()
|
|
db.query(DBUploadedFile).delete()
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Transactions table
|
|
class DBTransaction(Base):
|
|
__tablename__ = "transactions"
|
|
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
transaction_id = Column(String, index=True)
|
|
amount = Column(Float, nullable=False)
|
|
date = Column(DateTime, nullable=False)
|
|
vendor = Column(String, nullable=False)
|
|
description = Column(String, nullable=True)
|
|
category = Column(String, nullable=True)
|
|
tax_amount = Column(Float, nullable=True)
|
|
categorisation_id = Column(String, nullable=True)
|
|
user_id = Column(String, nullable=True)
|
|
|
|
|
|
# Uploaded Files table
|
|
class DBUploadedFile(Base):
|
|
__tablename__ = "uploaded_files"
|
|
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
file_id = Column(String, unique=True, index=True)
|
|
filename = Column(String, nullable=False)
|
|
file_path = Column(String, nullable=False)
|
|
file_type = Column(String, nullable=False)
|
|
upload_date = Column(DateTime, nullable=False)
|
|
status = Column(String, nullable=False, default="uploaded")
|
|
|
|
|
|
# Receipts table
|
|
class DBReceipt(Base):
|
|
__tablename__ = "receipts"
|
|
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
receipt_id = Column(String, unique=True, index=True)
|
|
file_id = Column(String, unique=True, index=True)
|
|
amount = Column(Float, nullable=False)
|
|
date = Column(DateTime, nullable=False)
|
|
vendor = Column(String, nullable=False)
|
|
description = Column(String, nullable=True)
|
|
category = Column(String, nullable=True)
|
|
tax_amount = Column(Float, nullable=True)
|
|
confidence = Column(Float, nullable=True)
|
|
extraction_success = Column(String, nullable=True)
|
|
error_message = Column(String, nullable=True)
|
|
receipt_currency = Column(String, nullable=True)
|
|
receipt_location = Column(String, nullable=True)
|
|
calculated_tax = Column(Float, nullable=True)
|
|
is_depreciable = Column(String, nullable=True) # Store as string "True"/"False"
|
|
name_of_asset = Column(String, nullable=True) # Name/description of the asset
|
|
cca_rate = Column(Float, nullable=True)
|
|
useful_life = Column(Integer, nullable=True)
|
|
residual_value = Column(Float, nullable=True)
|