Add user location support and tax analysis enhancements
- Introduced user location extraction from user tax info for improved matching. - Normalized user location to province codes for tax calculations. - Updated MatchResponse schema to include tax analysis data. - Enhanced LLMTaxAnalyzer to handle various location formats and provide fallback logic.
This commit is contained in:
+26
-8
@@ -5,10 +5,6 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import (
|
||||
DBReceipt,
|
||||
DBTransaction,
|
||||
@@ -16,6 +12,8 @@ from database import (
|
||||
create_db_tables,
|
||||
db_dependency,
|
||||
)
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from schemas import (
|
||||
DocumentProcessResponse,
|
||||
DocumentUploadResponse,
|
||||
@@ -29,6 +27,7 @@ from schemas import (
|
||||
from services.ai_rules import AIRule
|
||||
from services.document_processor import DocumentProcessor
|
||||
from services.matching_engine import MatchingEngine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
create_db_tables()
|
||||
|
||||
@@ -410,7 +409,7 @@ async def process_document(file_id: str, db: db_dependency):
|
||||
confidence=receipt_data.get("confidence", 0.0),
|
||||
extraction_success=str(receipt_data.get("extraction_success", False)),
|
||||
error_message=receipt_data.get("error"),
|
||||
receipt_currency=receipt_data.get("currency")
|
||||
receipt_currency=receipt_data.get("currency"),
|
||||
)
|
||||
|
||||
# Add to database
|
||||
@@ -429,7 +428,7 @@ async def process_document(file_id: str, db: db_dependency):
|
||||
category=receipt_data.get("category", ""),
|
||||
confidence=receipt_data.get("confidence", 0.0),
|
||||
error=receipt_data.get("error", None),
|
||||
receipt_currency=receipt_data.get("currency")
|
||||
receipt_currency=receipt_data.get("currency"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -536,13 +535,31 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
|
||||
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
|
||||
)
|
||||
|
||||
# Extract user location from user_tax_info if provided
|
||||
user_location = request.user_location # Default/fallback
|
||||
if request.user_tax_info:
|
||||
# Use state_code from user_tax_info (e.g., "ON", "QC", "BC")
|
||||
user_location = request.user_tax_info.state.state_code
|
||||
logger.info(
|
||||
f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Using default/provided user_location: {user_location}")
|
||||
|
||||
try:
|
||||
matching_results = matching_engine.process_matching(receipts, transactions, user_location=request.user_location)
|
||||
matching_results = matching_engine.process_matching(
|
||||
receipts, transactions, user_location=user_location
|
||||
)
|
||||
logger.info(f"Matching completed, got {len(matching_results)} results")
|
||||
|
||||
# Convert matching results to response format
|
||||
match_responses = []
|
||||
for result in matching_results:
|
||||
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
|
||||
final_tax = result.receipt.tax
|
||||
if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
|
||||
final_tax = result.tax_analysis["final_tax_amount"]
|
||||
|
||||
match_response = MatchResponse(
|
||||
receipt_id=result.receipt.id,
|
||||
transaction_id=result.transaction.id
|
||||
@@ -554,13 +571,14 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
|
||||
receipt_amount=result.receipt.amount,
|
||||
receipt_description=result.receipt.description,
|
||||
receipt_category=result.receipt.category,
|
||||
receipt_tax_amount=result.receipt.tax,
|
||||
receipt_tax_amount=final_tax,
|
||||
transaction_vendor=result.transaction.vendor
|
||||
if result.transaction
|
||||
else "",
|
||||
transaction_amount=result.transaction.amount
|
||||
if result.transaction
|
||||
else 0.0,
|
||||
tax_analysis=result.tax_analysis,
|
||||
)
|
||||
match_responses.append(match_response)
|
||||
|
||||
|
||||
+62
-1
@@ -131,6 +131,7 @@ class MatchResponse(BaseModel):
|
||||
receipt_tax_amount: float
|
||||
transaction_vendor: str
|
||||
transaction_amount: float
|
||||
tax_analysis: Optional[dict] = None
|
||||
|
||||
|
||||
class MatchingResponse(BaseModel):
|
||||
@@ -205,7 +206,67 @@ class DepreciationResponse(BaseModel):
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class CityInfo(BaseModel):
|
||||
"""City information from user tax info"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
state_id: int
|
||||
state_code: str
|
||||
country_id: int
|
||||
country_code: str
|
||||
latitude: Optional[str] = None
|
||||
longitude: Optional[str] = None
|
||||
|
||||
|
||||
class StateInfo(BaseModel):
|
||||
"""State/Province information from user tax info"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
country_id: int
|
||||
country_code: str
|
||||
state_code: str
|
||||
|
||||
|
||||
class CountryInfo(BaseModel):
|
||||
"""Country information from user tax info"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
iso3: str
|
||||
iso2: str
|
||||
phone_code: str
|
||||
capital: str
|
||||
currency: str
|
||||
native: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
subregion: Optional[str] = None
|
||||
emoji: Optional[str] = None
|
||||
emojiU: Optional[str] = None
|
||||
|
||||
|
||||
class UserTaxInfo(BaseModel):
|
||||
"""User tax information for location-based tax calculations"""
|
||||
|
||||
id: int
|
||||
user_id: int
|
||||
company_name: str
|
||||
tax_id: Optional[str] = ""
|
||||
tax_id_type: Optional[str] = "EIN"
|
||||
address_line_1: Optional[str] = ""
|
||||
address_line_2: Optional[str] = ""
|
||||
city: CityInfo
|
||||
state: StateInfo
|
||||
zip_postal_code: Optional[str] = ""
|
||||
country: CountryInfo
|
||||
include_on_invoices: Optional[int] = 1
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
|
||||
class MatchSpecificRequest(BaseModel):
|
||||
file_ids: List[str]
|
||||
categorization_id: str
|
||||
user_location: Optional[str] = "Canada"
|
||||
user_location: Optional[str] = "Canada" # Kept for backward compatibility
|
||||
user_tax_info: Optional[UserTaxInfo] = None
|
||||
|
||||
@@ -87,7 +87,13 @@ class LLMTaxAnalyzer:
|
||||
|
||||
# Extract location information
|
||||
receipt_location = self._extract_receipt_location(receipt)
|
||||
user_province = user_location.upper()
|
||||
|
||||
# Normalize user_location to province code (handle "Canada", "Ontario", "ON", etc.)
|
||||
user_province = self._normalize_location_to_province(user_location)
|
||||
|
||||
logger.info(
|
||||
f"Building tax analysis context - User Location: {user_location} → Province Code: {user_province}"
|
||||
)
|
||||
|
||||
# Build tax rates reference
|
||||
tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2)
|
||||
@@ -130,6 +136,47 @@ CCA DEPRECIATION RATES BY ASSET CLASS:
|
||||
"""
|
||||
return context
|
||||
|
||||
def _normalize_location_to_province(self, location: str) -> str:
|
||||
"""
|
||||
Normalize various location formats to province code.
|
||||
Handles: "ON", "Ontario", "Canada", etc.
|
||||
"""
|
||||
location_upper = location.upper().strip()
|
||||
|
||||
# Direct province code match
|
||||
if location_upper in self.PROVINCIAL_TAX_RATES:
|
||||
return location_upper
|
||||
|
||||
# Map full province names to codes
|
||||
province_name_map = {
|
||||
"ONTARIO": "ON",
|
||||
"QUEBEC": "QC",
|
||||
"BRITISH COLUMBIA": "BC",
|
||||
"ALBERTA": "AB",
|
||||
"SASKATCHEWAN": "SK",
|
||||
"MANITOBA": "MB",
|
||||
"NOVA SCOTIA": "NS",
|
||||
"NEW BRUNSWICK": "NB",
|
||||
"NEWFOUNDLAND AND LABRADOR": "NL",
|
||||
"NEWFOUNDLAND": "NL",
|
||||
"PRINCE EDWARD ISLAND": "PE",
|
||||
"NORTHWEST TERRITORIES": "NT",
|
||||
"NUNAVUT": "NU",
|
||||
"YUKON": "YT",
|
||||
}
|
||||
|
||||
if location_upper in province_name_map:
|
||||
return province_name_map[location_upper]
|
||||
|
||||
# Default to Ontario if country is Canada or unspecified
|
||||
if location_upper in ["CANADA", "CAN", "CA", ""]:
|
||||
logger.warning(f"Location '{location}' is too generic, defaulting to ON")
|
||||
return "ON"
|
||||
|
||||
# If nothing matches, default to Ontario
|
||||
logger.warning(f"Could not parse location '{location}', defaulting to ON")
|
||||
return "ON"
|
||||
|
||||
def _extract_receipt_location(self, receipt: Receipt) -> str:
|
||||
"""Extract and format receipt location information"""
|
||||
|
||||
@@ -276,6 +323,7 @@ b) **CCA Depreciation** (for tax purposes - Canada):
|
||||
Provide a structured JSON response with the following format:
|
||||
|
||||
{{
|
||||
"final_tax_amount": XX.XX,
|
||||
"sales_tax": {{
|
||||
"applicable_province": "XX",
|
||||
"applicable_rate": 0.XX,
|
||||
@@ -321,6 +369,8 @@ Provide a structured JSON response with the following format:
|
||||
"overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions"
|
||||
}}
|
||||
|
||||
**IMPORTANT**: The "final_tax_amount" field at the top level must contain the final calculated tax amount. This should be the calculated_tax from sales_tax analysis. If this is a meals & entertainment expense, ensure you return the FULL tax amount here (not the 50% adjusted amount).
|
||||
|
||||
**Critical Reminders**:
|
||||
- Sales tax uses RECEIPT location (or user location if receipt has none)
|
||||
- Depreciation ALWAYS uses USER location
|
||||
@@ -356,6 +406,7 @@ Provide a structured JSON response with the following format:
|
||||
"""Return fallback analysis if LLM fails"""
|
||||
return json.dumps(
|
||||
{
|
||||
"final_tax_amount": 0.0,
|
||||
"sales_tax": {
|
||||
"applicable_province": "ON",
|
||||
"applicable_rate": 0.13,
|
||||
@@ -424,6 +475,7 @@ Provide a structured JSON response with the following format:
|
||||
|
||||
# Return structured fallback
|
||||
return {
|
||||
"final_tax_amount": receipt.tax if receipt.tax else 0.0,
|
||||
"sales_tax": {
|
||||
"requires_review": True,
|
||||
"reason": "Failed to parse LLM response",
|
||||
|
||||
Reference in New Issue
Block a user