Add user location support and tax analysis enhancements

- Introduced user location extraction from user tax info for improved matching.
- Normalized user location to province codes for tax calculations.
- Updated MatchResponse schema to include tax analysis data.
- Enhanced LLMTaxAnalyzer to handle various location formats and provide fallback logic.
This commit is contained in:
bolade
2025-10-05 18:34:35 +01:00
parent c78c4c6fe9
commit c45e3fa791
3 changed files with 141 additions and 10 deletions
+26 -8
View File
@@ -5,10 +5,6 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from database import ( from database import (
DBReceipt, DBReceipt,
DBTransaction, DBTransaction,
@@ -16,6 +12,8 @@ from database import (
create_db_tables, create_db_tables,
db_dependency, db_dependency,
) )
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from schemas import ( from schemas import (
DocumentProcessResponse, DocumentProcessResponse,
DocumentUploadResponse, DocumentUploadResponse,
@@ -29,6 +27,7 @@ from schemas import (
from services.ai_rules import AIRule from services.ai_rules import AIRule
from services.document_processor import DocumentProcessor from services.document_processor import DocumentProcessor
from services.matching_engine import MatchingEngine from services.matching_engine import MatchingEngine
from sqlalchemy.orm import Session
create_db_tables() create_db_tables()
@@ -410,7 +409,7 @@ async def process_document(file_id: str, db: db_dependency):
confidence=receipt_data.get("confidence", 0.0), confidence=receipt_data.get("confidence", 0.0),
extraction_success=str(receipt_data.get("extraction_success", False)), extraction_success=str(receipt_data.get("extraction_success", False)),
error_message=receipt_data.get("error"), error_message=receipt_data.get("error"),
receipt_currency=receipt_data.get("currency") receipt_currency=receipt_data.get("currency"),
) )
# Add to database # Add to database
@@ -429,7 +428,7 @@ async def process_document(file_id: str, db: db_dependency):
category=receipt_data.get("category", ""), category=receipt_data.get("category", ""),
confidence=receipt_data.get("confidence", 0.0), confidence=receipt_data.get("confidence", 0.0),
error=receipt_data.get("error", None), error=receipt_data.get("error", None),
receipt_currency=receipt_data.get("currency") receipt_currency=receipt_data.get("currency"),
) )
except Exception as e: except Exception as e:
@@ -536,13 +535,31 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions" f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
) )
# Extract user location from user_tax_info if provided
user_location = request.user_location # Default/fallback
if request.user_tax_info:
# Use state_code from user_tax_info (e.g., "ON", "QC", "BC")
user_location = request.user_tax_info.state.state_code
logger.info(
f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})"
)
else:
logger.info(f"Using default/provided user_location: {user_location}")
try: try:
matching_results = matching_engine.process_matching(receipts, transactions, user_location=request.user_location) matching_results = matching_engine.process_matching(
receipts, transactions, user_location=user_location
)
logger.info(f"Matching completed, got {len(matching_results)} results") logger.info(f"Matching completed, got {len(matching_results)} results")
# Convert matching results to response format # Convert matching results to response format
match_responses = [] match_responses = []
for result in matching_results: for result in matching_results:
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
final_tax = result.receipt.tax
if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
final_tax = result.tax_analysis["final_tax_amount"]
match_response = MatchResponse( match_response = MatchResponse(
receipt_id=result.receipt.id, receipt_id=result.receipt.id,
transaction_id=result.transaction.id transaction_id=result.transaction.id
@@ -554,13 +571,14 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
receipt_amount=result.receipt.amount, receipt_amount=result.receipt.amount,
receipt_description=result.receipt.description, receipt_description=result.receipt.description,
receipt_category=result.receipt.category, receipt_category=result.receipt.category,
receipt_tax_amount=result.receipt.tax, receipt_tax_amount=final_tax,
transaction_vendor=result.transaction.vendor transaction_vendor=result.transaction.vendor
if result.transaction if result.transaction
else "", else "",
transaction_amount=result.transaction.amount transaction_amount=result.transaction.amount
if result.transaction if result.transaction
else 0.0, else 0.0,
tax_analysis=result.tax_analysis,
) )
match_responses.append(match_response) match_responses.append(match_response)
+62 -1
View File
@@ -131,6 +131,7 @@ class MatchResponse(BaseModel):
receipt_tax_amount: float receipt_tax_amount: float
transaction_vendor: str transaction_vendor: str
transaction_amount: float transaction_amount: float
tax_analysis: Optional[dict] = None
class MatchingResponse(BaseModel): class MatchingResponse(BaseModel):
@@ -205,7 +206,67 @@ class DepreciationResponse(BaseModel):
error: Optional[str] = None error: Optional[str] = None
class CityInfo(BaseModel):
"""City information from user tax info"""
id: int
name: str
state_id: int
state_code: str
country_id: int
country_code: str
latitude: Optional[str] = None
longitude: Optional[str] = None
class StateInfo(BaseModel):
"""State/Province information from user tax info"""
id: int
name: str
country_id: int
country_code: str
state_code: str
class CountryInfo(BaseModel):
"""Country information from user tax info"""
id: int
name: str
iso3: str
iso2: str
phone_code: str
capital: str
currency: str
native: Optional[str] = None
region: Optional[str] = None
subregion: Optional[str] = None
emoji: Optional[str] = None
emojiU: Optional[str] = None
class UserTaxInfo(BaseModel):
"""User tax information for location-based tax calculations"""
id: int
user_id: int
company_name: str
tax_id: Optional[str] = ""
tax_id_type: Optional[str] = "EIN"
address_line_1: Optional[str] = ""
address_line_2: Optional[str] = ""
city: CityInfo
state: StateInfo
zip_postal_code: Optional[str] = ""
country: CountryInfo
include_on_invoices: Optional[int] = 1
created_at: Optional[str] = None
updated_at: Optional[str] = None
class MatchSpecificRequest(BaseModel): class MatchSpecificRequest(BaseModel):
file_ids: List[str] file_ids: List[str]
categorization_id: str categorization_id: str
user_location: Optional[str] = "Canada" user_location: Optional[str] = "Canada" # Kept for backward compatibility
user_tax_info: Optional[UserTaxInfo] = None
+53 -1
View File
@@ -87,7 +87,13 @@ class LLMTaxAnalyzer:
# Extract location information # Extract location information
receipt_location = self._extract_receipt_location(receipt) receipt_location = self._extract_receipt_location(receipt)
user_province = user_location.upper()
# Normalize user_location to province code (handle "Canada", "Ontario", "ON", etc.)
user_province = self._normalize_location_to_province(user_location)
logger.info(
f"Building tax analysis context - User Location: {user_location} → Province Code: {user_province}"
)
# Build tax rates reference # Build tax rates reference
tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2) tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2)
@@ -130,6 +136,47 @@ CCA DEPRECIATION RATES BY ASSET CLASS:
""" """
return context return context
def _normalize_location_to_province(self, location: str) -> str:
"""
Normalize various location formats to province code.
Handles: "ON", "Ontario", "Canada", etc.
"""
location_upper = location.upper().strip()
# Direct province code match
if location_upper in self.PROVINCIAL_TAX_RATES:
return location_upper
# Map full province names to codes
province_name_map = {
"ONTARIO": "ON",
"QUEBEC": "QC",
"BRITISH COLUMBIA": "BC",
"ALBERTA": "AB",
"SASKATCHEWAN": "SK",
"MANITOBA": "MB",
"NOVA SCOTIA": "NS",
"NEW BRUNSWICK": "NB",
"NEWFOUNDLAND AND LABRADOR": "NL",
"NEWFOUNDLAND": "NL",
"PRINCE EDWARD ISLAND": "PE",
"NORTHWEST TERRITORIES": "NT",
"NUNAVUT": "NU",
"YUKON": "YT",
}
if location_upper in province_name_map:
return province_name_map[location_upper]
# Default to Ontario if country is Canada or unspecified
if location_upper in ["CANADA", "CAN", "CA", ""]:
logger.warning(f"Location '{location}' is too generic, defaulting to ON")
return "ON"
# If nothing matches, default to Ontario
logger.warning(f"Could not parse location '{location}', defaulting to ON")
return "ON"
def _extract_receipt_location(self, receipt: Receipt) -> str: def _extract_receipt_location(self, receipt: Receipt) -> str:
"""Extract and format receipt location information""" """Extract and format receipt location information"""
@@ -276,6 +323,7 @@ b) **CCA Depreciation** (for tax purposes - Canada):
Provide a structured JSON response with the following format: Provide a structured JSON response with the following format:
{{ {{
"final_tax_amount": XX.XX,
"sales_tax": {{ "sales_tax": {{
"applicable_province": "XX", "applicable_province": "XX",
"applicable_rate": 0.XX, "applicable_rate": 0.XX,
@@ -321,6 +369,8 @@ Provide a structured JSON response with the following format:
"overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions" "overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions"
}} }}
**IMPORTANT**: The "final_tax_amount" field at the top level must contain the final calculated tax amount. This should be the calculated_tax from sales_tax analysis. If this is a meals & entertainment expense, ensure you return the FULL tax amount here (not the 50% adjusted amount).
**Critical Reminders**: **Critical Reminders**:
- Sales tax uses RECEIPT location (or user location if receipt has none) - Sales tax uses RECEIPT location (or user location if receipt has none)
- Depreciation ALWAYS uses USER location - Depreciation ALWAYS uses USER location
@@ -356,6 +406,7 @@ Provide a structured JSON response with the following format:
"""Return fallback analysis if LLM fails""" """Return fallback analysis if LLM fails"""
return json.dumps( return json.dumps(
{ {
"final_tax_amount": 0.0,
"sales_tax": { "sales_tax": {
"applicable_province": "ON", "applicable_province": "ON",
"applicable_rate": 0.13, "applicable_rate": 0.13,
@@ -424,6 +475,7 @@ Provide a structured JSON response with the following format:
# Return structured fallback # Return structured fallback
return { return {
"final_tax_amount": receipt.tax if receipt.tax else 0.0,
"sales_tax": { "sales_tax": {
"requires_review": True, "requires_review": True,
"reason": "Failed to parse LLM response", "reason": "Failed to parse LLM response",