diff --git a/app/main.py b/app/main.py index 6010013..37f506d 100644 --- a/app/main.py +++ b/app/main.py @@ -5,10 +5,6 @@ import uuid from datetime import datetime from typing import List -from fastapi import FastAPI, File, Form, HTTPException, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from sqlalchemy.orm import Session - from database import ( DBReceipt, DBTransaction, @@ -16,6 +12,8 @@ from database import ( create_db_tables, db_dependency, ) +from fastapi import FastAPI, File, Form, HTTPException, UploadFile +from fastapi.middleware.cors import CORSMiddleware from schemas import ( DocumentProcessResponse, DocumentUploadResponse, @@ -29,6 +27,7 @@ from schemas import ( from services.ai_rules import AIRule from services.document_processor import DocumentProcessor from services.matching_engine import MatchingEngine +from sqlalchemy.orm import Session create_db_tables() @@ -410,7 +409,7 @@ async def process_document(file_id: str, db: db_dependency): confidence=receipt_data.get("confidence", 0.0), extraction_success=str(receipt_data.get("extraction_success", False)), error_message=receipt_data.get("error"), - receipt_currency=receipt_data.get("currency") + receipt_currency=receipt_data.get("currency"), ) # Add to database @@ -429,7 +428,7 @@ async def process_document(file_id: str, db: db_dependency): category=receipt_data.get("category", ""), confidence=receipt_data.get("confidence", 0.0), error=receipt_data.get("error", None), - receipt_currency=receipt_data.get("currency") + receipt_currency=receipt_data.get("currency"), ) except Exception as e: @@ -536,13 +535,31 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen f"Starting matching with {len(receipts)} receipts and {len(transactions)} transactions" ) + # Extract user location from user_tax_info if provided + user_location = request.user_location # Default/fallback + if request.user_tax_info: + # Use state_code from user_tax_info (e.g., "ON", "QC", "BC") + user_location = request.user_tax_info.state.state_code + logger.info( + f"Using location from user_tax_info: {user_location} ({request.user_tax_info.state.name}, {request.user_tax_info.country.name})" + ) + else: + logger.info(f"Using default/provided user_location: {user_location}") + try: - matching_results = matching_engine.process_matching(receipts, transactions, user_location=request.user_location) + matching_results = matching_engine.process_matching( + receipts, transactions, user_location=user_location + ) logger.info(f"Matching completed, got {len(matching_results)} results") # Convert matching results to response format match_responses = [] for result in matching_results: + # Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax + final_tax = result.receipt.tax + if result.tax_analysis and "final_tax_amount" in result.tax_analysis: + final_tax = result.tax_analysis["final_tax_amount"] + match_response = MatchResponse( receipt_id=result.receipt.id, transaction_id=result.transaction.id @@ -554,13 +571,14 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen receipt_amount=result.receipt.amount, receipt_description=result.receipt.description, receipt_category=result.receipt.category, - receipt_tax_amount=result.receipt.tax, + receipt_tax_amount=final_tax, transaction_vendor=result.transaction.vendor if result.transaction else "", transaction_amount=result.transaction.amount if result.transaction else 0.0, + tax_analysis=result.tax_analysis, ) match_responses.append(match_response) diff --git a/app/schemas.py b/app/schemas.py index 5173552..0910f5d 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -131,6 +131,7 @@ class MatchResponse(BaseModel): receipt_tax_amount: float transaction_vendor: str transaction_amount: float + tax_analysis: Optional[dict] = None class MatchingResponse(BaseModel): @@ -205,7 +206,67 @@ class DepreciationResponse(BaseModel): error: Optional[str] = None +class CityInfo(BaseModel): + """City information from user tax info""" + + id: int + name: str + state_id: int + state_code: str + country_id: int + country_code: str + latitude: Optional[str] = None + longitude: Optional[str] = None + + +class StateInfo(BaseModel): + """State/Province information from user tax info""" + + id: int + name: str + country_id: int + country_code: str + state_code: str + + +class CountryInfo(BaseModel): + """Country information from user tax info""" + + id: int + name: str + iso3: str + iso2: str + phone_code: str + capital: str + currency: str + native: Optional[str] = None + region: Optional[str] = None + subregion: Optional[str] = None + emoji: Optional[str] = None + emojiU: Optional[str] = None + + +class UserTaxInfo(BaseModel): + """User tax information for location-based tax calculations""" + + id: int + user_id: int + company_name: str + tax_id: Optional[str] = "" + tax_id_type: Optional[str] = "EIN" + address_line_1: Optional[str] = "" + address_line_2: Optional[str] = "" + city: CityInfo + state: StateInfo + zip_postal_code: Optional[str] = "" + country: CountryInfo + include_on_invoices: Optional[int] = 1 + created_at: Optional[str] = None + updated_at: Optional[str] = None + + class MatchSpecificRequest(BaseModel): file_ids: List[str] categorization_id: str - user_location: Optional[str] = "Canada" + user_location: Optional[str] = "Canada" # Kept for backward compatibility + user_tax_info: Optional[UserTaxInfo] = None diff --git a/app/services/llm_tax_analyzer.py b/app/services/llm_tax_analyzer.py index fee9cd0..2c2d37f 100644 --- a/app/services/llm_tax_analyzer.py +++ b/app/services/llm_tax_analyzer.py @@ -87,7 +87,13 @@ class LLMTaxAnalyzer: # Extract location information receipt_location = self._extract_receipt_location(receipt) - user_province = user_location.upper() + + # Normalize user_location to province code (handle "Canada", "Ontario", "ON", etc.) + user_province = self._normalize_location_to_province(user_location) + + logger.info( + f"Building tax analysis context - User Location: {user_location} → Province Code: {user_province}" + ) # Build tax rates reference tax_rates_info = json.dumps(self.PROVINCIAL_TAX_RATES, indent=2) @@ -130,6 +136,47 @@ CCA DEPRECIATION RATES BY ASSET CLASS: """ return context + def _normalize_location_to_province(self, location: str) -> str: + """ + Normalize various location formats to province code. + Handles: "ON", "Ontario", "Canada", etc. + """ + location_upper = location.upper().strip() + + # Direct province code match + if location_upper in self.PROVINCIAL_TAX_RATES: + return location_upper + + # Map full province names to codes + province_name_map = { + "ONTARIO": "ON", + "QUEBEC": "QC", + "BRITISH COLUMBIA": "BC", + "ALBERTA": "AB", + "SASKATCHEWAN": "SK", + "MANITOBA": "MB", + "NOVA SCOTIA": "NS", + "NEW BRUNSWICK": "NB", + "NEWFOUNDLAND AND LABRADOR": "NL", + "NEWFOUNDLAND": "NL", + "PRINCE EDWARD ISLAND": "PE", + "NORTHWEST TERRITORIES": "NT", + "NUNAVUT": "NU", + "YUKON": "YT", + } + + if location_upper in province_name_map: + return province_name_map[location_upper] + + # Default to Ontario if country is Canada or unspecified + if location_upper in ["CANADA", "CAN", "CA", ""]: + logger.warning(f"Location '{location}' is too generic, defaulting to ON") + return "ON" + + # If nothing matches, default to Ontario + logger.warning(f"Could not parse location '{location}', defaulting to ON") + return "ON" + def _extract_receipt_location(self, receipt: Receipt) -> str: """Extract and format receipt location information""" @@ -276,6 +323,7 @@ b) **CCA Depreciation** (for tax purposes - Canada): Provide a structured JSON response with the following format: {{ + "final_tax_amount": XX.XX, "sales_tax": {{ "applicable_province": "XX", "applicable_rate": 0.XX, @@ -321,6 +369,8 @@ Provide a structured JSON response with the following format: "overall_assessment": "Comprehensive summary: which rules applied, why, what location used for what purpose, and any required actions" }} +**IMPORTANT**: The "final_tax_amount" field at the top level must contain the final calculated tax amount. This should be the calculated_tax from sales_tax analysis. If this is a meals & entertainment expense, ensure you return the FULL tax amount here (not the 50% adjusted amount). + **Critical Reminders**: - Sales tax uses RECEIPT location (or user location if receipt has none) - Depreciation ALWAYS uses USER location @@ -356,6 +406,7 @@ Provide a structured JSON response with the following format: """Return fallback analysis if LLM fails""" return json.dumps( { + "final_tax_amount": 0.0, "sales_tax": { "applicable_province": "ON", "applicable_rate": 0.13, @@ -424,6 +475,7 @@ Provide a structured JSON response with the following format: # Return structured fallback return { + "final_tax_amount": receipt.tax if receipt.tax else 0.0, "sales_tax": { "requires_review": True, "reason": "Failed to parse LLM response",