Refactor document processing endpoint to accept file_id as a path parameter and update related logic; modify DocumentProcessRequest to make file_id optional; add clarification to tax processing rules in DocumentProcessor.

This commit is contained in:
bolade
2025-10-07 12:36:04 +01:00
parent 659ca4ff15
commit b2bf631448
4 changed files with 26 additions and 25 deletions
+9 -9
View File
@@ -365,11 +365,11 @@ async def upload_multiple_documents(
@app.post( @app.post(
"/process", "/process/{file_id}",
response_model=DocumentProcessResponse, response_model=DocumentProcessResponse,
tags=["Document Processing"], tags=["Document Processing"],
) )
async def process_document(request: DocumentProcessRequest, db: db_dependency): async def process_document(file_id: str, request: DocumentProcessRequest, db: db_dependency):
""" """
Process a previously uploaded document to extract receipt information. Process a previously uploaded document to extract receipt information.
@@ -381,10 +381,10 @@ async def process_document(request: DocumentProcessRequest, db: db_dependency):
""" """
try: try:
# Get file info from database # Get file info from database
db_uploaded_file = get_uploaded_file_from_db(db, request.file_id) db_uploaded_file = get_uploaded_file_from_db(db, file_id)
if not db_uploaded_file: if not db_uploaded_file:
raise HTTPException( raise HTTPException(
status_code=404, detail=f"File {request.file_id} not found" status_code=404, detail=f"File {file_id} not found"
) )
# Process the file using the stored file path # Process the file using the stored file path
@@ -406,8 +406,8 @@ async def process_document(request: DocumentProcessRequest, db: db_dependency):
# Create database receipt object # Create database receipt object
db_receipt = DBReceipt( db_receipt = DBReceipt(
receipt_id=f"receipt_{request.file_id}", receipt_id=f"receipt_{file_id}",
file_id=request.file_id, file_id=file_id,
amount=receipt_data.get("total_amount", 0.0), amount=receipt_data.get("total_amount", 0.0),
date=receipt_date, date=receipt_date,
vendor=receipt_data.get("vendor", ""), vendor=receipt_data.get("vendor", ""),
@@ -433,7 +433,7 @@ async def process_document(request: DocumentProcessRequest, db: db_dependency):
db.commit() db.commit()
return DocumentProcessResponse( return DocumentProcessResponse(
file_id=request.file_id, file_id=file_id,
receipt_id=db_receipt.receipt_id, receipt_id=db_receipt.receipt_id,
extraction_success=receipt_data.get("extraction_success", False), extraction_success=receipt_data.get("extraction_success", False),
vendor=receipt_data.get("vendor", ""), vendor=receipt_data.get("vendor", ""),
@@ -579,8 +579,8 @@ async def match_specific_receipts(request: MatchSpecificRequest, db: db_dependen
for result in matching_results: for result in matching_results:
# Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax # Get final tax amount from LLM analysis if available, otherwise use receipt's stated tax
final_tax = result.receipt.tax final_tax = result.receipt.tax
if result.tax_analysis and "final_tax_amount" in result.tax_analysis: # if result.tax_analysis and "final_tax_amount" in result.tax_analysis:
final_tax = result.tax_analysis["final_tax_amount"] # final_tax = result.tax_analysis["final_tax_amount"]
match_response = MatchResponse( match_response = MatchResponse(
receipt_id=result.receipt.id, receipt_id=result.receipt.id,
+1 -1
View File
@@ -161,7 +161,7 @@ class DocumentUploadResponse(BaseModel):
class DocumentProcessRequest(BaseModel): class DocumentProcessRequest(BaseModel):
file_id: str file_id: Optional[str] = None
user_location: Optional[str] = ( user_location: Optional[str] = (
None # Format: "State/Province, Country" (e.g., "Ontario, Canada") None # Format: "State/Province, Country" (e.g., "Ontario, Canada")
) )
+1
View File
@@ -108,6 +108,7 @@ class DocumentProcessor:
* For other locations, estimate based on typical rates * For other locations, estimate based on typical rates
- Store calculated tax in "calculated_tax" field (set to null if tax clearly shown) - Store calculated tax in "calculated_tax" field (set to null if tax clearly shown)
- If tax is clearly shown on receipt, use that value for tax_amount and set calculated_tax to null - If tax is clearly shown on receipt, use that value for tax_amount and set calculated_tax to null
- If tax is clearly shown on the receipt as 0%, set tax_amount to 0 and calculated_tax to null
DEPRECIATION RULES: DEPRECIATION RULES:
- Determine if item is a depreciable asset (vehicles, machinery, equipment, computers, furniture, buildings) - Determine if item is a depreciable asset (vehicles, machinery, equipment, computers, furniture, buildings)
+15 -15
View File
@@ -51,22 +51,22 @@ class MatchingEngine:
enhanced_matches = self._apply_manual_tax_analysis( enhanced_matches = self._apply_manual_tax_analysis(
ai_matches, user_location ai_matches, user_location
) )
else: # else:
# Use LLM-based tax analysis in a SINGLE batch call # # Use LLM-based tax analysis in a SINGLE batch call
try: # try:
enhanced_matches = ( # enhanced_matches = (
self.llm_tax_analyzer.analyze_and_apply_tax_rules_batch( # self.llm_tax_analyzer.analyze_and_apply_tax_rules_batch(
ai_matches, user_location # ai_matches, user_location
) # )
) # )
except Exception as e: # except Exception as e:
# If batch LLM analysis fails, log it and continue with matches as-is # # If batch LLM analysis fails, log it and continue with matches as-is
import logging # import logging
logging.error(f"Batch LLM tax analysis failed: {str(e)}") # logging.error(f"Batch LLM tax analysis failed: {str(e)}")
for match in ai_matches: # for match in ai_matches:
match.match_reason += " (Note: Advanced tax analysis unavailable)" # match.match_reason += " (Note: Advanced tax analysis unavailable)"
enhanced_matches = ai_matches # enhanced_matches = ai_matches
return enhanced_matches return enhanced_matches