Add user location support and tax analysis enhancements
- Introduced user location extraction from user tax info for improved matching. - Normalized user location to province codes for tax calculations. - Updated MatchResponse schema to include tax analysis data. - Enhanced LLMTaxAnalyzer to handle various location formats and provide fallback logic.
This commit is contained in:
+26
-8
@@ -5,10 +5,6 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from database import (
|
from database import (
|
||||||
DBReceipt,
|
DBReceipt,
|
||||||
DBTransaction,
|
DBTransaction,
|
||||||
@@ -16,6 +12,8 @@ from database import (
|
|||||||
create_db_tables,
|
create_db_tables,
|
||||||
db_dependency,
|
db_dependency,
|
||||||
)
|
)
|
||||||
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from schemas import (
|
from schemas import (
|
||||||
DocumentProcessResponse,
|
DocumentProcessResponse,
|
||||||
DocumentUploadResponse,
|
DocumentUploadResponse,
|
||||||
@@ -29,6 +27,7 @@ from schemas import (
|
|||||||
from services.ai_rules import AIRule
|
from services.ai_rules import AIRule
|
||||||
from services.document_processor import DocumentProcessor
|
from services.document_processor import DocumentProcessor
|
||||||
from services.matching_engine import MatchingEngine
|
from services.matching_engine import MatchingEngine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
create_db_tables()
|
create_db_tables()
|
||||||
|
|
||||||
@@ -410,7 +409,7 @@ async def process_document(file_id: str, db: db_dependency):
|
|||||||
confidence=receipt_data.get("confidence", 0.0),
|
confidence=receipt_data.get("confidence", 0.0),
|
||||||
extraction_success=str(receipt_data.get("extraction_success", False)),
|
extraction_success=str(receipt_data.get("extraction_success", False)),
|
||||||
error_message=receipt_data.get("error"),
|
error_message=receipt_data.get("error"),
|
||||||
receipt_currency=receipt_data.get("currency")
|
receipt_currency=receipt_data.get("currency"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to database
|
# Add to database
|
||||||
@@ -429,7 +428,7 @@ async def process_document(file_id: str, db: db_dependency):
|
|||||||
category=receipt_data.get("category", ""),
|
category=receipt_data.get("category", ""),
|
||||||
confidence=receipt_data.get("confidence", 0.0),
|
confidence=receipt_data.get("confidence", 0.0),
|
||||||
error=receipt_data.get("error", None),
|
error=receipt_data.get("error", None),
|
||||||
receipt_currency=receipt_data.get("currency")
|
receipt_currency=receipt_data.get("currency"),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -536,13 +535,31 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
|
|||||||
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
|
f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract user location from user_tax_info if provided
|
||||||
|
user_location = request.user_location # Default/fallback
|
||||||
|
if request.user_tax_info:
|
||||||
|
# Use state_code from user_tax_info (e.g., "ON", "QC", "BC")
|
||||||
|
user_location = request.user_tax_info.state.state_code
|
||||||
|
logger.info(
|
||||||
|
f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Using default/provided user_location: {user_location}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
matching_results = matching_engine.process_matching(receipts, transactions, user_location=request.user_location)
|
matching_results = matching_engine.process_matching(
|
||||||
|
receipts, transactions, user_location=user_location
|
||||||
|
)
|
||||||
logger.info(f"Matching completed, got {len(matching_results)} results")
|
logger.info(f"Matching completed, got {len(matching_results)} results")
|
||||||
|
|
||||||
# Convert matching results to response format
|
# Convert matching results to response format
|
||||||
match_responses = []
|
match_responses = []
|
||||||
for result in matching_results:
|
for result in matching_results:
|
||||||
|
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
|
||||||
|
final_tax = result.receipt.tax
|
||||||
|
if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
|
||||||
|
final_tax = result.tax_analysis["final_tax_amount"]
|
||||||
|
|
||||||
match_response = MatchResponse(
|
match_response = MatchResponse(
|
||||||
receipt_id=result.receipt.id,
|
receipt_id=result.receipt.id,
|
||||||
transaction_id=result.transaction.id
|
transaction_id=result.transaction.id
|
||||||
@@ -554,13 +571,14 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
|
|||||||
receipt_amount=result.receipt.amount,
|
receipt_amount=result.receipt.amount,
|
||||||
receipt_description=result.receipt.description,
|
receipt_description=result.receipt.description,
|
||||||
receipt_category=result.receipt.category,
|
receipt_category=result.receipt.category,
|
||||||
receipt_tax_amount=result.receipt.tax,
|
receipt_tax_amount=final_tax,
|
||||||
transaction_vendor=result.transaction.vendor
|
transaction_vendor=result.transaction.vendor
|
||||||
if result.transaction
|
if result.transaction
|
||||||
else "",
|
else "",
|
||||||
transaction_amount=result.transaction.amount
|
transaction_amount=result.transaction.amount
|
||||||
if result.transaction
|
if result.transaction
|
||||||
else 0.0,
|
else 0.0,
|
||||||
|
tax_analysis=result.tax_analysis,
|
||||||
)
|
)
|
||||||
match_responses.append(match_response)
|
match_responses.append(match_response)
|
||||||
|
|
||||||
|
|||||||
+62
-1
@@ -131,6 +131,7 @@ class MatchResponse(BaseModel):
|
|||||||
receipt_tax_amount: float
|
receipt_tax_amount: float
|
||||||
transaction_vendor: str
|
transaction_vendor: str
|
||||||
transaction_amount: float
|
transaction_amount: float
|
||||||
|
tax_analysis: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class MatchingResponse(BaseModel):
|
class MatchingResponse(BaseModel):
|
||||||
@@ -205,7 +206,67 @@ class DepreciationResponse(BaseModel):
|
|||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CityInfo(BaseModel):
|
||||||
|
"""City information from user tax info"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
state_id: int
|
||||||
|
state_code: str
|
||||||
|
country_id: int
|
||||||
|
country_code: str
|
||||||
|
latitude: Optional[str] = None
|
||||||
|
longitude: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StateInfo(BaseModel):
|
||||||
|
"""State/Province information from user tax info"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
country_id: int
|
||||||
|
country_code: str
|
||||||
|
state_code: str
|
||||||
|
|
||||||
|
|
||||||
|
class CountryInfo(BaseModel):
|
||||||
|
"""Country information from user tax info"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
iso3: str
|
||||||
|
iso2: str
|
||||||
|
phone_code: str
|
||||||
|
capital: str
|
||||||
|
currency: str
|
||||||
|
native: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
subregion: Optional[str] = None
|
||||||
|
emoji: Optional[str] = None
|
||||||
|
emojiU: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserTaxInfo(BaseModel):
|
||||||
|
"""User tax information for location-based tax calculations"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
company_name: str
|
||||||
|
tax_id: Optional[str] = ""
|
||||||
|
tax_id_type: Optional[str] = "EIN"
|
||||||
|
address_line_1: Optional[str] = ""
|
||||||
|
address_line_2: Optional[str] = ""
|
||||||
|
city: CityInfo
|
||||||
|
state: StateInfo
|
||||||
|
zip_postal_code: Optional[str] = ""
|
||||||
|
country: CountryInfo
|
||||||
|
include_on_invoices: Optional[int] = 1
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
updated_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class MatchSpecificRequest(BaseModel):
|
class MatchSpecificRequest(BaseModel):
|
||||||
file_ids: List[str]
|
file_ids: List[str]
|
||||||
categorization_id: str
|
categorization_id: str
|
||||||
user_location: Optional[str] = "Canada"
|
user_location: Optional[str] = "Canada" # Kept for backward compatibility
|
||||||
|
user_tax_info: Optional[UserTaxInfo] = None
|
||||||
|
|||||||
@@ -87,7 +87,13 @@ class LLMTaxAnalyzer:
|
|||||||
|
|
||||||
# Extract location information
|
# Extract location information
|
||||||
receipt_location = self._extract_receipt_location(receipt)
|
receipt_location = self._extract_receipt_location(receipt)
|
||||||
user_province = user_location.upper()
|
|
||||||
|
# Normalize user_location to province code (handle "Canada", "Ontario", "ON", etc.)
|
||||||
|
user_province = self._normalize_location_to_province(user_location)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Building tax analysis context - User Location: {user_location} → Province Code: {user_province}"
|
||||||
|
)
|
||||||
|
|
||||||
# Build tax rates reference
|
# Build tax rates reference
|
||||||
tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2)
|
tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2)
|
||||||
@@ -130,6 +136,47 @@ CCA DEPRECIATION RATES BY ASSET CLASS:
|
|||||||
"""
|
"""
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
def _normalize_location_to_province(self, location: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize various location formats to province code.
|
||||||
|
Handles: "ON", "Ontario", "Canada", etc.
|
||||||
|
"""
|
||||||
|
location_upper = location.upper().strip()
|
||||||
|
|
||||||
|
# Direct province code match
|
||||||
|
if location_upper in self.PROVINCIAL_TAX_RATES:
|
||||||
|
return location_upper
|
||||||
|
|
||||||
|
# Map full province names to codes
|
||||||
|
province_name_map = {
|
||||||
|
"ONTARIO": "ON",
|
||||||
|
"QUEBEC": "QC",
|
||||||
|
"BRITISH COLUMBIA": "BC",
|
||||||
|
"ALBERTA": "AB",
|
||||||
|
"SASKATCHEWAN": "SK",
|
||||||
|
"MANITOBA": "MB",
|
||||||
|
"NOVA SCOTIA": "NS",
|
||||||
|
"NEW BRUNSWICK": "NB",
|
||||||
|
"NEWFOUNDLAND AND LABRADOR": "NL",
|
||||||
|
"NEWFOUNDLAND": "NL",
|
||||||
|
"PRINCE EDWARD ISLAND": "PE",
|
||||||
|
"NORTHWEST TERRITORIES": "NT",
|
||||||
|
"NUNAVUT": "NU",
|
||||||
|
"YUKON": "YT",
|
||||||
|
}
|
||||||
|
|
||||||
|
if location_upper in province_name_map:
|
||||||
|
return province_name_map[location_upper]
|
||||||
|
|
||||||
|
# Default to Ontario if country is Canada or unspecified
|
||||||
|
if location_upper in ["CANADA", "CAN", "CA", ""]:
|
||||||
|
logger.warning(f"Location '{location}' is too generic, defaulting to ON")
|
||||||
|
return "ON"
|
||||||
|
|
||||||
|
# If nothing matches, default to Ontario
|
||||||
|
logger.warning(f"Could not parse location '{location}', defaulting to ON")
|
||||||
|
return "ON"
|
||||||
|
|
||||||
def _extract_receipt_location(self, receipt: Receipt) -> str:
|
def _extract_receipt_location(self, receipt: Receipt) -> str:
|
||||||
"""Extract and format receipt location information"""
|
"""Extract and format receipt location information"""
|
||||||
|
|
||||||
@@ -276,6 +323,7 @@ b) **CCA Depreciation** (for tax purposes - Canada):
|
|||||||
Provide a structured JSON response with the following format:
|
Provide a structured JSON response with the following format:
|
||||||
|
|
||||||
{{
|
{{
|
||||||
|
"final_tax_amount": XX.XX,
|
||||||
"sales_tax": {{
|
"sales_tax": {{
|
||||||
"applicable_province": "XX",
|
"applicable_province": "XX",
|
||||||
"applicable_rate": 0.XX,
|
"applicable_rate": 0.XX,
|
||||||
@@ -321,6 +369,8 @@ Provide a structured JSON response with the following format:
|
|||||||
"overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions"
|
"overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions"
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
**IMPORTANT**: The "final_tax_amount" field at the top level must contain the final calculated tax amount. This should be the calculated_tax from sales_tax analysis. If this is a meals & entertainment expense, ensure you return the FULL tax amount here (not the 50% adjusted amount).
|
||||||
|
|
||||||
**Critical Reminders**:
|
**Critical Reminders**:
|
||||||
- Sales tax uses RECEIPT location (or user location if receipt has none)
|
- Sales tax uses RECEIPT location (or user location if receipt has none)
|
||||||
- Depreciation ALWAYS uses USER location
|
- Depreciation ALWAYS uses USER location
|
||||||
@@ -356,6 +406,7 @@ Provide a structured JSON response with the following format:
|
|||||||
"""Return fallback analysis if LLM fails"""
|
"""Return fallback analysis if LLM fails"""
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{
|
{
|
||||||
|
"final_tax_amount": 0.0,
|
||||||
"sales_tax": {
|
"sales_tax": {
|
||||||
"applicable_province": "ON",
|
"applicable_province": "ON",
|
||||||
"applicable_rate": 0.13,
|
"applicable_rate": 0.13,
|
||||||
@@ -424,6 +475,7 @@ Provide a structured JSON response with the following format:
|
|||||||
|
|
||||||
# Return structured fallback
|
# Return structured fallback
|
||||||
return {
|
return {
|
||||||
|
"final_tax_amount": receipt.tax if receipt.tax else 0.0,
|
||||||
"sales_tax": {
|
"sales_tax": {
|
||||||
"requires_review": True,
|
"requires_review": True,
|
||||||
"reason": "Failed to parse LLM response",
|
"reason": "Failed to parse LLM response",
|
||||||
|
|||||||
Reference in New Issue
Block a user