Add user location support and tax analysis enhancements

- Introduced user location extraction from user tax info for improved matching.
- Normalized user location to province codes for tax calculations.
- Updated MatchResponse schema to include tax analysis data.
- Enhanced LLMTaxAnalyzer to handle various location formats and provide fallback logic.
This commit is contained in:
bolade
2025-10-05 18:34:35 +01:00
parent c78c4c6fe9
commit c45e3fa791
3 changed files with 141 additions and 10 deletions
+26 -8
View File
@@ -5,10 +5,6 @@ import uuid
from datetime import datetime
from typing import List
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from database import (
DBReceipt,
DBTransaction,
@@ -16,6 +12,8 @@ from database import (
create_db_tables,
db_dependency,
)
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from schemas import (
DocumentProcessResponse,
DocumentUploadResponse,
@@ -29,6 +27,7 @@ from schemas import (
from services.ai_rules import AIRule
from services.document_processor import DocumentProcessor
from services.matching_engine import MatchingEngine
from sqlalchemy.orm import Session
create_db_tables()
@@ -410,7 +409,7 @@ async def process_document(file_id: str, db: db_dependency):
confidence=receipt_data.get("confidence", 0.0),
extraction_success=str(receipt_data.get("extraction_success", False)),
error_message=receipt_data.get("error"),
receipt_currency=receipt_data.get("currency")
receipt_currency=receipt_data.get("currency"),
)
# Add to database
@@ -429,7 +428,7 @@ async def process_document(file_id: str, db: db_dependency):
category=receipt_data.get("category", ""),
confidence=receipt_data.get("confidence", 0.0),
error=receipt_data.get("error", None),
receipt_currency=receipt_data.get("currency")
receipt_currency=receipt_data.get("currency"),
)
except Exception as e:
@@ -536,13 +535,31 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
)
# Extract user location from user_tax_info if provided
user_location = request.user_location # Default/fallback
if request.user_tax_info:
# Use state_code from user_tax_info (e.g., "ON", "QC", "BC")
user_location = request.user_tax_info.state.state_code
logger.info(
f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})"
)
else:
logger.info(f"Using default/provided user_location: {user_location}")
try:
matching_results = matching_engine.process_matching(receipts, transactions, user_location=request.user_location)
matching_results = matching_engine.process_matching(
receipts, transactions, user_location=user_location
)
logger.info(f"Matching completed, got {len(matching_results)} results")
# Convert matching results to response format
match_responses = []
for result in matching_results:
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
final_tax = result.receipt.tax
if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
final_tax = result.tax_analysis["final_tax_amount"]
match_response = MatchResponse(
receipt_id=result.receipt.id,
transaction_id=result.transaction.id
@@ -554,13 +571,14 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
receipt_amount=result.receipt.amount,
receipt_description=result.receipt.description,
receipt_category=result.receipt.category,
receipt_tax_amount=result.receipt.tax,
receipt_tax_amount=final_tax,
transaction_vendor=result.transaction.vendor
if result.transaction
else "",
transaction_amount=result.transaction.amount
if result.transaction
else 0.0,
tax_analysis=result.tax_analysis,
)
match_responses.append(match_response)
+62 -1
View File
@@ -131,6 +131,7 @@ class MatchResponse(BaseModel):
receipt_tax_amount: float
transaction_vendor: str
transaction_amount: float
tax_analysis: Optional[dict] = None
class MatchingResponse(BaseModel):
@@ -205,7 +206,67 @@ class DepreciationResponse(BaseModel):
error: Optional[str] = None
class CityInfo(BaseModel):
"""City information from user tax info"""
id: int
name: str
state_id: int
state_code: str
country_id: int
country_code: str
latitude: Optional[str] = None
longitude: Optional[str] = None
class StateInfo(BaseModel):
"""State/Province information from user tax info"""
id: int
name: str
country_id: int
country_code: str
state_code: str
class CountryInfo(BaseModel):
"""Country information from user tax info"""
id: int
name: str
iso3: str
iso2: str
phone_code: str
capital: str
currency: str
native: Optional[str] = None
region: Optional[str] = None
subregion: Optional[str] = None
emoji: Optional[str] = None
emojiU: Optional[str] = None
class UserTaxInfo(BaseModel):
"""User tax information for location-based tax calculations"""
id: int
user_id: int
company_name: str
tax_id: Optional[str] = ""
tax_id_type: Optional[str] = "EIN"
address_line_1: Optional[str] = ""
address_line_2: Optional[str] = ""
city: CityInfo
state: StateInfo
zip_postal_code: Optional[str] = ""
country: CountryInfo
include_on_invoices: Optional[int] = 1
created_at: Optional[str] = None
updated_at: Optional[str] = None
class MatchSpecificRequest(BaseModel):
file_ids: List[str]
categorization_id: str
user_location: Optional[str] = "Canada"
user_location: Optional[str] = "Canada" # Kept for backward compatibility
user_tax_info: Optional[UserTaxInfo] = None
+53 -1
View File
@@ -87,7 +87,13 @@ class LLMTaxAnalyzer:
# Extract location information
receipt_location = self._extract_receipt_location(receipt)
user_province = user_location.upper()
# Normalize user_location to province code (handle "Canada", "Ontario", "ON", etc.)
user_province = self._normalize_location_to_province(user_location)
logger.info(
f"Building tax analysis context - User Location: {user_location} → Province Code: {user_province}"
)
# Build tax rates reference
tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2)
@@ -130,6 +136,47 @@ CCA DEPRECIATION RATES BY ASSET CLASS:
"""
return context
def _normalize_location_to_province(self, location: str) -> str:
"""
Normalize various location formats to province code.
Handles: "ON", "Ontario", "Canada", etc.
"""
location_upper = location.upper().strip()
# Direct province code match
if location_upper in self.PROVINCIAL_TAX_RATES:
return location_upper
# Map full province names to codes
province_name_map = {
"ONTARIO": "ON",
"QUEBEC": "QC",
"BRITISH COLUMBIA": "BC",
"ALBERTA": "AB",
"SASKATCHEWAN": "SK",
"MANITOBA": "MB",
"NOVA SCOTIA": "NS",
"NEW BRUNSWICK": "NB",
"NEWFOUNDLAND AND LABRADOR": "NL",
"NEWFOUNDLAND": "NL",
"PRINCE EDWARD ISLAND": "PE",
"NORTHWEST TERRITORIES": "NT",
"NUNAVUT": "NU",
"YUKON": "YT",
}
if location_upper in province_name_map:
return province_name_map[location_upper]
# Default to Ontario if country is Canada or unspecified
if location_upper in ["CANADA", "CAN", "CA", ""]:
logger.warning(f"Location '{location}' is too generic, defaulting to ON")
return "ON"
# If nothing matches, default to Ontario
logger.warning(f"Could not parse location '{location}', defaulting to ON")
return "ON"
def _extract_receipt_location(self, receipt: Receipt) -> str:
"""Extract and format receipt location information"""
@@ -276,6 +323,7 @@ b) **CCA Depreciation** (for tax purposes - Canada):
Provide a structured JSON response with the following format:
{{
"final_tax_amount": XX.XX,
"sales_tax": {{
"applicable_province": "XX",
"applicable_rate": 0.XX,
@@ -321,6 +369,8 @@ Provide a structured JSON response with the following format:
"overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions"
}}
**IMPORTANT**: The "final_tax_amount" field at the top level must contain the final calculated tax amount. This should be the calculated_tax from sales_tax analysis. If this is a meals & entertainment expense, ensure you return the FULL tax amount here (not the 50% adjusted amount).
**Critical Reminders**:
- Sales tax uses RECEIPT location (or user location if receipt has none)
- Depreciation ALWAYS uses USER location
@@ -356,6 +406,7 @@ Provide a structured JSON response with the following format:
"""Return fallback analysis if LLM fails"""
return json.dumps(
{
"final_tax_amount": 0.0,
"sales_tax": {
"applicable_province": "ON",
"applicable_rate": 0.13,
@@ -424,6 +475,7 @@ Provide a structured JSON response with the following format:
# Return structured fallback
return {
"final_tax_amount": receipt.tax if receipt.tax else 0.0,
"sales_tax": {
"requires_review": True,
"reason": "Failed to parse LLM response",