Add requirements.txt with essential dependencies for the project
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
class Settings(BaseSettings):
|
||||
database_url: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
GROQ_API_KEY: str
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, String, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def create_db_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def clear_all_data():
|
||||
"""Clear all data from the database (useful for testing)"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.query(DBTransaction).delete()
|
||||
db.query(DBReceipt).delete()
|
||||
db.query(DBUploadedFile).delete()
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Transactions table
|
||||
class DBTransaction(Base):
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
transaction_id = Column(String, index=True)
|
||||
amount = Column(Float, nullable=False)
|
||||
date = Column(DateTime, nullable=False)
|
||||
vendor = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
category = Column(String, nullable=True)
|
||||
tax_amount = Column(Float, nullable=True)
|
||||
categorisation_id = Column(String, nullable=True)
|
||||
user_id = Column(String, nullable=True)
|
||||
|
||||
|
||||
# Uploaded Files table
|
||||
class DBUploadedFile(Base):
|
||||
__tablename__ = "uploaded_files"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(String, unique=True, index=True)
|
||||
filename = Column(String, nullable=False)
|
||||
file_path = Column(String, nullable=False)
|
||||
file_type = Column(String, nullable=False)
|
||||
upload_date = Column(DateTime, nullable=False)
|
||||
status = Column(String, nullable=False, default="uploaded")
|
||||
|
||||
|
||||
# Receipts table
|
||||
class DBReceipt(Base):
|
||||
__tablename__ = "receipts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
receipt_id = Column(String, unique=True, index=True)
|
||||
file_id = Column(String, unique=True, index=True)
|
||||
amount = Column(Float, nullable=False)
|
||||
date = Column(DateTime, nullable=False)
|
||||
vendor = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
category = Column(String, nullable=True)
|
||||
tax_amount = Column(Float, nullable=True)
|
||||
confidence = Column(Float, nullable=True)
|
||||
extraction_success = Column(String, nullable=True)
|
||||
error_message = Column(String, nullable=True)
|
||||
receipt_currency = Column(String, nullable=True)
|
||||
+821
@@ -0,0 +1,821 @@
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import (
|
||||
DBReceipt,
|
||||
DBTransaction,
|
||||
DBUploadedFile,
|
||||
create_db_tables,
|
||||
db_dependency,
|
||||
)
|
||||
from schemas import (
|
||||
DocumentProcessResponse,
|
||||
DocumentUploadResponse,
|
||||
MatchingResponse,
|
||||
MatchResponse,
|
||||
MatchSpecificRequest,
|
||||
Receipt,
|
||||
RuleRequest,
|
||||
Transaction,
|
||||
)
|
||||
from services.ai_rules import AIRule
|
||||
from services.document_processor import DocumentProcessor
|
||||
from services.matching_engine import MatchingEngine
|
||||
|
||||
create_db_tables()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AI Bookkeeper - Data Science Engine",
|
||||
description="AI-powered receipt-to-transaction matching engine. Receives transaction data and provides intelligent matching capabilities.",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize DS Engine components
|
||||
matching_engine = MatchingEngine()
|
||||
document_processor = DocumentProcessor()
|
||||
|
||||
|
||||
# Helper functions for database operations
|
||||
def get_transactions_from_db(
|
||||
db: Session, user_id: str = None, categorization_id: str = None
|
||||
):
|
||||
"""Retrieve transactions from database"""
|
||||
query = db.query(DBTransaction)
|
||||
if user_id:
|
||||
query = query.filter(DBTransaction.user_id == user_id)
|
||||
if categorization_id:
|
||||
query = query.filter(DBTransaction.categorisation_id == categorization_id)
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_receipt_from_db(db: Session, file_id: str):
|
||||
"""Retrieve receipt from database by file_id"""
|
||||
return db.query(DBReceipt).filter(DBReceipt.file_id == file_id).first()
|
||||
|
||||
|
||||
def get_receipts_from_db(db: Session, file_ids: List[str]):
|
||||
"""Retrieve multiple receipts from database by file_ids"""
|
||||
return db.query(DBReceipt).filter(DBReceipt.file_id.in_(file_ids)).all()
|
||||
|
||||
|
||||
def get_uploaded_file_from_db(db: Session, file_id: str):
|
||||
"""Retrieve uploaded file from database by file_id"""
|
||||
return db.query(DBUploadedFile).filter(DBUploadedFile.file_id == file_id).first()
|
||||
|
||||
|
||||
def get_uploaded_files_from_db(db: Session, file_ids: List[str]):
|
||||
"""Retrieve multiple uploaded files from database by file_ids"""
|
||||
return db.query(DBUploadedFile).filter(DBUploadedFile.file_id.in_(file_ids)).all()
|
||||
|
||||
|
||||
@app.get("/", tags=["Health"])
|
||||
async def root():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"message": "AI Bookkeeper Data Science Engine is running",
|
||||
"version": "1.0.0",
|
||||
"status": "healthy",
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TRANSACTION IMPORT ENDPOINTS
|
||||
# ============================================================================
|
||||
@app.post("/transactions/import/csv", tags=["Transaction Import"])
|
||||
async def import_transactions_csv(
|
||||
db: db_dependency,
|
||||
file: UploadFile = File(...),
|
||||
categorization_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
):
|
||||
"""
|
||||
Import transactions from a CSV file (custom bank export format).
|
||||
"""
|
||||
try:
|
||||
content = await file.read()
|
||||
decoded = content.decode("utf-8")
|
||||
reader = csv.DictReader(io.StringIO(decoded))
|
||||
transactions = []
|
||||
errors = []
|
||||
for idx, row in enumerate(reader):
|
||||
try:
|
||||
# Use correct headers and strip whitespace
|
||||
account_number = row.get("Account Number") or row.get(
|
||||
"Account Number ".strip()
|
||||
)
|
||||
txn_date_raw = row.get("Transaction Date") or row.get(
|
||||
"Transaction Date ".strip() or row.get("Date")
|
||||
)
|
||||
amount_raw = row.get("Amount") or row.get("Amount ".strip())
|
||||
payee_name = row.get("Description 2") or row.get(
|
||||
"Description 2 ".strip()
|
||||
)
|
||||
memo = f"{row.get('Account Type', '').strip()} {row.get('Cheque Number', '').strip()} {row.get('Description 1', '').strip()}".strip()
|
||||
# Compose ID
|
||||
txn_id = f"{account_number}_{idx + 1}"
|
||||
# Parse date (try multiple formats)
|
||||
txn_date_str = txn_date_raw.strip()
|
||||
txn_date = None
|
||||
for fmt in ("%m/%d/%y", "%m/%d/%Y"):
|
||||
try:
|
||||
txn_date = datetime.strptime(txn_date_str, fmt).strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if not txn_date:
|
||||
raise ValueError(f"Could not parse date: {txn_date_str}")
|
||||
# Parse amount
|
||||
amount = float(amount_raw.replace(",", "").strip())
|
||||
|
||||
# Create database transaction object
|
||||
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
|
||||
db_transaction = DBTransaction(
|
||||
transaction_id=txn_id,
|
||||
amount=amount,
|
||||
date=txn_date_obj,
|
||||
vendor=payee_name.strip(),
|
||||
description=memo,
|
||||
categorisation_id=categorization_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Add to database
|
||||
db.add(db_transaction)
|
||||
|
||||
transactions.append(
|
||||
{
|
||||
"id": txn_id,
|
||||
"txn_date": txn_date,
|
||||
"amount": amount,
|
||||
"payee_name": payee_name.strip(),
|
||||
"memo": memo,
|
||||
"categorization_id": categorization_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(f"Row {idx + 1}: {str(e)}")
|
||||
|
||||
# Commit all transactions to database
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"imported_count": len(transactions),
|
||||
"converted_transactions": transactions,
|
||||
"errors": errors,
|
||||
"categorization_id": categorization_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transactions/import/image", tags=["Transaction Import"])
|
||||
async def import_transactions_from_image(
|
||||
db: db_dependency,
|
||||
file: UploadFile = File(...),
|
||||
categorization_id: str = Form("image_import"),
|
||||
user_id: str = Form("default"),
|
||||
):
|
||||
"""
|
||||
Import transactions from an image (bank statement, credit card statement, etc.) using AI extraction.
|
||||
"""
|
||||
try:
|
||||
# Validate file type
|
||||
allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"]
|
||||
file_extension = file.filename.split(".")[-1].lower()
|
||||
if file_extension not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type. Allowed: {allowed_types}",
|
||||
)
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
# Save file to disk
|
||||
image_path = await document_processor.save_uploaded_file(content, file.filename)
|
||||
# Extract transactions from image (pass file path)
|
||||
extraction_result = await document_processor.extract_transactions_from_image(
|
||||
image_path
|
||||
)
|
||||
if not extraction_result.get("extraction_success", False):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=extraction_result.get("error", "Extraction failed"),
|
||||
)
|
||||
extracted_transactions = extraction_result.get("transactions", [])
|
||||
|
||||
# Store transactions in database
|
||||
transactions = []
|
||||
for idx, txn in enumerate(extracted_transactions):
|
||||
try:
|
||||
txn_id = f"img_{file.filename}_{idx + 1}"
|
||||
txn_date_raw = txn.get("date")
|
||||
amount = txn.get("amount")
|
||||
vendor = txn.get("vendor")
|
||||
memo = txn.get("memo", "")
|
||||
|
||||
# Parse date to YYYY-MM-DD format
|
||||
txn_date = document_processor._parse_date_to_iso(txn_date_raw)
|
||||
if not txn_date:
|
||||
# Fallback: use current year if parsing fails
|
||||
txn_date = f"2024-{txn_date_raw}"
|
||||
|
||||
# Parse date for database
|
||||
txn_date_obj = datetime.strptime(txn_date, "%Y-%m-%d")
|
||||
|
||||
# Create database transaction object
|
||||
db_transaction = DBTransaction(
|
||||
transaction_id=txn_id,
|
||||
amount=float(amount),
|
||||
date=txn_date_obj,
|
||||
vendor=vendor,
|
||||
description=memo,
|
||||
categorisation_id=categorization_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Add to database
|
||||
db.add(db_transaction)
|
||||
|
||||
transactions.append(
|
||||
{
|
||||
"id": txn_id,
|
||||
"txn_date": txn_date,
|
||||
"amount": amount,
|
||||
"payee_name": vendor,
|
||||
"memo": memo,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing transaction {idx}: {str(e)}")
|
||||
continue
|
||||
|
||||
# Commit all transactions to database
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"imported_count": len(transactions),
|
||||
"converted_transactions": transactions,
|
||||
"errors": [],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing transactions from image: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DOCUMENT PROCESSING ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.post(
|
||||
"/upload-multiple",
|
||||
response_model=List[DocumentUploadResponse],
|
||||
tags=["Document Processing"],
|
||||
)
|
||||
async def upload_multiple_documents(
|
||||
files: List[UploadFile] = File(...), db: db_dependency = None
|
||||
):
|
||||
"""
|
||||
Upload multiple receipt images for processing.
|
||||
|
||||
This endpoint accepts multiple image files and returns file IDs
|
||||
that can be used with the /process/{file_id} endpoint.
|
||||
"""
|
||||
try:
|
||||
responses = []
|
||||
|
||||
for file in files:
|
||||
# Validate file type
|
||||
allowed_types = ["jpg", "jpeg", "png", "gif", "bmp", "pdf"]
|
||||
file_extension = file.filename.split(".")[-1].lower()
|
||||
|
||||
if file_extension not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type for {file.filename}. Allowed: {allowed_types}",
|
||||
)
|
||||
|
||||
# Generate unique file ID
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Read file content and save to disk
|
||||
content = await file.read()
|
||||
file_path = await document_processor.save_uploaded_file(
|
||||
content, file.filename
|
||||
)
|
||||
|
||||
# Create database record for uploaded file
|
||||
db_uploaded_file = DBUploadedFile(
|
||||
file_id=file_id,
|
||||
filename=file.filename,
|
||||
file_path=file_path,
|
||||
file_type=file_extension,
|
||||
upload_date=datetime.now(),
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
# Add to database
|
||||
db.add(db_uploaded_file)
|
||||
|
||||
responses.append(
|
||||
DocumentUploadResponse(
|
||||
file_id=file_id,
|
||||
filename=file.filename,
|
||||
file_type=file_extension,
|
||||
upload_date=datetime.now(),
|
||||
status="uploaded",
|
||||
)
|
||||
)
|
||||
|
||||
# Commit all uploaded files to database
|
||||
db.commit()
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading documents: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post(
|
||||
"/process/{file_id}",
|
||||
response_model=DocumentProcessResponse,
|
||||
tags=["Document Processing"],
|
||||
)
|
||||
async def process_document(file_id: str, db: db_dependency):
|
||||
"""
|
||||
Process a previously uploaded document to extract receipt information.
|
||||
|
||||
This endpoint uses AI to extract structured data from receipt images,
|
||||
including vendor, amount, date, and category information.
|
||||
"""
|
||||
try:
|
||||
# Get file info from database
|
||||
db_uploaded_file = get_uploaded_file_from_db(db, file_id)
|
||||
if not db_uploaded_file:
|
||||
raise HTTPException(status_code=404, detail=f"File {file_id} not found")
|
||||
|
||||
# Process the file using the stored file path
|
||||
receipt_data = await document_processor.process_file(
|
||||
db_uploaded_file.file_path, db_uploaded_file.file_type
|
||||
)
|
||||
|
||||
# Parse date for database storage
|
||||
receipt_date = None
|
||||
if receipt_data.get("date"):
|
||||
try:
|
||||
receipt_date = datetime.strptime(receipt_data["date"], "%Y-%m-%d")
|
||||
except ValueError:
|
||||
receipt_date = datetime.now()
|
||||
else:
|
||||
receipt_date = datetime.now()
|
||||
|
||||
# Create database receipt object
|
||||
db_receipt = DBReceipt(
|
||||
receipt_id=f"receipt_{file_id}",
|
||||
file_id=file_id,
|
||||
amount=receipt_data.get("total_amount", 0.0),
|
||||
date=receipt_date,
|
||||
vendor=receipt_data.get("vendor", ""),
|
||||
description=receipt_data.get("description", ""),
|
||||
category=receipt_data.get("category", ""),
|
||||
tax_amount=receipt_data.get("tax_amount", 0.0),
|
||||
confidence=receipt_data.get("confidence", 0.0),
|
||||
extraction_success=str(receipt_data.get("extraction_success", False)),
|
||||
error_message=receipt_data.get("error"),
|
||||
receipt_currency=receipt_data.get("currency")
|
||||
)
|
||||
|
||||
# Add to database
|
||||
db.add(db_receipt)
|
||||
db.commit()
|
||||
|
||||
return DocumentProcessResponse(
|
||||
file_id=file_id,
|
||||
receipt_id=db_receipt.receipt_id,
|
||||
extraction_success=receipt_data.get("extraction_success", False),
|
||||
vendor=receipt_data.get("vendor", ""),
|
||||
description=receipt_data.get("description", ""),
|
||||
total_amount=receipt_data.get("total_amount", 0.0),
|
||||
tax_amount=receipt_data.get("tax_amount", 0.0),
|
||||
date=receipt_data.get("date", ""),
|
||||
category=receipt_data.get("category", ""),
|
||||
confidence=receipt_data.get("confidence", 0.0),
|
||||
error=receipt_data.get("error", None),
|
||||
receipt_currency=receipt_data.get("currency")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document {file_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/match-specific", response_model=MatchingResponse, tags=["AI Matching"])
|
||||
async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependency):
|
||||
"""
|
||||
Match specific receipts against imported transactions.
|
||||
|
||||
This endpoint takes a request with receipt file IDs and categorization ID,
|
||||
and matches them against the currently imported transactions using AI-powered matching logic.
|
||||
"""
|
||||
try:
|
||||
file_ids = request.file_ids
|
||||
categorization_id = request.categorization_id
|
||||
logger.info(
|
||||
f"Starting match-specific for file IDs: {file_ids}, categorization_id: {categorization_id}"
|
||||
)
|
||||
|
||||
# Get transactions from database
|
||||
db_transactions = get_transactions_from_db(
|
||||
db, categorization_id=categorization_id
|
||||
)
|
||||
|
||||
if not db_transactions:
|
||||
logger.warning("No transactions found in database")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No transactions found. Please upload CSV first.",
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(db_transactions)} transactions in database")
|
||||
|
||||
# Convert database transactions to Transaction objects
|
||||
transactions = []
|
||||
for db_txn in db_transactions:
|
||||
try:
|
||||
transaction = Transaction(
|
||||
id=db_txn.transaction_id,
|
||||
transaction_date=db_txn.date,
|
||||
amount=db_txn.amount,
|
||||
vendor=db_txn.vendor,
|
||||
notes=db_txn.description or "",
|
||||
)
|
||||
transactions.append(transaction)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error converting transaction {db_txn.transaction_id}: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Converted {len(transactions)} transactions")
|
||||
|
||||
# Get receipts for the specified file IDs from database
|
||||
db_receipts = get_receipts_from_db(db, file_ids)
|
||||
receipts = []
|
||||
missing_files = []
|
||||
|
||||
for file_id in file_ids:
|
||||
# Find the corresponding database receipt
|
||||
db_receipt = next((r for r in db_receipts if r.file_id == file_id), None)
|
||||
|
||||
if db_receipt:
|
||||
try:
|
||||
receipt = Receipt(
|
||||
id=db_receipt.receipt_id,
|
||||
receipt_date=db_receipt.date,
|
||||
amount=db_receipt.amount,
|
||||
vendor=db_receipt.vendor,
|
||||
category=db_receipt.category or "Other",
|
||||
description=db_receipt.description or "",
|
||||
tax=db_receipt.tax_amount or 0.0,
|
||||
file_name=db_receipt.file_id,
|
||||
upload_date=datetime.now(),
|
||||
)
|
||||
receipts.append(receipt)
|
||||
logger.info(f"Successfully loaded receipt for file_id: {file_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error creating receipt object for {file_id}: {str(e)}"
|
||||
)
|
||||
missing_files.append(file_id)
|
||||
else:
|
||||
logger.warning(f"Receipt {file_id} not found in database")
|
||||
missing_files.append(file_id)
|
||||
|
||||
logger.info(f"Found {len(receipts)} receipts, {len(missing_files)} missing")
|
||||
|
||||
if missing_files:
|
||||
logger.warning(f"Missing files: {missing_files}")
|
||||
|
||||
if not receipts:
|
||||
logger.warning("No valid receipts found")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No valid receipts found for matching.",
|
||||
)
|
||||
|
||||
# Perform matching
|
||||
logger.info(
|
||||
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
|
||||
)
|
||||
|
||||
try:
|
||||
matching_results = matching_engine.process_matching(receipts, transactions)
|
||||
logger.info(f"Matching completed, got {len(matching_results)} results")
|
||||
|
||||
# Convert matching results to response format
|
||||
match_responses = []
|
||||
for result in matching_results:
|
||||
match_response = MatchResponse(
|
||||
receipt_id=result.receipt.id,
|
||||
transaction_id=result.transaction.id
|
||||
if result.transaction
|
||||
else "no_match",
|
||||
confidence_score=result.confidence_score * 100,
|
||||
match_reason=result.match_reason,
|
||||
receipt_vendor=result.receipt.vendor,
|
||||
receipt_amount=result.receipt.amount,
|
||||
receipt_description=result.receipt.description,
|
||||
receipt_category=result.receipt.category,
|
||||
receipt_tax_amount=result.receipt.tax,
|
||||
transaction_vendor=result.transaction.vendor
|
||||
if result.transaction
|
||||
else "",
|
||||
transaction_amount=result.transaction.amount
|
||||
if result.transaction
|
||||
else 0.0,
|
||||
)
|
||||
match_responses.append(match_response)
|
||||
|
||||
# Calculate statistics
|
||||
high_confidence = len(
|
||||
[r for r in matching_results if r.confidence_score >= 0.8]
|
||||
)
|
||||
low_confidence = len(
|
||||
[r for r in matching_results if r.confidence_score < 0.5]
|
||||
)
|
||||
avg_score = (
|
||||
sum(r.confidence_score for r in matching_results)
|
||||
/ len(matching_results)
|
||||
if matching_results
|
||||
else 0
|
||||
)
|
||||
|
||||
stats = {
|
||||
"total": len(match_responses),
|
||||
"high_confidence": high_confidence,
|
||||
"low_confidence": low_confidence,
|
||||
"avg_score": round(avg_score, 2),
|
||||
}
|
||||
|
||||
logger.info(f"Generated stats: {stats}")
|
||||
logger.info(
|
||||
f"Match-specific completed successfully with {len(match_responses)} matches"
|
||||
)
|
||||
|
||||
return MatchingResponse(matches=match_responses, stats=stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in matching section: {str(e)}")
|
||||
logger.error(f"Exception type: {type(e)}")
|
||||
logger.error(f"Exception args: {e.args}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Unexpected matching error: {str(e)}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in match_specific_receipts: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE QUERY ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.get("/transactions", tags=["Database Queries"])
|
||||
async def get_transactions(
|
||||
db: db_dependency,
|
||||
user_id: str = None,
|
||||
categorization_id: str = None,
|
||||
limit: int = 100,
|
||||
):
|
||||
"""
|
||||
Get transactions from the database.
|
||||
"""
|
||||
try:
|
||||
transactions = get_transactions_from_db(db, user_id, categorization_id)
|
||||
|
||||
# Limit results
|
||||
transactions = transactions[:limit]
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for txn in transactions:
|
||||
result.append(
|
||||
{
|
||||
"id": txn.transaction_id,
|
||||
"amount": txn.amount,
|
||||
"date": txn.date.strftime("%Y-%m-%d"),
|
||||
"vendor": txn.vendor,
|
||||
"description": txn.description,
|
||||
"category": txn.category,
|
||||
"tax_amount": txn.tax_amount,
|
||||
"categorisation_id": txn.categorisation_id,
|
||||
"user_id": txn.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"transactions": result,
|
||||
"count": len(result),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/receipts", tags=["Database Queries"])
|
||||
async def get_receipts(db: db_dependency, limit: int = 100):
|
||||
"""
|
||||
Get receipts from the database.
|
||||
"""
|
||||
try:
|
||||
receipts = db.query(DBReceipt).limit(limit).all()
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for receipt in receipts:
|
||||
result.append(
|
||||
{
|
||||
"id": receipt.receipt_id,
|
||||
"file_id": receipt.file_id,
|
||||
"amount": receipt.amount,
|
||||
"date": receipt.date.strftime("%Y-%m-%d"),
|
||||
"vendor": receipt.vendor,
|
||||
"description": receipt.description,
|
||||
"category": receipt.category,
|
||||
"tax_amount": receipt.tax_amount,
|
||||
"confidence": receipt.confidence,
|
||||
"extraction_success": receipt.extraction_success,
|
||||
"error_message": receipt.error_message,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"receipts": result,
|
||||
"count": len(result),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/uploaded-files", tags=["Database Queries"])
|
||||
async def get_uploaded_files(db: db_dependency, limit: int = 100):
|
||||
"""
|
||||
Get uploaded files from the database.
|
||||
"""
|
||||
try:
|
||||
uploaded_files = db.query(DBUploadedFile).limit(limit).all()
|
||||
|
||||
# Convert to response format
|
||||
result = []
|
||||
for file in uploaded_files:
|
||||
result.append(
|
||||
{
|
||||
"file_id": file.file_id,
|
||||
"filename": file.filename,
|
||||
"file_path": file.file_path,
|
||||
"file_type": file.file_type,
|
||||
"upload_date": file.upload_date.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"status": file.status,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"uploaded_files": result,
|
||||
"count": len(result),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RULES MANAGEMENT ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.post("/rules", tags=["AI Rules Management"])
|
||||
async def add_rule(request: RuleRequest):
|
||||
"""
|
||||
Add a new AI rule for transaction matching.
|
||||
"""
|
||||
try:
|
||||
new_rule = AIRule(
|
||||
name=request.name,
|
||||
condition=request.condition,
|
||||
action=request.action,
|
||||
source=request.source,
|
||||
)
|
||||
|
||||
matching_engine.rules_engine.rules.append(new_rule)
|
||||
|
||||
return {"message": f"Rule '{request.name}' added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/rules", tags=["AI Rules Management"])
|
||||
async def get_rules():
|
||||
"""
|
||||
Get all current AI rules.
|
||||
"""
|
||||
try:
|
||||
rules = []
|
||||
for rule in matching_engine.rules_engine.rules:
|
||||
rules.append(
|
||||
{
|
||||
"name": rule.name,
|
||||
"condition": rule.condition,
|
||||
"action": rule.action,
|
||||
"source": rule.source,
|
||||
"status": rule.status,
|
||||
}
|
||||
)
|
||||
|
||||
return {"rules": rules}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/rules/{rule_name}", tags=["AI Rules Management"])
|
||||
async def delete_rule(rule_name: str):
|
||||
"""
|
||||
Delete an AI rule by name.
|
||||
"""
|
||||
try:
|
||||
rules = matching_engine.rules_engine.rules
|
||||
for i, rule in enumerate(rules):
|
||||
if rule.name == rule_name:
|
||||
del rules[i]
|
||||
return {"message": f"Rule '{rule_name}' deleted successfully"}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Rule '{rule_name}' not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STATISTICS ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.get("/stats", tags=["Statistics"])
|
||||
async def get_stats(db: db_dependency):
|
||||
"""
|
||||
Get system statistics.
|
||||
"""
|
||||
try:
|
||||
# Count transactions, receipts, and uploaded files from database
|
||||
total_transactions = db.query(DBTransaction).count()
|
||||
total_receipts = db.query(DBReceipt).count()
|
||||
total_uploaded_files = db.query(DBUploadedFile).count()
|
||||
|
||||
return {
|
||||
"total_transactions": total_transactions,
|
||||
"total_receipts": total_receipts,
|
||||
"total_uploaded_files": total_uploaded_files,
|
||||
"rules_count": len(matching_engine.rules_engine.rules),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8654)
|
||||
+210
@@ -0,0 +1,210 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class Address:
|
||||
"""Address information for tax calculations"""
|
||||
|
||||
province: str
|
||||
city: str
|
||||
postal_code: str
|
||||
country: str = "Canada"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Receipt:
|
||||
id: str
|
||||
file_name: str
|
||||
upload_date: datetime
|
||||
receipt_date: datetime
|
||||
amount: float
|
||||
tax: float
|
||||
vendor: str
|
||||
category: str
|
||||
description: str
|
||||
# Tax rule fields
|
||||
billing_address: Optional[Address] = None
|
||||
shipping_address: Optional[Address] = None
|
||||
currency: str = "CAD"
|
||||
is_meals_entertainment: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transaction:
|
||||
id: str
|
||||
transaction_date: datetime
|
||||
amount: float
|
||||
vendor: str
|
||||
notes: str
|
||||
# Tax rule fields
|
||||
currency: str = "CAD"
|
||||
fx_rate: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Asset:
|
||||
"""Asset for depreciation calculations"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
purchase_date: datetime
|
||||
purchase_amount: float
|
||||
useful_life_years: int
|
||||
residual_value: float
|
||||
cca_rate: float # Capital Cost Allowance rate
|
||||
asset_class: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
receipt: Receipt
|
||||
transaction: Transaction
|
||||
confidence_score: float
|
||||
match_reason: str
|
||||
tax_analysis: Optional[dict] = None
|
||||
|
||||
|
||||
class AddressRequest(BaseModel):
|
||||
province: str
|
||||
city: str
|
||||
postal_code: str
|
||||
country: str = "Canada"
|
||||
|
||||
|
||||
class ReceiptRequest(BaseModel):
|
||||
id: str
|
||||
file_name: str
|
||||
upload_date: datetime
|
||||
receipt_date: datetime
|
||||
amount: float
|
||||
tax: float
|
||||
vendor: str
|
||||
category: str
|
||||
description: str
|
||||
# Tax rule fields
|
||||
billing_address: Optional[AddressRequest] = None
|
||||
shipping_address: Optional[AddressRequest] = None
|
||||
currency: str = "CAD"
|
||||
is_meals_entertainment: bool = False
|
||||
|
||||
|
||||
class TransactionRequest(BaseModel):
|
||||
id: str
|
||||
transaction_date: datetime
|
||||
amount: float
|
||||
vendor: str
|
||||
notes: str
|
||||
# Tax rule fields
|
||||
currency: str = "CAD"
|
||||
fx_rate: Optional[float] = None
|
||||
|
||||
|
||||
class AssetRequest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
purchase_date: datetime
|
||||
purchase_amount: float
|
||||
useful_life_years: int
|
||||
residual_value: float
|
||||
cca_rate: float
|
||||
asset_class: str
|
||||
|
||||
|
||||
class MatchingRequest(BaseModel):
|
||||
receipt_ids: List[str]
|
||||
transaction_ids: List[str]
|
||||
|
||||
|
||||
class MatchResponse(BaseModel):
|
||||
receipt_id: str
|
||||
transaction_id: str
|
||||
confidence_score: float
|
||||
match_reason: str
|
||||
receipt_vendor: str
|
||||
receipt_amount: float
|
||||
receipt_description: str
|
||||
receipt_category: str
|
||||
receipt_tax_amount: float
|
||||
transaction_vendor: str
|
||||
transaction_amount: float
|
||||
|
||||
|
||||
class MatchingResponse(BaseModel):
|
||||
matches: List[MatchResponse]
|
||||
stats: dict
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
match_id: str
|
||||
approved: bool
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class RuleRequest(BaseModel):
|
||||
name: str
|
||||
condition: str
|
||||
action: str
|
||||
source: str = "user"
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
file_id: str
|
||||
filename: str
|
||||
file_type: str
|
||||
upload_date: datetime
|
||||
status: str
|
||||
|
||||
|
||||
class DocumentProcessResponse(BaseModel):
|
||||
file_id: str
|
||||
receipt_id: str
|
||||
extraction_success: bool
|
||||
vendor: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
total_amount: Optional[float] = None
|
||||
tax_amount: Optional[float] = None
|
||||
date: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
confidence: Optional[float] = None
|
||||
error: Optional[str] = None
|
||||
receipt_currency: Optional[str] = "CAD"
|
||||
|
||||
|
||||
# New tax-related models
|
||||
class TaxCalculationRequest(BaseModel):
|
||||
receipt_id: str
|
||||
transaction_id: Optional[str] = None
|
||||
|
||||
|
||||
class TaxCalculationResponse(BaseModel):
|
||||
receipt_id: str
|
||||
rules_applied: List[str]
|
||||
sales_tax: dict
|
||||
fx_analysis: Optional[dict] = None
|
||||
meals_entertainment: dict
|
||||
|
||||
|
||||
class DepreciationRequest(BaseModel):
|
||||
asset: AssetRequest
|
||||
year: int
|
||||
method: str # "straight_line" or "cca"
|
||||
|
||||
|
||||
class DepreciationResponse(BaseModel):
|
||||
asset_id: str
|
||||
year: int
|
||||
method: str
|
||||
depreciation: float
|
||||
book_value: float
|
||||
total_depreciation: Optional[float] = None
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class MatchSpecificRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
categorization_id: str
|
||||
@@ -0,0 +1,469 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import groq
|
||||
|
||||
from config import settings
|
||||
from schemas import Match, Receipt, Transaction
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIMatcher:
|
||||
def __init__(self, use_batch_matching=True):
|
||||
self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
|
||||
self.model = "llama-3.1-8b-instant"
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 2 # seconds - increased for rate limiting
|
||||
self.rate_limit_delay = 1.0 # seconds between API calls
|
||||
self.last_api_call = 0
|
||||
self.use_batch_matching = (
|
||||
use_batch_matching # Toggle between new and legacy methods
|
||||
)
|
||||
|
||||
def match_receipts_to_transactions(
|
||||
self, receipts: List[Receipt], transactions: List[Transaction]
|
||||
) -> List[Match]:
|
||||
"""Match receipts to transactions using AI"""
|
||||
logger.info(
|
||||
f"Starting AI matching for {len(receipts)} receipts against {len(transactions)} transactions"
|
||||
)
|
||||
matches = []
|
||||
|
||||
for i, receipt in enumerate(receipts):
|
||||
logger.info(
|
||||
f"Processing receipt {i + 1}/{len(receipts)}: {receipt.vendor} - ${receipt.amount}"
|
||||
)
|
||||
|
||||
# Rate limiting
|
||||
self._rate_limit()
|
||||
|
||||
# Get the BEST match for this receipt (highest confidence score)
|
||||
best_match = self._find_best_match(receipt, transactions)
|
||||
if best_match:
|
||||
matches.append(best_match)
|
||||
logger.info(
|
||||
f"Found match: {best_match.confidence_score:.3f} - {best_match.match_reason}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No match found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||||
)
|
||||
|
||||
# Sort by confidence score (highest first)
|
||||
matches = sorted(matches, key=lambda x: x.confidence_score, reverse=True)
|
||||
logger.info(f"AI matching completed. Found {len(matches)} matches")
|
||||
return matches
|
||||
|
||||
def _rate_limit(self):
|
||||
"""Implement rate limiting to avoid API quota exhaustion"""
|
||||
current_time = time.time()
|
||||
time_since_last_call = current_time - self.last_api_call
|
||||
|
||||
if time_since_last_call < self.rate_limit_delay:
|
||||
sleep_time = self.rate_limit_delay - time_since_last_call
|
||||
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_api_call = time.time()
|
||||
|
||||
def _find_best_match(
|
||||
self, receipt: Receipt, transactions: List[Transaction]
|
||||
) -> Match:
|
||||
"""Find the BEST match for a receipt using a single AI call for all candidates"""
|
||||
candidates = self._filter_candidates(receipt, transactions)
|
||||
if not candidates:
|
||||
logger.warning(
|
||||
f"No candidates found for receipt: {receipt.vendor} - ${receipt.amount}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"Found {len(candidates)} candidates for receipt: {receipt.vendor}")
|
||||
|
||||
# Choose matching method based on configuration
|
||||
if self.use_batch_matching:
|
||||
# New efficient method: single AI call for all candidates
|
||||
best_match = self._find_best_match_single_call(receipt, candidates)
|
||||
else:
|
||||
# Legacy method: individual AI calls (fallback)
|
||||
best_match = self._find_best_match_legacy(receipt, candidates)
|
||||
|
||||
return best_match
|
||||
|
||||
def _find_best_match_single_call(
|
||||
self, receipt: Receipt, candidates: List[Transaction]
|
||||
) -> Match:
|
||||
"""Find the best match using a single AI call to evaluate all candidates"""
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Limit candidates to avoid token limits (adjust based on your needs)
|
||||
max_candidates = 10
|
||||
if len(candidates) > max_candidates:
|
||||
# Sort by amount similarity and take top candidates
|
||||
candidates = sorted(
|
||||
candidates, key=lambda t: abs(receipt.amount - abs(t.amount))
|
||||
)[:max_candidates]
|
||||
logger.info(
|
||||
f"Limited candidates to top {max_candidates} by amount similarity"
|
||||
)
|
||||
|
||||
# Build comprehensive prompt with all candidates
|
||||
candidates_text = ""
|
||||
for i, transaction in enumerate(candidates):
|
||||
transaction_amount_abs = abs(transaction.amount)
|
||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||||
amount_percent_diff = (
|
||||
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||||
)
|
||||
|
||||
candidates_text += f"""
|
||||
Candidate {i + 1}:
|
||||
- Vendor: {transaction.vendor}
|
||||
- Amount: ${transaction.amount} (absolute: ${transaction_amount_abs})
|
||||
- Date: {transaction.transaction_date.strftime("%Y-%m-%d")} ({date_diff} days difference)
|
||||
- Notes: {transaction.notes}
|
||||
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
||||
"""
|
||||
|
||||
prompt = f"""
|
||||
You are an expert at matching receipts to bank transactions. Analyze the receipt below against ALL the candidate transactions and return the BEST match.
|
||||
|
||||
RECEIPT TO MATCH:
|
||||
- Vendor: {receipt.vendor}
|
||||
- Amount: ${receipt.amount}
|
||||
- Date: {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||||
- Description: {receipt.description}
|
||||
- Category: {receipt.category}
|
||||
|
||||
CANDIDATE TRANSACTIONS:
|
||||
{candidates_text}
|
||||
|
||||
SCORING CRITERIA:
|
||||
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
||||
- High confidence (minor differences): 0.8-0.94
|
||||
- Medium confidence (moderate differences): 0.6-0.79
|
||||
- Low confidence (significant differences): 0.4-0.59
|
||||
- Very low confidence (major differences): 0.2-0.39
|
||||
- Minimal similarity: 0.1-0.19
|
||||
- No meaningful similarity: 0.0-0.09
|
||||
|
||||
Consider vendor name similarity, amount accuracy, date proximity, and description/notes relevance.
|
||||
|
||||
IMPORTANT: You MUST return the candidate with the highest match score, even if it's very low. Never return NONE.
|
||||
Return ONLY the best match in this exact format:
|
||||
CANDIDATE_NUMBER|CONFIDENCE_SCORE|REASON
|
||||
|
||||
Example: 3|0.87|Same vendor name, exact amount match, 1 day apart
|
||||
Example of low match: 5|0.15|Best available option despite significant differences in vendor and amount
|
||||
"""
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self._call_groq_api_with_timeout(
|
||||
prompt, timeout=45
|
||||
) # Longer timeout for complex prompt
|
||||
|
||||
# Parse the single result
|
||||
candidate_num, score, reason = self._parse_single_match_response(result)
|
||||
|
||||
if candidate_num == -1: # Parsing error occurred
|
||||
logger.warning(
|
||||
f"Failed to parse AI response for receipt: {receipt.vendor}"
|
||||
)
|
||||
return None
|
||||
|
||||
if 0 <= candidate_num < len(candidates):
|
||||
best_transaction = candidates[candidate_num]
|
||||
logger.info(
|
||||
f"AI selected candidate {candidate_num + 1}: {best_transaction.vendor} (score: {score:.3f})"
|
||||
)
|
||||
return Match(receipt, best_transaction, score, reason)
|
||||
else:
|
||||
logger.warning(
|
||||
f"AI returned invalid candidate number: {candidate_num}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
sleep_time = self.retry_delay * (2**attempt)
|
||||
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _parse_single_match_response(self, result: str) -> Tuple[int, float, str]:
|
||||
"""Parse AI response for single best match"""
|
||||
result = result.strip()
|
||||
logger.debug(f"Parsing single match response: {result}")
|
||||
|
||||
try:
|
||||
if result.upper().startswith("NONE"):
|
||||
# This should not happen with new prompt, but handle as parsing error
|
||||
logger.warning(
|
||||
"AI returned NONE despite being instructed to always return best match"
|
||||
)
|
||||
return -1, 0.0, "AI returned NONE unexpectedly"
|
||||
|
||||
if "|" in result:
|
||||
parts = result.split("|")
|
||||
if len(parts) >= 3:
|
||||
candidate_str = parts[0].strip()
|
||||
score_str = parts[1].strip()
|
||||
reason = "|".join(parts[2:]).strip()
|
||||
|
||||
# Extract candidate number
|
||||
import re
|
||||
|
||||
candidate_match = re.search(r"\d+", candidate_str)
|
||||
if candidate_match:
|
||||
candidate_num = (
|
||||
int(candidate_match.group()) - 1
|
||||
) # Convert to 0-based index
|
||||
else:
|
||||
raise ValueError("No candidate number found")
|
||||
|
||||
# Extract score
|
||||
score_clean = "".join(
|
||||
c for c in score_str if c.isdigit() or c == "."
|
||||
)
|
||||
score = float(score_clean) if score_clean else 0.0
|
||||
|
||||
# Ensure score is in valid range
|
||||
score = max(0.0, min(1.0, score))
|
||||
|
||||
logger.debug(
|
||||
f"Parsed: candidate={candidate_num}, score={score}, reason={reason}"
|
||||
)
|
||||
return candidate_num, score, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing single match response: {e}")
|
||||
|
||||
# Fallback
|
||||
logger.warning(f"Could not parse single match response: {result}")
|
||||
return -1, 0.0, f"Parse error: {result[:50]}..."
|
||||
|
||||
def _filter_candidates(
|
||||
self, receipt: Receipt, transactions: List[Transaction]
|
||||
) -> List[Transaction]:
|
||||
"""Filter transactions to create a reasonable candidate list"""
|
||||
candidates = []
|
||||
amount_threshold = receipt.amount * 2.0 # 200% threshold - very inclusive
|
||||
|
||||
for transaction in transactions:
|
||||
# Use absolute value for transaction amount comparison
|
||||
transaction_amount_abs = abs(transaction.amount)
|
||||
|
||||
# Only exclude transactions with obviously different amounts
|
||||
if abs(receipt.amount - transaction_amount_abs) <= amount_threshold:
|
||||
candidates.append(transaction)
|
||||
|
||||
logger.debug(
|
||||
f"Filtered {len(transactions)} transactions to {len(candidates)} candidates"
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _find_best_match_legacy(
|
||||
self, receipt: Receipt, transactions: List[Transaction]
|
||||
) -> Match:
|
||||
"""Legacy method: Find the best match using individual API calls (kept as fallback)"""
|
||||
candidates = self._filter_candidates(receipt, transactions)
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
highest_score = 0
|
||||
|
||||
for transaction in candidates:
|
||||
score, reason = self._calculate_match_score(receipt, transaction)
|
||||
logger.debug(
|
||||
f"Score {score:.3f} for transaction {transaction.vendor}: {reason}"
|
||||
)
|
||||
|
||||
if score > highest_score:
|
||||
highest_score = score
|
||||
best_match = Match(receipt, transaction, score, reason)
|
||||
|
||||
return best_match
|
||||
|
||||
def _calculate_match_score(
|
||||
self, receipt: Receipt, transaction: Transaction
|
||||
) -> Tuple[float, str]:
|
||||
"""Calculate match score using AI"""
|
||||
# Calculate differences for the AI to consider
|
||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||
transaction_amount_abs = abs(transaction.amount)
|
||||
amount_diff = abs(receipt.amount - transaction_amount_abs)
|
||||
amount_percent_diff = (
|
||||
(amount_diff / receipt.amount) * 100 if receipt.amount > 0 else 0
|
||||
)
|
||||
|
||||
prompt = f"""
|
||||
Compare this receipt with this transaction and provide a confidence score (0-1) and brief reason, the reason must be a single sentence without any special formatting.
|
||||
|
||||
Receipt: {receipt.vendor}, ${receipt.amount}, {receipt.receipt_date.strftime("%Y-%m-%d")}
|
||||
Receipt Description: {receipt.description}
|
||||
Receipt Category: {receipt.category}
|
||||
Transaction: {transaction.vendor}, ${transaction.amount} (absolute: ${transaction_amount_abs}), {transaction.transaction_date.strftime("%Y-%m-%d")}
|
||||
Transaction Notes: {transaction.notes}
|
||||
|
||||
Differences:
|
||||
- Date difference: {date_diff} days
|
||||
- Amount difference: ${amount_diff} ({amount_percent_diff:.1f}%)
|
||||
- Vendor comparison: "{receipt.vendor}" vs "{transaction.vendor}"
|
||||
- Description/Notes comparison: "{receipt.description}" vs "{transaction.notes}"
|
||||
- Category: {receipt.category}
|
||||
|
||||
Score this potential match based on how likely it is the correct match:
|
||||
|
||||
- Perfect matches (same vendor, amount, date): 0.95-1.0
|
||||
- High confidence (minor differences): 0.8-0.94
|
||||
- Medium confidence (moderate differences): 0.6-0.79
|
||||
- Low confidence (significant differences): 0.4-0.59
|
||||
- Very low confidence (major differences): 0.2-0.39
|
||||
- Minimal similarity: 0.1-0.19
|
||||
- No meaningful similarity: 0.0-0.09
|
||||
|
||||
Consider description and category similarity in your scoring.
|
||||
|
||||
IMPORTANT: Return ONLY the score and reason separated by a pipe character.
|
||||
Format: [score]|[reason]
|
||||
Example: 0.85|Same vendor, same amount, 2 days apart
|
||||
"""
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self._call_groq_api_with_timeout(
|
||||
prompt, timeout=30
|
||||
) # Increased timeout
|
||||
|
||||
# Parse the result - handle multiple formats
|
||||
score, reason = self._parse_ai_response(result)
|
||||
|
||||
logger.debug(f"AI Response: {result}")
|
||||
logger.debug(f"Parsed: score={score}, reason={reason}")
|
||||
|
||||
return score, reason
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt + 1} failed for receipt {receipt.id}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
# Exponential backoff for rate limiting
|
||||
sleep_time = self.retry_delay * (2**attempt)
|
||||
logger.info(f"Waiting {sleep_time} seconds before retry...")
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
logger.error(f"All attempts failed for receipt {receipt.id}")
|
||||
return 0.0, f"AI error after {self.max_retries} attempts: {str(e)}"
|
||||
|
||||
def _parse_ai_response(self, result: str) -> Tuple[float, str]:
|
||||
"""Parse AI response with robust error handling"""
|
||||
result = result.strip()
|
||||
logger.debug(f"Parsing AI response: {result}")
|
||||
|
||||
# Try to find score in various formats
|
||||
if "|" in result:
|
||||
parts = result.split("|")
|
||||
logger.debug(f"Split response into {len(parts)} parts: {parts}")
|
||||
|
||||
# Look for a numeric score in any part
|
||||
for i, part in enumerate(parts):
|
||||
part = part.strip()
|
||||
try:
|
||||
# Remove any non-numeric characters except decimal point
|
||||
score_str_clean = "".join(
|
||||
c for c in part if c.isdigit() or c == "."
|
||||
)
|
||||
if score_str_clean:
|
||||
score = float(score_str_clean)
|
||||
if 0 <= score <= 1: # Valid confidence score
|
||||
# Get reason from other parts
|
||||
reason_parts = [
|
||||
p.strip()
|
||||
for j, p in enumerate(parts)
|
||||
if j != i and p.strip()
|
||||
]
|
||||
reason = (
|
||||
" | ".join(reason_parts)
|
||||
if reason_parts
|
||||
else "Score extracted"
|
||||
)
|
||||
logger.debug(
|
||||
f"Found score {score} in part {i}, reason: {reason}"
|
||||
)
|
||||
return score, reason
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try to extract just a number from the response
|
||||
try:
|
||||
import re
|
||||
|
||||
numbers = re.findall(r"\d+\.?\d*", result)
|
||||
if numbers:
|
||||
for num_str in numbers:
|
||||
score = float(num_str)
|
||||
if 0 <= score <= 1: # Valid confidence score
|
||||
logger.debug(f"Extracted score {score} from response")
|
||||
return score, f"Extracted from response: {result[:50]}..."
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Fallback - try to find any number and normalize it
|
||||
try:
|
||||
import re
|
||||
|
||||
numbers = re.findall(r"\d+\.?\d*", result)
|
||||
if numbers:
|
||||
score = float(numbers[0])
|
||||
# Normalize to 0-1 range if it's a percentage or other scale
|
||||
if score > 1:
|
||||
score = score / 100 # Assume percentage
|
||||
score = max(0, min(1, score)) # Clamp to 0-1
|
||||
logger.debug(f"Normalized score {score} from response")
|
||||
return score, f"Normalized from response: {result[:50]}..."
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Final fallback
|
||||
logger.warning(f"Could not parse AI response: {result}")
|
||||
return 0.0, f"Unparseable response: {result[:50]}..."
|
||||
|
||||
def _call_groq_api_with_timeout(self, prompt: str, timeout: int = 15) -> str:
|
||||
"""Make API call with timeout and retry logic"""
|
||||
import concurrent.futures
|
||||
|
||||
def api_call():
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=200,
|
||||
temperature=0.1,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(api_call)
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise Exception(f"API call timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,175 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from schemas import Receipt, Transaction
|
||||
from services.tax_rules_engine import TaxRulesEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIRule:
|
||||
name: str
|
||||
condition: str
|
||||
action: str
|
||||
source: str
|
||||
status: str = "active"
|
||||
|
||||
|
||||
class AIRulesEngine:
|
||||
def __init__(self):
|
||||
self.rules: List[AIRule] = []
|
||||
self.tax_rules_engine = TaxRulesEngine()
|
||||
self._load_default_rules()
|
||||
|
||||
def _load_default_rules(self):
|
||||
self.rules = [
|
||||
AIRule(
|
||||
"exact_amount_match", "amount_diff <= 0.01", "auto_approve", "system"
|
||||
),
|
||||
AIRule(
|
||||
"same_vendor_same_date",
|
||||
"vendor_match and date_diff <= 1",
|
||||
"high_confidence",
|
||||
"system",
|
||||
),
|
||||
AIRule(
|
||||
"gas_station_pattern",
|
||||
"vendor_contains_gas_or_fuel",
|
||||
"categorize_transport",
|
||||
"system",
|
||||
),
|
||||
# Tax-related rules
|
||||
AIRule(
|
||||
"fx_currency_mismatch",
|
||||
"currency_mismatch",
|
||||
"flag_fx_review",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"meals_entertainment",
|
||||
"is_meals_entertainment",
|
||||
"apply_me_tax_rule",
|
||||
"tax_system",
|
||||
),
|
||||
AIRule(
|
||||
"provincial_tax_calculation",
|
||||
"has_address_info",
|
||||
"calculate_provincial_tax",
|
||||
"tax_system",
|
||||
),
|
||||
]
|
||||
|
||||
def apply_rules(self, receipt: Receipt, transaction: Transaction) -> Dict[str, Any]:
|
||||
results = {
|
||||
"auto_approve": False,
|
||||
"confidence_boost": 0,
|
||||
"category": None,
|
||||
"tax_analysis": {},
|
||||
}
|
||||
|
||||
for rule in self.rules:
|
||||
if rule.status != "active":
|
||||
continue
|
||||
|
||||
if self._evaluate_condition(rule.condition, receipt, transaction):
|
||||
self._execute_action(rule.action, results, receipt, transaction)
|
||||
|
||||
return results
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str, receipt: Receipt, transaction: Transaction
|
||||
) -> bool:
|
||||
"""Safely evaluate rule conditions without using eval()"""
|
||||
amount_diff = abs(receipt.amount - abs(transaction.amount))
|
||||
date_diff = abs((receipt.receipt_date - transaction.transaction_date).days)
|
||||
vendor_match = (
|
||||
receipt.vendor.lower() in transaction.vendor.lower()
|
||||
or transaction.vendor.lower() in receipt.vendor.lower()
|
||||
)
|
||||
vendor_lower = receipt.vendor.lower()
|
||||
vendor_contains_gas_or_fuel = "gas" in vendor_lower or "fuel" in vendor_lower
|
||||
|
||||
# Tax-related conditions
|
||||
currency_mismatch = receipt.currency != transaction.currency
|
||||
is_meals_entertainment = receipt.is_meals_entertainment
|
||||
has_address_info = (
|
||||
receipt.billing_address is not None or receipt.shipping_address is not None
|
||||
)
|
||||
|
||||
# Handle specific condition types safely
|
||||
if condition == "amount_diff <= 0.01":
|
||||
return amount_diff <= 0.01
|
||||
elif condition == "vendor_match and date_diff <= 1":
|
||||
return vendor_match and date_diff <= 1
|
||||
elif condition == "vendor_contains_gas_or_fuel":
|
||||
return vendor_contains_gas_or_fuel
|
||||
elif condition == "currency_mismatch":
|
||||
return currency_mismatch
|
||||
elif condition == "is_meals_entertainment":
|
||||
return is_meals_entertainment
|
||||
elif condition == "has_address_info":
|
||||
return has_address_info
|
||||
else:
|
||||
# For any other conditions, try to evaluate them safely
|
||||
try:
|
||||
# Only allow safe operations
|
||||
safe_globals = {
|
||||
"amount_diff": amount_diff,
|
||||
"date_diff": date_diff,
|
||||
"vendor_match": vendor_match,
|
||||
"vendor_contains_gas_or_fuel": vendor_contains_gas_or_fuel,
|
||||
"currency_mismatch": currency_mismatch,
|
||||
"is_meals_entertainment": is_meals_entertainment,
|
||||
"has_address_info": has_address_info,
|
||||
"receipt": receipt,
|
||||
"transaction": transaction,
|
||||
"abs": abs,
|
||||
"len": len,
|
||||
"min": min,
|
||||
"max": max,
|
||||
"sum": sum,
|
||||
"round": round,
|
||||
}
|
||||
return eval(condition, safe_globals, {})
|
||||
except (SyntaxError, NameError, TypeError) as e:
|
||||
print(f"Warning: Invalid condition '{condition}': {e}")
|
||||
return False
|
||||
|
||||
def _execute_action(
|
||||
self,
|
||||
action: str,
|
||||
results: Dict[str, Any],
|
||||
receipt: Receipt,
|
||||
transaction: Transaction,
|
||||
):
|
||||
if action == "auto_approve":
|
||||
results["auto_approve"] = True
|
||||
elif action == "high_confidence":
|
||||
results["confidence_boost"] += 0.2
|
||||
elif action == "categorize_transport":
|
||||
results["category"] = "Transportation"
|
||||
elif action == "flag_fx_review":
|
||||
# Apply FX rule and flag for review
|
||||
fx_result = self.tax_rules_engine.apply_fx_rule(receipt, transaction)
|
||||
results["tax_analysis"]["fx"] = fx_result
|
||||
if fx_result.get("requires_manual_review", False):
|
||||
results["confidence_boost"] -= 0.1 # Reduce confidence for FX issues
|
||||
elif action == "apply_me_tax_rule":
|
||||
# Apply meals & entertainment rule
|
||||
me_result = self.tax_rules_engine.apply_meals_entertainment_rule(receipt)
|
||||
results["tax_analysis"]["meals_entertainment"] = me_result
|
||||
elif action == "calculate_provincial_tax":
|
||||
# Calculate provincial tax
|
||||
tax_result = self.tax_rules_engine.apply_sales_tax_rule(receipt)
|
||||
results["tax_analysis"]["sales_tax"] = tax_result
|
||||
|
||||
def add_rule(self, rule: AIRule):
|
||||
self.rules.append(rule)
|
||||
|
||||
def remove_rule(self, rule_name: str):
|
||||
self.rules = [r for r in self.rules if r.name != rule_name]
|
||||
|
||||
def apply_tax_rules(
|
||||
self, receipt: Receipt, transaction: Transaction = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply all tax rules to a receipt/transaction pair"""
|
||||
return self.tax_rules_engine.apply_all_tax_rules(receipt, transaction)
|
||||
@@ -0,0 +1,547 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiofiles
|
||||
import groq
|
||||
import PyPDF2
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
def __init__(self):
|
||||
self.client = groq.Groq(api_key=settings.GROQ_API_KEY)
|
||||
self.model = "meta-llama/llama-4-scout-17b-16e-instruct" # Vision model
|
||||
|
||||
async def process_file(self, file_path: str, file_type: str) -> Dict[str, Any]:
|
||||
"""Process uploaded file and extract receipt data"""
|
||||
try:
|
||||
if file_type.lower() in ["jpg", "jpeg", "png", "gif", "bmp"]:
|
||||
return await self._process_image(file_path)
|
||||
elif file_type.lower() == "pdf":
|
||||
return await self._process_pdf(file_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _process_image(self, image_path: str) -> Dict[str, Any]:
|
||||
"""Extract data from image using Groq vision"""
|
||||
try:
|
||||
# Encode image to base64
|
||||
base64_image = self._encode_image(image_path)
|
||||
|
||||
# Create Groq vision prompt
|
||||
prompt = """
|
||||
Analyze this receipt image and extract the following information in JSON format:
|
||||
{
|
||||
"vendor": "Store/company name",
|
||||
"description": "Detailed description of items/services purchased",
|
||||
"total_amount": 0.00,
|
||||
"tax_amount": 0.00,
|
||||
"date": "YYYY-MM-DD",
|
||||
"category": "Food/Transport/Office/Other",
|
||||
"confidence": 0.95,
|
||||
"currency": "USD"
|
||||
}
|
||||
|
||||
Rules:
|
||||
- Extract vendor name as it appears on receipt
|
||||
- Extract description of items/services purchased (e.g., "Coffee and sandwich", "Gasoline", "Office supplies")
|
||||
- Total amount should be the final total including tax
|
||||
- Tax amount is separate tax line if available
|
||||
- Date should be the date on the receipt
|
||||
- Categorize based on vendor type (Starbucks=Food, Shell=Transport, etc.)
|
||||
- Confidence score 0-1 based on how clear the receipt is
|
||||
- Currency should be the currency used on the receipt (e.g., "USD", "EUR")
|
||||
|
||||
Return only valid JSON.
|
||||
"""
|
||||
|
||||
# Call Groq vision API with correct format
|
||||
response = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=self.model,
|
||||
max_tokens=500,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
return self._parse_extraction_result(result_text)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Image processing error: {str(e)}"}
|
||||
|
||||
def _encode_image(self, image_path: str) -> str:
|
||||
"""Encode image to base64 string"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
async def _process_pdf(self, pdf_path: str) -> Dict[str, Any]:
|
||||
"""Extract data from PDF by converting to image first"""
|
||||
try:
|
||||
# For now, extract text from PDF and process as text
|
||||
text_content = self._extract_text_from_pdf(pdf_path)
|
||||
return self._process_text_content(text_content)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"PDF processing error: {str(e)}"}
|
||||
|
||||
def _extract_text_from_pdf(self, pdf_path: str) -> str:
|
||||
"""Extract text from PDF"""
|
||||
try:
|
||||
with open(pdf_path, "rb") as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
text = ""
|
||||
for page in pdf_reader.pages:
|
||||
text += page.extract_text() + "\n"
|
||||
return text
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _process_text_content(self, text_content: str) -> Dict[str, Any]:
|
||||
"""Process text content using Groq (fallback for PDFs)"""
|
||||
try:
|
||||
prompt = f"""
|
||||
Analyze this receipt text and extract the following information in JSON format:
|
||||
|
||||
Receipt Text:
|
||||
{text_content}
|
||||
|
||||
Extract:
|
||||
{{
|
||||
"vendor": "Store/company name",
|
||||
"description": "Detailed description of items/services purchased",
|
||||
"total_amount": 0.00,
|
||||
"tax_amount": 0.00,
|
||||
"date": "YYYY-MM-DD",
|
||||
"category": "Food/Transport/Office/Other",
|
||||
"confidence": 0.95,
|
||||
"currency": "USD"
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Extract vendor name as it appears on receipt
|
||||
- Extract description of items/services purchased (e.g., "Coffee and sandwich", "Gasoline", "Office supplies")
|
||||
- Total amount should be the final total including tax
|
||||
- Tax amount is separate tax line if available
|
||||
- Date should be the date on the receipt
|
||||
- Categorize based on vendor type
|
||||
- Confidence score 0-1 based on clarity
|
||||
- Currency should be the currency used on the receipt (e.g., "USD", "EUR")
|
||||
|
||||
Return only valid JSON.
|
||||
"""
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=500,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
return self._parse_extraction_result(result_text)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Text processing error: {str(e)}"}
|
||||
|
||||
def _parse_extraction_result(self, result_text: str) -> Dict[str, Any]:
|
||||
"""Parse Groq response and extract JSON data"""
|
||||
try:
|
||||
# Clean up response and extract JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# Find JSON in response - try multiple patterns
|
||||
json_match = re.search(r"\{.*\}", result_text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group()
|
||||
|
||||
# Clean up common JSON issues
|
||||
json_str = re.sub(
|
||||
r",\s*([}\]])", r"\1", json_str
|
||||
) # Remove trailing commas
|
||||
json_str = re.sub(
|
||||
r"([{,])\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*:", r'\1"\2":', json_str
|
||||
) # Quote unquoted keys
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
# Try to fix common JSON issues
|
||||
logger.warning(f"Initial JSON parsing failed: {e}")
|
||||
|
||||
# Try to extract individual fields using regex
|
||||
vendor_match = re.search(r'"vendor"\s*:\s*"([^"]*)"', json_str)
|
||||
description_match = re.search(
|
||||
r'"description"\s*:\s*"([^"]*)"', json_str
|
||||
)
|
||||
total_amount_match = re.search(
|
||||
r'"total_amount"\s*:\s*([0-9.]+)', json_str
|
||||
)
|
||||
tax_amount_match = re.search(
|
||||
r'"tax_amount"\s*:\s*([0-9.]+)', json_str
|
||||
)
|
||||
date_match = re.search(r'"date"\s*:\s*"([^"]*)"', json_str)
|
||||
category_match = re.search(r'"category"\s*:\s*"([^"]*)"', json_str)
|
||||
confidence_match = re.search(
|
||||
r'"confidence"\s*:\s*([0-9.]+)', json_str
|
||||
)
|
||||
currency_match = re.search(
|
||||
r'"currency"\s*:\s*"([^"]*)"', json_str
|
||||
)
|
||||
|
||||
data = {
|
||||
"vendor": vendor_match.group(1) if vendor_match else "",
|
||||
"description": description_match.group(1)
|
||||
if description_match
|
||||
else "",
|
||||
"total_amount": float(total_amount_match.group(1))
|
||||
if total_amount_match
|
||||
else 0.0,
|
||||
"tax_amount": float(tax_amount_match.group(1))
|
||||
if tax_amount_match
|
||||
else 0.0,
|
||||
"date": date_match.group(1) if date_match else "",
|
||||
"category": category_match.group(1)
|
||||
if category_match
|
||||
else "Other",
|
||||
"confidence": float(confidence_match.group(1))
|
||||
if confidence_match
|
||||
else 0.5,
|
||||
"currency": currency_match.group(1) if currency_match else "CAD"
|
||||
}
|
||||
|
||||
# Validate and clean data
|
||||
return {
|
||||
"vendor": str(data.get("vendor", "")).strip(),
|
||||
"description": str(data.get("description", "")).strip(),
|
||||
"total_amount": float(data.get("total_amount", 0)),
|
||||
"tax_amount": float(data.get("tax_amount", 0)),
|
||||
"date": str(data.get("date", "")).strip(),
|
||||
"category": str(data.get("category", "Other")).strip(),
|
||||
"confidence": float(data.get("confidence", 0.5)),
|
||||
"extraction_success": True,
|
||||
"currency": data.get("currency", "CAD").strip(),
|
||||
}
|
||||
else:
|
||||
# Try to extract fields from plain text
|
||||
logger.warning("No JSON found in response, attempting text extraction")
|
||||
return self._extract_from_plain_text(result_text)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JSON parsing error: {str(e)}")
|
||||
return {
|
||||
"error": f"JSON parsing error: {str(e)}",
|
||||
"extraction_success": False,
|
||||
}
|
||||
|
||||
def _extract_from_plain_text(self, text: str) -> Dict[str, Any]:
|
||||
"""Extract receipt data from plain text when JSON parsing fails"""
|
||||
try:
|
||||
import re
|
||||
|
||||
# Extract vendor (look for common patterns)
|
||||
vendor_patterns = [
|
||||
r"(?:vendor|store|merchant|company)\s*[:\-]?\s*([A-Za-z0-9\s&.,]+)",
|
||||
r"([A-Z][A-Za-z0-9\s&.,]{3,30})", # Capitalized words
|
||||
]
|
||||
|
||||
vendor = ""
|
||||
for pattern in vendor_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
vendor = match.group(1).strip()
|
||||
break
|
||||
|
||||
# Extract amount (look for currency patterns)
|
||||
amount_patterns = [
|
||||
r"\$?\s*([0-9,]+\.?[0-9]*)",
|
||||
r"(?:total|amount|sum)\s*[:\-]?\s*\$?\s*([0-9,]+\.?[0-9]*)",
|
||||
]
|
||||
|
||||
total_amount = 0.0
|
||||
for pattern in amount_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
total_amount = float(match.group(1).replace(",", ""))
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Extract date
|
||||
date_patterns = [
|
||||
r"(\d{4}-\d{2}-\d{2})",
|
||||
r"(\d{1,2}/\d{1,2}/\d{2,4})",
|
||||
r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+\d{1,2},?\s+\d{4}",
|
||||
]
|
||||
|
||||
date = ""
|
||||
for pattern in date_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
date = match.group(0)
|
||||
break
|
||||
|
||||
return {
|
||||
"vendor": vendor or "Unknown",
|
||||
"total_amount": total_amount,
|
||||
"tax_amount": 0.0,
|
||||
"date": date or "",
|
||||
"category": "Other",
|
||||
"confidence": 0.3, # Low confidence for text extraction
|
||||
"extraction_success": True,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Text extraction error: {str(e)}")
|
||||
return {
|
||||
"vendor": "Unknown",
|
||||
"total_amount": 0.0,
|
||||
"tax_amount": 0.0,
|
||||
"date": "",
|
||||
"category": "Other",
|
||||
"confidence": 0.1,
|
||||
"extraction_success": False,
|
||||
"error": f"Text extraction failed: {str(e)}",
|
||||
}
|
||||
|
||||
async def save_uploaded_file(self, file_content: bytes, filename: str) -> str:
|
||||
"""Save uploaded file to temporary storage"""
|
||||
try:
|
||||
# Create uploads directory if it doesn't exist
|
||||
upload_dir = "uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
|
||||
# Generate unique filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
safe_filename = f"{timestamp}_{filename.replace(' ', '_')}"
|
||||
file_path = os.path.join(upload_dir, safe_filename)
|
||||
|
||||
# Save file
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(file_content)
|
||||
|
||||
return file_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to save file: {str(e)}")
|
||||
|
||||
async def extract_transactions_from_image(self, image_path: str) -> Dict[str, Any]:
|
||||
"""Extract multiple transactions from an image (bank statement, credit card statement, etc.)"""
|
||||
try:
|
||||
# Encode image to base64
|
||||
base64_image = self._encode_image(image_path)
|
||||
|
||||
# Create Groq vision prompt for transaction extraction
|
||||
prompt = """
|
||||
Analyze this financial document image (bank statement, credit card statement, etc.) and extract ALL transactions in JSON format.
|
||||
|
||||
Look for transaction lists, payment records, or any financial entries that show:
|
||||
- Date
|
||||
- Amount (positive or negative)
|
||||
- Vendor/Description/Payee name
|
||||
- Any additional notes or memo
|
||||
|
||||
Return the transactions as a JSON array:
|
||||
{
|
||||
"extraction_success": true,
|
||||
"transactions": [
|
||||
{
|
||||
"date": "YYYY-MM-DD",
|
||||
"amount": 0.00,
|
||||
"vendor": "Vendor name",
|
||||
"memo": "Additional notes"
|
||||
},
|
||||
{
|
||||
"date": "YYYY-MM-DD",
|
||||
"amount": -0.00,
|
||||
"vendor": "Another vendor",
|
||||
"memo": "Payment or charge description"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Rules:
|
||||
- Extract ALL visible transactions
|
||||
- Include both positive (credits) and negative (debits) amounts
|
||||
- Use the actual date format from the document
|
||||
- Vendor should be the merchant/payee name
|
||||
- Memo can include transaction type, reference numbers, etc.
|
||||
- If no transactions found, return empty array but set extraction_success to true
|
||||
|
||||
Return only valid JSON.
|
||||
"""
|
||||
|
||||
# Call Groq vision API
|
||||
response = self.client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=self.model,
|
||||
max_tokens=2000, # Higher token limit for multiple transactions
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Parse response
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
return self._parse_transaction_extraction_result(result_text)
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"extraction_success": False,
|
||||
"error": f"Transaction extraction error: {str(e)}",
|
||||
"transactions": [],
|
||||
}
|
||||
|
||||
def _parse_transaction_extraction_result(self, result_text: str) -> Dict[str, Any]:
|
||||
"""Parse Groq response for transaction extraction"""
|
||||
try:
|
||||
import json
|
||||
import re
|
||||
|
||||
# Find the first '{' and last '}'
|
||||
start = result_text.find("{")
|
||||
end = result_text.rfind("}")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
return {
|
||||
"extraction_success": False,
|
||||
"error": "Could not find JSON object in AI response",
|
||||
"transactions": [],
|
||||
}
|
||||
json_str = result_text[start : end + 1]
|
||||
|
||||
# Remove trailing commas before } or ]
|
||||
json_str = re.sub(r",\s*([}\]])", r"\1", json_str)
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.error(f"JSON parsing error: {str(e)}")
|
||||
logging.error(f"Offending JSON string:\n{json_str}")
|
||||
return {
|
||||
"extraction_success": False,
|
||||
"error": f"JSON parsing error: {str(e)}",
|
||||
"transactions": [],
|
||||
}
|
||||
|
||||
# Validate and clean data
|
||||
transactions = data.get("transactions", [])
|
||||
cleaned_transactions = []
|
||||
for txn in transactions:
|
||||
try:
|
||||
cleaned_txn = {
|
||||
"date": str(txn.get("date", "")).strip(),
|
||||
"amount": float(
|
||||
str(txn.get("amount", 0)).replace("$", "").replace(",", "")
|
||||
),
|
||||
"vendor": str(txn.get("vendor", "")).strip(),
|
||||
"memo": str(txn.get("memo", "")).strip(),
|
||||
}
|
||||
cleaned_transactions.append(cleaned_txn)
|
||||
except Exception:
|
||||
continue
|
||||
return {
|
||||
"extraction_success": data.get("extraction_success", True),
|
||||
"transactions": cleaned_transactions,
|
||||
"total_transactions": len(cleaned_transactions),
|
||||
}
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.error(f"JSON parsing error (outer): {str(e)}")
|
||||
return {
|
||||
"extraction_success": False,
|
||||
"error": f"JSON parsing error: {str(e)}",
|
||||
"transactions": [],
|
||||
}
|
||||
|
||||
def _parse_date_to_iso(self, date_str: str) -> str:
|
||||
"""Parse various date formats and convert to YYYY-MM-DD"""
|
||||
try:
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
date_str = date_str.strip().upper()
|
||||
|
||||
# Handle formats like "MAY 22", "JUN 01", "MAY 22, 2024"
|
||||
month_pattern = r"(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)\s+(\d{1,2})(?:,\s*(\d{4}))?"
|
||||
match = re.match(month_pattern, date_str)
|
||||
|
||||
if match:
|
||||
month_abbr, day, year = match.groups()
|
||||
month_map = {
|
||||
"JAN": 1,
|
||||
"FEB": 2,
|
||||
"MAR": 3,
|
||||
"APR": 4,
|
||||
"MAY": 5,
|
||||
"JUN": 6,
|
||||
"JUL": 7,
|
||||
"AUG": 8,
|
||||
"SEP": 9,
|
||||
"OCT": 10,
|
||||
"NOV": 11,
|
||||
"DEC": 12,
|
||||
}
|
||||
|
||||
month = month_map[month_abbr]
|
||||
day = int(day)
|
||||
year = int(year) if year else datetime.now().year
|
||||
|
||||
# Handle 2-digit years
|
||||
if year < 100:
|
||||
year += 2000
|
||||
|
||||
return f"{year:04d}-{month:02d}-{day:02d}"
|
||||
|
||||
# Handle YYYY-MM-DD format
|
||||
if re.match(r"\d{4}-\d{2}-\d{2}", date_str):
|
||||
return date_str
|
||||
|
||||
# Handle MM/DD/YYYY format
|
||||
if re.match(r"\d{1,2}/\d{1,2}/\d{4}", date_str):
|
||||
return datetime.strptime(date_str, "%m/%d/%Y").strftime("%Y-%m-%d")
|
||||
|
||||
# Handle MM/DD/YY format
|
||||
if re.match(r"\d{1,2}/\d{1,2}/\d{2}", date_str):
|
||||
return datetime.strptime(date_str, "%m/%d/%y").strftime("%Y-%m-%d")
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
@@ -0,0 +1,76 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeedbackLog:
|
||||
transaction_id: str
|
||||
original_match: str
|
||||
correction: str
|
||||
reason: str
|
||||
timestamp: datetime
|
||||
user_id: str
|
||||
|
||||
|
||||
class FeedbackLogger:
|
||||
def __init__(self, log_file: str = "feedback_logs.json"):
|
||||
self.log_file = log_file
|
||||
self.logs: List[FeedbackLog] = self._load_logs()
|
||||
|
||||
def _load_logs(self) -> List[FeedbackLog]:
|
||||
if not os.path.exists(self.log_file):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.log_file, "r") as f:
|
||||
data = json.load(f)
|
||||
return [FeedbackLog(**log) for log in data]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _save_logs(self):
|
||||
with open(self.log_file, "w") as f:
|
||||
json.dump(
|
||||
[
|
||||
{
|
||||
"transaction_id": log.transaction_id,
|
||||
"original_match": log.original_match,
|
||||
"correction": log.correction,
|
||||
"reason": log.reason,
|
||||
"timestamp": log.timestamp.isoformat(),
|
||||
"user_id": log.user_id,
|
||||
}
|
||||
for log in self.logs
|
||||
],
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
def log_override(
|
||||
self,
|
||||
transaction_id: str,
|
||||
original_match: str,
|
||||
correction: str,
|
||||
reason: str,
|
||||
user_id: str,
|
||||
):
|
||||
log = FeedbackLog(
|
||||
transaction_id=transaction_id,
|
||||
original_match=original_match,
|
||||
correction=correction,
|
||||
reason=reason,
|
||||
timestamp=datetime.now(),
|
||||
user_id=user_id,
|
||||
)
|
||||
self.logs.append(log)
|
||||
self._save_logs()
|
||||
|
||||
def get_logs_by_transaction(self, transaction_id: str) -> List[FeedbackLog]:
|
||||
return [log for log in self.logs if log.transaction_id == transaction_id]
|
||||
|
||||
def get_recent_logs(self, days: int = 30) -> List[FeedbackLog]:
|
||||
cutoff = datetime.now() - timedelta(days=days)
|
||||
return [log for log in self.logs if log.timestamp > cutoff]
|
||||
@@ -0,0 +1,89 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from services.ai_matcher import AIMatcher
|
||||
from services.ai_rules import AIRulesEngine
|
||||
from services.feedback_logger import FeedbackLogger
|
||||
from schemas import Match, Receipt, Transaction
|
||||
|
||||
|
||||
class MatchingEngine:
|
||||
def __init__(self):
|
||||
self.ai_matcher = AIMatcher()
|
||||
self.rules_engine = AIRulesEngine()
|
||||
self.feedback_logger = FeedbackLogger()
|
||||
|
||||
def process_matching(
|
||||
self, receipts: List[Receipt], transactions: List[Transaction]
|
||||
) -> List[Match]:
|
||||
# Get AI matches
|
||||
ai_matches = self.ai_matcher.match_receipts_to_transactions(
|
||||
receipts, transactions
|
||||
)
|
||||
|
||||
# Apply rules and enhance matches
|
||||
enhanced_matches = []
|
||||
for match in ai_matches:
|
||||
enhanced_match = self._enhance_match_with_rules(match)
|
||||
enhanced_matches.append(enhanced_match)
|
||||
|
||||
return enhanced_matches
|
||||
|
||||
def _enhance_match_with_rules(self, match: Match) -> Match:
|
||||
rule_results = self.rules_engine.apply_rules(match.receipt, match.transaction)
|
||||
|
||||
# Apply confidence boost from rules
|
||||
if rule_results["confidence_boost"] > 0:
|
||||
match.confidence_score = min(
|
||||
1.0, match.confidence_score + rule_results["confidence_boost"]
|
||||
)
|
||||
|
||||
# Auto-approve if rules say so
|
||||
if rule_results["auto_approve"]:
|
||||
match.confidence_score = 1.0
|
||||
match.match_reason += " (Auto-approved by rules)"
|
||||
|
||||
# Add tax analysis to match
|
||||
if rule_results.get("tax_analysis"):
|
||||
match.tax_analysis = rule_results["tax_analysis"]
|
||||
|
||||
return match
|
||||
|
||||
def approve_match(self, match: Match, user_id: str):
|
||||
# Log the approval
|
||||
self.feedback_logger.log_override(
|
||||
transaction_id=match.transaction.id,
|
||||
original_match=f"AI Score: {match.confidence_score}",
|
||||
correction="Approved",
|
||||
reason="User approved match",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def reject_match(self, match: Match, reason: str, user_id: str):
|
||||
# Log the rejection
|
||||
self.feedback_logger.log_override(
|
||||
transaction_id=match.transaction.id,
|
||||
original_match=f"AI Score: {match.confidence_score}",
|
||||
correction="Rejected",
|
||||
reason=reason,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def get_matching_stats(self, matches: List[Match]) -> Dict[str, Any]:
|
||||
if not matches:
|
||||
return {
|
||||
"total": 0,
|
||||
"high_confidence": 0,
|
||||
"low_confidence": 0,
|
||||
"avg_score": 0,
|
||||
}
|
||||
|
||||
high_confidence = len([m for m in matches if m.confidence_score >= 0.8])
|
||||
low_confidence = len([m for m in matches if m.confidence_score < 0.8])
|
||||
avg_score = sum(m.confidence_score for m in matches) / len(matches)
|
||||
|
||||
return {
|
||||
"total": len(matches),
|
||||
"high_confidence": high_confidence,
|
||||
"low_confidence": low_confidence,
|
||||
"avg_score": round(avg_score, 3),
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from schemas import Address, Asset, Receipt, Transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaxRulesEngine:
|
||||
"""Engine to handle tax calculations based on the four tax rules"""
|
||||
|
||||
# Provincial tax rates (simplified - in production, use a tax rate API)
|
||||
PROVINCIAL_TAX_RATES = {
|
||||
"ON": 0.13, # Ontario HST
|
||||
"QC": 0.14975, # Quebec QST
|
||||
"BC": 0.12, # British Columbia
|
||||
"AB": 0.05, # Alberta
|
||||
"SK": 0.11, # Saskatchewan
|
||||
"MB": 0.12, # Manitoba
|
||||
"NS": 0.15, # Nova Scotia
|
||||
"NB": 0.15, # New Brunswick
|
||||
"NL": 0.15, # Newfoundland and Labrador
|
||||
"PE": 0.15, # Prince Edward Island
|
||||
"NT": 0.05, # Northwest Territories
|
||||
"NU": 0.05, # Nunavut
|
||||
"YT": 0.05, # Yukon
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def apply_sales_tax_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
||||
"""
|
||||
Sales Tax Rule: Apply correct sales tax based on billing vs shipping addresses
|
||||
"""
|
||||
try:
|
||||
# Determine which address to use for tax calculation
|
||||
tax_address = self._get_tax_address(receipt)
|
||||
|
||||
if not tax_address:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No valid address found for tax calculation",
|
||||
"calculated_tax": 0.0,
|
||||
"tax_rate": 0.0,
|
||||
}
|
||||
|
||||
# Get tax rate for the province
|
||||
tax_rate = self.PROVINCIAL_TAX_RATES.get(tax_address.province, 0.0)
|
||||
|
||||
# Calculate tax amount
|
||||
calculated_tax = receipt.amount * tax_rate
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"calculated_tax": calculated_tax,
|
||||
"tax_rate": tax_rate,
|
||||
"tax_address": tax_address.province,
|
||||
"rule_applied": "Sales Tax Rule",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error applying sales tax rule: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"calculated_tax": 0.0,
|
||||
"tax_rate": 0.0,
|
||||
}
|
||||
|
||||
def _get_tax_address(self, receipt: Receipt) -> Optional[Address]:
|
||||
"""Determine which address to use for tax calculation"""
|
||||
# Rule: Use shipping address if different from billing, otherwise use billing
|
||||
if receipt.shipping_address and receipt.billing_address:
|
||||
if self._addresses_different(
|
||||
receipt.billing_address, receipt.shipping_address
|
||||
):
|
||||
return receipt.shipping_address
|
||||
else:
|
||||
return receipt.billing_address
|
||||
elif receipt.shipping_address:
|
||||
return receipt.shipping_address
|
||||
elif receipt.billing_address:
|
||||
return receipt.billing_address
|
||||
else:
|
||||
return None
|
||||
|
||||
def _addresses_different(self, billing: Address, shipping: Address) -> bool:
|
||||
"""Check if billing and shipping addresses are different"""
|
||||
return (
|
||||
billing.province != shipping.province
|
||||
or billing.city != shipping.city
|
||||
or billing.postal_code != shipping.postal_code
|
||||
)
|
||||
|
||||
def apply_fx_rule(
|
||||
self, receipt: Receipt, transaction: Transaction
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Foreign Exchange Rule: Handle currency mismatches
|
||||
"""
|
||||
try:
|
||||
# Check for currency mismatch
|
||||
if receipt.currency != transaction.currency:
|
||||
fx_discrepancy = abs(receipt.amount - abs(transaction.amount))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"fx_discrepancy": fx_discrepancy,
|
||||
"receipt_currency": receipt.currency,
|
||||
"transaction_currency": transaction.currency,
|
||||
"receipt_amount": receipt.amount,
|
||||
"transaction_amount": abs(transaction.amount),
|
||||
"requires_manual_review": True,
|
||||
"rule_applied": "Foreign Exchange Rule",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"fx_discrepancy": 0.0,
|
||||
"requires_manual_review": False,
|
||||
"rule_applied": "No FX Rule (same currency)",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error applying FX rule: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"fx_discrepancy": 0.0,
|
||||
"requires_manual_review": False,
|
||||
}
|
||||
|
||||
def calculate_straight_line_depreciation(
|
||||
self, asset: Asset, year: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Straight-Line Depreciation for accounting purposes
|
||||
"""
|
||||
try:
|
||||
if year > asset.useful_life_years:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Year {year} exceeds useful life of {asset.useful_life_years} years",
|
||||
"depreciation": 0.0,
|
||||
}
|
||||
|
||||
# Straight-line formula: (Cost - Residual Value) / Useful Life
|
||||
annual_depreciation = (
|
||||
asset.purchase_amount - asset.residual_value
|
||||
) / asset.useful_life_years
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"depreciation": annual_depreciation,
|
||||
"book_value": asset.purchase_amount - (annual_depreciation * year),
|
||||
"method": "Straight-Line",
|
||||
"rule_applied": "Depreciation Rule (Accounting)",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating straight-line depreciation: {str(e)}")
|
||||
return {"success": False, "error": str(e), "depreciation": 0.0}
|
||||
|
||||
def calculate_cca_depreciation(self, asset: Asset, year: int) -> Dict[str, Any]:
|
||||
"""
|
||||
CCA (Capital Cost Allowance) Depreciation for tax purposes
|
||||
"""
|
||||
try:
|
||||
if year < 1:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Year must be at least 1",
|
||||
"depreciation": 0.0,
|
||||
}
|
||||
|
||||
# CCA uses declining balance method
|
||||
book_value = asset.purchase_amount
|
||||
total_depreciation = 0.0
|
||||
|
||||
for current_year in range(1, year + 1):
|
||||
# CCA is calculated on the declining balance
|
||||
cca_amount = book_value * asset.cca_rate
|
||||
book_value -= cca_amount
|
||||
total_depreciation += cca_amount
|
||||
|
||||
# Stop if book value reaches residual value
|
||||
if book_value <= asset.residual_value:
|
||||
break
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"depreciation": cca_amount, # Current year depreciation
|
||||
"total_depreciation": total_depreciation,
|
||||
"book_value": max(book_value, asset.residual_value),
|
||||
"method": "CCA Declining Balance",
|
||||
"rule_applied": "Depreciation Rule (Tax)",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error calculating CCA depreciation: {str(e)}")
|
||||
return {"success": False, "error": str(e), "depreciation": 0.0}
|
||||
|
||||
def apply_meals_entertainment_rule(self, receipt: Receipt) -> Dict[str, Any]:
|
||||
"""
|
||||
Meals & Entertainment Tax Deduction Rule
|
||||
"""
|
||||
try:
|
||||
if not receipt.is_meals_entertainment:
|
||||
return {
|
||||
"success": True,
|
||||
"tax_deduction": receipt.amount,
|
||||
"accounting_deduction": receipt.amount,
|
||||
"rule_applied": "No M&E Rule (not meals/entertainment)",
|
||||
}
|
||||
|
||||
# For tax purposes: 50% deductible
|
||||
tax_deduction = receipt.amount * 0.5
|
||||
|
||||
# For accounting purposes: 100% deductible
|
||||
accounting_deduction = receipt.amount
|
||||
|
||||
# Sales tax is fully deductible for accounting
|
||||
tax_on_meal = receipt.tax
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tax_deduction": tax_deduction,
|
||||
"accounting_deduction": accounting_deduction,
|
||||
"tax_on_meal": tax_on_meal,
|
||||
"rule_applied": "Meals & Entertainment Rule",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error applying meals & entertainment rule: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"tax_deduction": 0.0,
|
||||
"accounting_deduction": 0.0,
|
||||
}
|
||||
|
||||
def apply_all_tax_rules(
|
||||
self, receipt: Receipt, transaction: Transaction = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply all tax rules to a receipt
|
||||
"""
|
||||
results = {
|
||||
"receipt_id": receipt.id,
|
||||
"rules_applied": [],
|
||||
"sales_tax": {},
|
||||
"fx_analysis": {},
|
||||
"meals_entertainment": {},
|
||||
}
|
||||
|
||||
# Apply Sales Tax Rule
|
||||
sales_tax_result = self.apply_sales_tax_rule(receipt)
|
||||
results["sales_tax"] = sales_tax_result
|
||||
if sales_tax_result["success"]:
|
||||
results["rules_applied"].append("Sales Tax Rule")
|
||||
|
||||
# Apply FX Rule (if transaction provided)
|
||||
if transaction:
|
||||
fx_result = self.apply_fx_rule(receipt, transaction)
|
||||
results["fx_analysis"] = fx_result
|
||||
if fx_result["success"]:
|
||||
results["rules_applied"].append("Foreign Exchange Rule")
|
||||
|
||||
# Apply Meals & Entertainment Rule
|
||||
me_result = self.apply_meals_entertainment_rule(receipt)
|
||||
results["meals_entertainment"] = me_result
|
||||
if me_result["success"]:
|
||||
results["rules_applied"].append("Meals & Entertainment Rule")
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user