feat: Enhance medical report generation with new features and improved data handling

- Added body fat percentage input and optional muscle oxygenation CSV upload in the upload form.
- Implemented TSI chart generation based on muscle oxygenation data.
- Updated report generation to include metabolism and fuel source charts.
- Refactored context generation to eliminate reliance on SECA data, using patient info directly instead.
- Improved error handling and logging for graph generation processes.
- Enhanced HTML templates for better user experience and functionality.
This commit is contained in:
bolade
2025-11-18 16:57:39 +01:00
parent 83f50882e2
commit 7e985c497e
12 changed files with 1256 additions and 262 deletions
+211 -99
View File
@@ -13,10 +13,13 @@ from pathlib import Path
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from jinja2 import Environment, FileSystemLoader
from pydantic import BaseModel
from starlette.middleware.sessions import SessionMiddleware
from services.report_generator import ReportGeneratorService
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request as StarletteRequest
app = FastAPI(
title="Medical Report Generation API",
@@ -25,7 +28,32 @@ app = FastAPI(
)
# Add session middleware
app.add_middleware(SessionMiddleware, secret_key=os.getenv("SECRET_KEY", "your-secret-key-change-in-production"))
app.add_middleware(
SessionMiddleware,
secret_key=os.getenv("SECRET_KEY", "your-secret-key-change-in-production"),
)
# Add security headers middleware to allow external scripts
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: StarletteRequest, call_next):
response = await call_next(request)
# Allow external scripts and styles (for Tailwind CDN)
# Only add CSP for HTML responses
content_type = response.headers.get("content-type", "").lower()
if "text/html" in content_type:
response.headers["Content-Security-Policy"] = (
"default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https:; style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com https:; img-src 'self' data: https:;"
)
return response
app.add_middleware(SecurityHeadersMiddleware)
# Mount static files (if static directory exists)
static_dir = Path("static")
if static_dir.exists():
app.mount("/static", StaticFiles(directory="static"), name="static")
# Setup templates
jinja_env = Environment(loader=FileSystemLoader("app/templates"))
@@ -59,13 +87,15 @@ def render_template(template_name: str, context: dict) -> HTMLResponse:
"""Helper function to render Jinja2 templates"""
template = jinja_env.get_template(template_name)
html_content = template.render(**context)
return HTMLResponse(content=html_content)
return HTMLResponse(content=html_content, media_type="text/html")
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
"""Root endpoint - Upload form page"""
return render_template("upload.html", {"request": request, "session": request.session})
return render_template(
"upload.html", {"request": request, "session": request.session}
)
@app.post("/upload")
@@ -77,56 +107,73 @@ async def upload_files(
height: str = Form(...),
weight: str = Form(...),
gender: str = Form(...),
fat_percentage: float = Form(...),
focus: str = Form(default="Endurance"),
session_id: str = Form(default="default"),
spirometry_pdf: UploadFile = File(...),
pnoe_csv: UploadFile = File(...),
seca_excel: UploadFile = File(...),
oxygenation_csv: UploadFile = File(None),
):
"""Handle file upload and generate report"""
# Validate file types
if not spirometry_pdf.filename.endswith(".pdf"):
return render_template("upload.html", {
"request": request,
"session": request.session,
"error": "Spirometry file must be a PDF"
})
return render_template(
"upload.html",
{
"request": request,
"session": request.session,
"error": "Spirometry file must be a PDF",
},
)
if not pnoe_csv.filename.endswith(".csv"):
return render_template("upload.html", {
"request": request,
"session": request.session,
"error": "Pnoe file must be a CSV"
})
if not seca_excel.filename.endswith((".xlsx", ".xls")):
return render_template("upload.html", {
"request": request,
"session": request.session,
"error": "SECA file must be an Excel file (.xlsx or .xls)"
})
return render_template(
"upload.html",
{
"request": request,
"session": request.session,
"error": "Pnoe file must be a CSV",
},
)
# Validate oxygenation CSV if provided
if oxygenation_csv and oxygenation_csv.filename:
if not oxygenation_csv.filename.endswith(".csv"):
return render_template(
"upload.html",
{
"request": request,
"session": request.session,
"error": "Oxygenation file must be a CSV",
},
)
# Create session-specific temp directory
session_uuid = str(uuid.uuid4())
session_temp_dir = TEMP_DIR / session_uuid
session_temp_dir.mkdir(exist_ok=True, parents=True)
# Save uploaded files
spirometry_path = session_temp_dir / f"spirometry_{spirometry_pdf.filename}"
pnoe_path = session_temp_dir / f"pnoe_{pnoe_csv.filename}"
seca_path = session_temp_dir / f"seca_{seca_excel.filename}"
oxygenation_path = None
try:
# Write files
with open(spirometry_path, "wb") as f:
shutil.copyfileobj(spirometry_pdf.file, f)
with open(pnoe_path, "wb") as f:
shutil.copyfileobj(pnoe_csv.file, f)
with open(seca_path, "wb") as f:
shutil.copyfileobj(seca_excel.file, f)
# Save oxygenation CSV if provided
if oxygenation_csv and oxygenation_csv.filename:
oxygenation_path = (
session_temp_dir / f"oxygenation_{oxygenation_csv.filename}"
)
with open(oxygenation_path, "wb") as f:
shutil.copyfileobj(oxygenation_csv.file, f)
# Prepare patient information
patient_name = f"{first_name} {last_name}"
patient_info = {
@@ -137,76 +184,100 @@ async def upload_files(
"height": height,
"weight": weight,
"gender": gender,
"fat_percentage": fat_percentage,
"focus": focus,
"session_id": session_id,
}
# Generate report
oxygenation_csv_path = str(oxygenation_path) if oxygenation_path else None
result = await report_service.generate_report(
spirometry_pdf_path=str(spirometry_path),
pnoe_csv_path=str(pnoe_path),
seca_excel_path=str(seca_path),
patient_info=patient_info,
oxygenation_csv_path=oxygenation_csv_path,
)
# Store in session
request.session["patient_info"] = patient_info
request.session["temp_dir"] = str(session_temp_dir)
request.session["report_path"] = result["report_path"]
request.session["graphs_generated"] = result["graphs_generated"]
request.session["analysis_data"] = result["analysis_data"]
# Extract spirometry CSV path (it's saved in data_dir by the service)
from services.spirometry_table_extractor import extract_spirometry_table_from_pdf
from services.context_generator import ContextGenerator
from pathlib import Path as PathLib
from services.context_generator import ContextGenerator
from services.spirometry_table_extractor import (
extract_spirometry_table_from_pdf,
)
# The spirometry CSV is extracted during report generation
# We need to find it or extract it again
data_dir = PathLib("data")
spirometry_csv_path = data_dir / f"spirometry_{Path(spirometry_pdf.filename).stem}.csv"
spirometry_csv_path = (
data_dir / f"spirometry_{Path(spirometry_pdf.filename).stem}.csv"
)
# If it doesn't exist, extract it
if not spirometry_csv_path.exists():
spirometry_csv_path = extract_spirometry_table_from_pdf(
str(spirometry_path), output_dir=str(data_dir)
)
spirometry_csv_path = PathLib(spirometry_csv_path)
# Get calculated metrics for display and editing
context_gen = ContextGenerator()
context_gen.load_data(
str(pnoe_path),
str(spirometry_csv_path),
str(seca_path)
None, # No SECA file needed anymore
)
context_gen.extract_patient_info(last_name) # Extract patient info
# Set patient info manually since we're not reading from SECA
weight_kg = float(weight.replace("lbs", "").replace("kg", "").strip())
if "lbs" in weight.lower():
weight_kg = weight_kg / 2.20462 # Convert lbs to kg
context_gen.patient_info = {
"name": first_name,
"last_name": last_name,
"age": age,
"weight": weight_kg,
"fat_percentage": fat_percentage,
"gender": gender,
}
spirometry_metrics = context_gen.calculate_spirometry_metrics()
pnoe_metrics = context_gen.calculate_pnoe_metrics()
# Store metrics in session
request.session["metrics"] = {
"spirometry": spirometry_metrics,
"pnoe": pnoe_metrics,
}
request.session["spirometry_csv_path"] = str(spirometry_csv_path)
return RedirectResponse(url="/preview", status_code=303)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"ERROR: {error_details}")
return render_template("upload.html", {
"request": request,
"session": request.session,
"error": f"Error generating report: {str(e)}"
})
return render_template(
"upload.html",
{
"request": request,
"session": request.session,
"error": f"Error generating report: {str(e)}",
},
)
finally:
# Close file handles
spirometry_pdf.file.close()
pnoe_csv.file.close()
seca_excel.file.close()
if oxygenation_csv and oxygenation_csv.filename:
oxygenation_csv.file.close()
@app.get("/preview", response_class=HTMLResponse)
@@ -214,7 +285,9 @@ async def preview(request: Request):
"""Preview generated report"""
if not request.session.get("report_path"):
return RedirectResponse(url="/", status_code=303)
return render_template("preview.html", {"request": request, "session": request.session})
return render_template(
"preview.html", {"request": request, "session": request.session}
)
@app.get("/graphs/{filename}")
@@ -231,7 +304,9 @@ async def edit_form(request: Request):
"""Display edit metrics form"""
if not request.session.get("metrics"):
return RedirectResponse(url="/", status_code=303)
return render_template("edit.html", {"request": request, "session": request.session})
return render_template(
"edit.html", {"request": request, "session": request.session}
)
@app.post("/edit")
@@ -239,16 +314,13 @@ async def edit_metrics(request: Request):
"""Handle metric edits and regenerate report"""
if not request.session.get("temp_dir") or not request.session.get("patient_info"):
return RedirectResponse(url="/", status_code=303)
# Get form data
form_data = await request.form()
# Build metric overrides
metric_overrides = {
"pnoe": {},
"spirometry": {}
}
metric_overrides = {"pnoe": {}, "spirometry": {}}
# Pnoe overrides
if form_data.get("vo2_max"):
metric_overrides["pnoe"]["vo2_max"] = float(form_data["vo2_max"])
@@ -262,28 +334,36 @@ async def edit_metrics(request: Request):
metric_overrides["pnoe"]["fat_max_value"] = float(form_data["fat_max_value"])
if form_data.get("fat_max_hr"):
metric_overrides["pnoe"]["fat_max_hr"] = float(form_data["fat_max_hr"])
# VT1 and VT2 overrides
if form_data.get("vt1_hr") or form_data.get("vt1_speed") or form_data.get("vt1_time"):
if (
form_data.get("vt1_hr")
or form_data.get("vt1_speed")
or form_data.get("vt1_time")
):
metric_overrides["pnoe"]["vt1"] = {
"HeartRate": float(form_data.get("vt1_hr", 0)),
"Speed": float(form_data.get("vt1_speed", 0)),
"Time": float(form_data.get("vt1_time", 0))
"Time": float(form_data.get("vt1_time", 0)),
}
if form_data.get("vt2_hr") or form_data.get("vt2_speed") or form_data.get("vt2_time"):
if (
form_data.get("vt2_hr")
or form_data.get("vt2_speed")
or form_data.get("vt2_time")
):
metric_overrides["pnoe"]["vt2"] = {
"HeartRate": float(form_data.get("vt2_hr", 0)),
"Speed": float(form_data.get("vt2_speed", 0)),
"Time": float(form_data.get("vt2_time", 0))
"Time": float(form_data.get("vt2_time", 0)),
}
# Heart rate zones
for i in range(1, 6):
zone_key = f"zone{i}_bpm"
if form_data.get(zone_key):
metric_overrides["pnoe"][zone_key] = form_data[zone_key]
# Spirometry overrides
if form_data.get("fvc_best"):
metric_overrides["spirometry"]["fvc_best"] = float(form_data["fvc_best"])
@@ -294,88 +374,120 @@ async def edit_metrics(request: Request):
if form_data.get("fev1_pred"):
metric_overrides["spirometry"]["fev1_pred"] = float(form_data["fev1_pred"])
if form_data.get("fev1_fvc_pct_best"):
metric_overrides["spirometry"]["fev1_fvc_pct_best"] = float(form_data["fev1_fvc_pct_best"])
metric_overrides["spirometry"]["fev1_fvc_pct_best"] = float(
form_data["fev1_fvc_pct_best"]
)
if form_data.get("fev1_fvc_pct_pred"):
metric_overrides["spirometry"]["fev1_fvc_pct_pred"] = float(form_data["fev1_fvc_pct_pred"])
metric_overrides["spirometry"]["fev1_fvc_pct_pred"] = float(
form_data["fev1_fvc_pct_pred"]
)
try:
# Get file paths from session
temp_dir = Path(request.session["temp_dir"])
patient_info = request.session["patient_info"]
# Find files in temp directory
spirometry_path = None
pnoe_path = None
seca_path = None
oxygenation_path = None
for file_path in temp_dir.iterdir():
if file_path.name.startswith("spirometry_"):
spirometry_path = file_path
elif file_path.name.startswith("pnoe_"):
pnoe_path = file_path
elif file_path.name.startswith("seca_"):
seca_path = file_path
if not all([spirometry_path, pnoe_path, seca_path]):
raise ValueError("Could not find all uploaded files")
elif file_path.name.startswith("oxygenation_"):
oxygenation_path = file_path
if not all([spirometry_path, pnoe_path]):
raise ValueError("Could not find all required uploaded files")
# Regenerate report with overrides
oxygenation_csv_path = str(oxygenation_path) if oxygenation_path else None
result = await report_service.generate_report(
spirometry_pdf_path=str(spirometry_path),
pnoe_csv_path=str(pnoe_path),
seca_excel_path=str(seca_path),
patient_info=patient_info,
metric_overrides=metric_overrides if (metric_overrides["pnoe"] or metric_overrides["spirometry"]) else None,
metric_overrides=metric_overrides
if (metric_overrides["pnoe"] or metric_overrides["spirometry"])
else None,
oxygenation_csv_path=oxygenation_csv_path,
)
# Update session with new report
request.session["report_path"] = result["report_path"]
request.session["graphs_generated"] = result["graphs_generated"]
request.session["analysis_data"] = result["analysis_data"]
# Recalculate metrics with overrides
from services.context_generator import ContextGenerator
context_gen = ContextGenerator()
spirometry_csv_path = request.session.get("spirometry_csv_path", "")
if not spirometry_csv_path or not Path(spirometry_csv_path).exists():
from services.spirometry_table_extractor import extract_spirometry_table_from_pdf
from pathlib import Path as PathLib
from services.spirometry_table_extractor import (
extract_spirometry_table_from_pdf,
)
data_dir = PathLib("data")
spirometry_csv_path = extract_spirometry_table_from_pdf(
str(spirometry_path), output_dir=str(data_dir)
)
spirometry_csv_path = str(PathLib(spirometry_csv_path))
context_gen.load_data(
str(pnoe_path),
spirometry_csv_path,
str(seca_path)
None, # No SECA file
)
# Set patient info manually
weight_str = patient_info.get("weight", "0")
weight_kg = float(weight_str.replace("lbs", "").replace("kg", "").strip())
if "lbs" in weight_str.lower():
weight_kg = weight_kg / 2.20462 # Convert lbs to kg
context_gen.patient_info = {
"name": patient_info.get("first_name", ""),
"last_name": patient_info.get("last_name", ""),
"age": patient_info.get("age", 25),
"weight": weight_kg,
"fat_percentage": patient_info.get("fat_percentage", 0),
"gender": patient_info.get("gender", "female"),
}
context_gen.extract_patient_info(patient_info.get("last_name", ""))
spirometry_overrides = metric_overrides.get("spirometry", {})
pnoe_overrides = metric_overrides.get("pnoe", {})
spirometry_metrics = context_gen.calculate_spirometry_metrics(spirometry_overrides)
spirometry_metrics = context_gen.calculate_spirometry_metrics(
spirometry_overrides
)
pnoe_metrics = context_gen.calculate_pnoe_metrics(pnoe_overrides)
# Update metrics in session
request.session["metrics"] = {
"spirometry": spirometry_metrics,
"pnoe": pnoe_metrics,
}
request.session["spirometry_csv_path"] = spirometry_csv_path
return RedirectResponse(url="/preview", status_code=303)
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"ERROR: {error_details}")
return render_template("edit.html", {
"request": request,
"session": request.session,
"error": f"Error regenerating report: {str(e)}"
})
return render_template(
"edit.html",
{
"request": request,
"session": request.session,
"error": f"Error regenerating report: {str(e)}",
},
)
@app.get("/health")