Skip to content

Commit 83f0389

Browse files
committed
Update README.md
1 parent ce86a44 commit 83f0389

1 file changed

Lines changed: 53 additions & 43 deletions

File tree

README.md

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This project implements a dual-architecture early warning system comparing gradi
1212

1313
Models were trained on the MIMIC-IV Clinical Demo v2.2 dataset (100 patients), using dual feature engineering pipelines: 171 timestamp-level temporal features (24-hour windows) for TCN, and 40 patient-level aggregated features for LightGBM.
1414

15-
The hybrid approach reveals complementary strengths: LightGBM achieves superior calibration and regression fidelity (68% Brier reduction, +17% AUC, +44% R²) for sustained risk assessment, while TCN demonstrates stronger acute event discrimination (+9.3% AUC, superior sensitivity) for detecting rapid deterioration. Together, they characterise short-term instability and longer-term exposure to physiological risk.
15+
**The hybrid approach reveals complementary strengths:** LightGBM achieves superior calibration and regression fidelity (68% Brier reduction, +17% AUC, +44% R²) for sustained risk assessment, while TCN demonstrates stronger acute event discrimination (+9.3% AUC, superior sensitivity) for detecting rapid deterioration. Together, they characterise short-term instability and longer-term exposure to physiological risk.
1616

1717
The complete pipeline includes clinically validated NEWS2 preprocessing (CO₂ retainer logic, GCS mapping, supplemental O₂ protocols), comprehensive feature engineering, robust evaluation, and model-specific interpretability (SHAP for LightGBM; gradient×input saliency for TCN).
1818

@@ -786,15 +786,15 @@ Maintaining both feature sets ensures flexibility and robustness in model select
786786
3. Permute back for pooling → `(B, L, C_last)`
787787
4. Apply masked mean pooling → `(B, C_last)`
788788
- If mask provided, ignore padding, average over real timestamps.
789-
6. **Optional dense head (if enabled)**
789+
5. **Optional dense head (if enabled)**
790790
- Linear → ReLU → Dropout
791791
- Adds non-linearity and regularisation after pooling
792792
- Output `(B, head_hidden)`
793-
7. **Pass to task-specific heads:**
793+
6. **Pass to task-specific heads:**
794794
- `classifier_max`: binary logit `(B,)`
795795
- `classifier_median`: binary logit `(B,)`
796796
- `regressor`: continuous risk `(B,)`
797-
8. **Output:**
797+
7. **Output:**
798798
- Return dictionary of patient-level predictions (ready for loss functions).
799799
- `logit_max`, `logit_median` → binary classification logits
800800
- `regressor` → continuous regression output
@@ -932,47 +932,51 @@ The following hyperparameters were used when training the final TCN model and st
932932
```
933933

934934
#### 7.5.1 Pre-Training Refinements
935-
- Initial evaluation metrics and subsequent diagnostics identified issues with model performance:
936-
1. Poor learning on median-risk → class imbalance; poor calibration
937-
2. Poor regression peformance → underfitting; skewed targets
938-
- Implemented minimal controlled fixes, keeping architecture/hyperparameters constant:
939-
1. Log-transform regression target → `log1p(y)` before tensor creation to reduce regression skew and stabilise variance
940-
2. Applying class weighting (`pos_weight = 2.889`) for median-risk BCE loss → correct class imbalance
935+
936+
Initial evaluation metrics and subsequent diagnostics identified issues with model performance:
937+
938+
1. Poor learning on median-risk → class imbalance; poor calibration
939+
2. Poor regression peformance → underfitting; skewed targets
940+
941+
Implemented minimal controlled fixes, keeping architecture/hyperparameters constant:
942+
943+
1. Log-transform regression target → `log1p(y)` before tensor creation to reduce regression skew and stabilise variance
944+
2. Applying class weighting (`pos_weight = 2.889`) for median-risk BCE loss → correct class imbalance
941945

942946
#### 7.5.2 Setup Flow
943947
1. **Imports & Config**
944-
- Import custom `TCNModel` from `tcn_model.py`
945-
- **Define hyperparameters**: `DEVICE`, `BATCH_SIZE`, `EPOCHS`, `LR`, `EARLY_STOPPING_PATIENCE`
948+
- Import custom `TCNModel` from `tcn_model.py`
949+
- **Define hyperparameters**: `DEVICE`, `BATCH_SIZE`, `EPOCHS`, `LR`, `EARLY_STOPPING_PATIENCE`
946950
2. **Load Prepared Sequence Data**
947-
- Load padded tensors (`train.pt`, `val.pt`, `test.pt`) and their corresponding masks (valid timesteps vs padding)
948-
- These are the time-series features per patient, already standardised + padded to equal length by the preprocessing pipeline.
951+
- Load padded tensors (`train.pt`, `val.pt`, `test.pt`) and their corresponding masks (valid timesteps vs padding)
952+
- These are the time-series features per patient, already standardised + padded to equal length by the preprocessing pipeline.
949953
3. **Build Target Tensors (Patient Labels)**
950-
- Load patient-level CSV (`news2_features_patient.csv`).
951-
- **Recreate binary labels (same as LightGBM):** `max_risk_binary` (high vs not-high risk), `median_risk_binary` (low vs medium)
952-
- Load splits (`patient_splits.json`) so each patient is consistently assigned to train/val/test.
954+
- Load patient-level CSV (`news2_features_patient.csv`).
955+
- **Recreate binary labels (same as LightGBM):** `max_risk_binary` (high vs not-high risk), `median_risk_binary` (low vs medium)
956+
- Load splits (`patient_splits.json`) so each patient is consistently assigned to train/val/test.
953957
4. **Apply Target Transformations**
954958
- Log-transform regression target → `log1p()` for variance stabilisation
955959
- Compute class weights for `median_risk` weighted BCE loss → `pos_weight = num_neg / num_pos``2.889`
956960
5. **Build Target Tensors**
957961
- Create PyTorch tensors for all 3 targets in all 3 splits (train/val/test) → `y_<split>_max, y_<split>_median, y_<split>_reg`
958962
6. **Construct Datasets & Dataloaders**
959-
- `TensorDataset` groups together (inputs, masks, targets) into one dataset object per patient.
960-
- `DataLoader` creates shuffled mini-batches:
963+
- `TensorDataset` groups together (inputs, masks, targets) into one dataset object per patient.
964+
- `DataLoader` creates shuffled mini-batches:
961965
- `batch_size=32` (32 patients per step) → improves GPU efficiency; stabilises gradient descent.
962966
- `shuffle=True` → prevents learning artefacts from patient order.
963967
7. **Model Initialisation**
964968
- Instantiate `TCNModel(num_features=171, num_channels=[64,64,128], head_hidden=64)`
965-
- Move model to GPU/CPU device.
969+
- Move model to GPU/CPU device.
966970
8. **Loss Functions**
967971
- `criterion_max` = `BCEWithLogitsLoss` → binary classification.
968972
- `criterion_median` = `BCEWithLogitsLoss(pos_weight=2.889)` → binary classification with weighted BCE
969973
- `criterion_reg` = `MSELoss` → log-transformed regression task.
970974
9. **Optimiser + Scheduler**
971-
- Optimiser = `Adam` (LR=1e-3) adapts learning rate per parameter → faster convergence
972-
- Scheduler = `ReduceLROnPlateau` (patience=3, factor=0.5) → halves LR on plateau
975+
- Optimiser = `Adam` (LR=1e-3) adapts learning rate per parameter → faster convergence
976+
- Scheduler = `ReduceLROnPlateau` (patience=3, factor=0.5) → halves LR on plateau
973977
10. **Reproducibility Controls**
974-
- Fixed seeds for Python, NumPy, and PyTorch.
975-
- Deterministic CuDNN convolution settings.
978+
- Fixed seeds for Python, NumPy, and PyTorch.
979+
- Deterministic CuDNN convolution settings.
976980

977981
##
978982
### 7.6 Training & Validation Loop
@@ -1000,52 +1004,58 @@ The following hyperparameters were used when training the final TCN model and st
10001004
```
10011005

10021006
**Inner Loop (per batch) - Learning and Optimisation Cycle**
1007+
10031008
This is the fundamental algorithm that performs the learning; gradient optimisation and weight modifying that runs once per batch:
1009+
10041010
1. **Forward pass**: model computes predictions for the batch.
10051011
2. **Loss computation (BCE, MSE)**: predictions are compared to true labels → gives `loss_max, loss_median, loss_reg`.
10061012
3. **Combine losses**: summed to get overall batch loss.
10071013
4. **Backward pass**: compute gradients → tells how to adjust weights to reduce loss.
10081014
5. **Gradient clipping**: prevent exploding gradients, stabilises updates.
10091015
5. **Optimizer step**: update weights using the gradients, gradients determine direction, learning rate determines size.
1016+
10101017
**Outer Loop (per epoch) - Training Controller**
1018+
10111019
Runs once per epoch, and controls the training process by supervising the inner loop to prevent overfitting:
1020+
10121021
1. **Call inner loop:** goes through all training batches, updates weights, returns average training loss
10131022
2. **Validation loop:** evaluates the model’s generalisation (no updates, no gradients), returns average validation loss
10141023
3. **Learning rate scheduler:** modifies learning rate based on validation loss (`ReduceLROnPlateau`).
10151024
4. **Early stopping logic:** if validation hasn't improved for 7 epochs, terminate training to prevent overfitting
10161025
5. **Checkpoint saving:** if validation improves, save best model; once training ends we are left with best model.
1026+
10171027
Repeat until early stoppage to prevent overfitting, the inner loop runs many times per epoch.
10181028

10191029
#### 7.6.2 Flow of Logic
10201030
1. **Training Loop**
1021-
- Loop over epochs (one full pass through the entire training dataset)
1031+
- Loop over epochs (one full pass through the entire training dataset)
10221032
- Each epoch allows the model to adjust weights progressively per batch
1023-
- **For each batch**:
1033+
- **For each batch**:
10241034
- Move input sequences (`x_batch`) and masks to device
1025-
- Forward pass → model predicts 3 outputs (`logit_max, logit_median, regression`).
1026-
- Compute individual losses with loss functions → compare predictions to true labels (`y_max, y_median, y_reg`).
1035+
- Forward pass → model predicts 3 outputs (`logit_max, logit_median, regression`)
1036+
- Compute individual losses with loss functions → compare predictions to true labels (`y_max, y_median, y_reg`)
10271037
- Combine losses into 1 (`loss = loss_max + loss_median + loss_reg`) → one scalar loss value means each task contributes equally (multi-task learning).
1028-
- Backward pass → calculate gradients of this total loss w.r.t. every model parameter.
1029-
- Gradient clipping (`clip_grad_norm_`) → prevents exploding gradients (if gradients get too large, clipping rescales gradients, keeps training stable).
1030-
- Optimiser step → updates weights in opposite direction of the gradients.
1031-
- **This is the deep learning itself**: forward pass → loss → backward pass → update weights.
1038+
- Backward pass → calculate gradients of this total loss w.r.t. every model parameter
1039+
- Gradient clipping (`clip_grad_norm_`) → prevents exploding gradients (if gradients get too large, clipping rescales gradients, keeps training stable)
1040+
- Optimiser step → updates weights in opposite direction of the gradients
1041+
- **This is the deep learning itself**: forward pass → loss → backward pass → update weights
10321042
3. **Track Average Training Loss per Epoch**
10331043
- Weighted average over batch sizes → mean epoch training loss per patient
1034-
- Logged and compared with validation loss for analysis → see if model is learning.
1044+
- Logged and compared with validation loss for analysis → see if model is learning
10351045
2. **Validation Loop**
10361046
- Set model to evaluation mode → disable dropout, batch norm updates
1037-
- Run the model on validation split (no gradients or optimiser step).
1038-
- Compute and track average validation loss per epoch .
1039-
- Scheduler step → Update LR scheduler based on validation loss.
1040-
- **Logic**:
1047+
- Run the model on validation split (no gradients or optimiser step)
1048+
- Compute and track average validation loss per epoch
1049+
- Scheduler step → Update LR scheduler based on validation loss
1050+
- **Logic**:
10411051
- When validation loss improves (validation loss ↓) → save model, final model state will be best one
1042-
- When validation loss stagnates/gets worse (validation loss ↑) → patience counter increases.
1043-
- Early stopping: Training stops early when overfitting begins (after 7 epochs of no improvement).
1044-
- **Rationale**: validation loss tells us if the model is generalising or just memorising training data .
1052+
- When validation loss stagnates/gets worse (validation loss ↑) → patience counter increases
1053+
- Early stopping: Training stops early when overfitting begins (after 7 epochs of no improvement)
1054+
- **Rationale**: validation loss tells us if the model is generalising or just memorising training data
10451055
10. **Early Stopping**
1046-
- If validation loss improves → save .pt model
1047-
- If no improvement for 7 epochs → stop training early.
1048-
- **Rationale**: protects against overfitting and wasted compute.
1056+
- If validation loss improves → save .pt model
1057+
- If no improvement for 7 epochs → stop training early
1058+
- **Rationale**: protects against overfitting and wasted compute
10491059

10501060
#### 7.6.3 Loop Rationale
10511061
- **Multi-task learning:** Losses from 3 outputs combined for joint learning.

0 commit comments

Comments
 (0)