|
| 1 | +# ECG2Stroke — Deep Learning to Predict Incident Ischemic Stroke |
| 2 | + |
| 3 | +This directory contains models and code for predicting incident ischemic stroke from 12-lead resting ECGs, as described in <mark>INSERT LINK HERE</mark>. |
| 4 | + |
| 5 | +**Input:** |
| 6 | +* **Modality**: 12‑lead resting ECG |
| 7 | + * **Expected shape**: `(batch, time, leads)` = `(B, 5000, 12)` |
| 8 | + |
| 9 | + * Sampling rate: **500 Hz** (10 seconds ⇒ 5000 samples) |
| 10 | + * **Normalization**: Z-score normalized per‑lead (i.e., mean of 0 and standard deviation of 1). |
| 11 | + |
| 12 | +**Outputs:** |
| 13 | +* **Survival curve prediction for incident ischemic stroke** |
| 14 | +* **Survival curve prediction for death** |
| 15 | +* **Sex classification** |
| 16 | +* **Age regression** |
| 17 | +* **Classification of atrial fibrillation at the time of ECG** |
| 18 | + |
| 19 | +The raw model files are stored using `git lfs` so you must have it installed and localize the full ~225MB file with: |
| 20 | +```bash |
| 21 | +git lfs pull --include model_zoo/ECG2Stroke/ecg2stroke_dropout_2024_10_04_10_49_43.h5 |
| 22 | +``` |
| 23 | + |
| 24 | +To load the model in a jupyter notebook (running with the ml4h docker), run: |
| 25 | + |
| 26 | +```python |
| 27 | +import numpy as np |
| 28 | +from tensorflow.keras.models import load_model |
| 29 | +from ml4h.tensormap.ukb.demographics import sex_dummy1, age_in_days, af_dummy2 |
| 30 | +from ml4h.tensormap.ukb.survival import mgb_stroke_wrt_instance2, mgb_death_wrt_instance2 |
| 31 | + |
| 32 | +output_tensormaps = {tm.output_name(): tm for tm in [mgb_stroke_wrt_instance2, mgb_death_wrt_instance2, |
| 33 | + sex_dummy1, age_in_days, af_dummy2]} |
| 34 | +model = load_model('./ecg2stroke_dropout_2024_10_04_10_49_43.h5') |
| 35 | +ecg = np.random.random((1, 5000, 12)) |
| 36 | +prediction = model(ecg) |
| 37 | +``` |
| 38 | +If above does not work you may need to use an absolute path in `load_model`. |
| 39 | + |
| 40 | +The model has 5 output heads: the survival curve prediction for incident ischemic stroke, the survival curve prediction for death, sex classification, age regression, and classification of atrial fibrillation at the time of ECG. Those outputs can be accessed with: |
| 41 | +```python |
| 42 | +for name, pred in zip(model.output_names, prediction): |
| 43 | + otm = output_tensormaps[name] |
| 44 | + if otm.is_survival_curve(): |
| 45 | + intervals = otm.shape[-1] // 2 |
| 46 | + days_per_bin = 1 + (2*otm.days_window) // intervals |
| 47 | + predicted_survivals = np.cumprod(pred[:, :intervals], axis=1) |
| 48 | + print(f'Stroke Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}') |
| 49 | + else: |
| 50 | + print(f'{otm} prediction is {pred}') |
| 51 | +``` |
| 52 | + |
| 53 | + |
| 54 | +To perform command line inference with this model run: |
| 55 | +```bash |
| 56 | + python /path/to/ml4h/ml4h/recipes.py \ |
| 57 | + --mode infer \ |
| 58 | + --tensors /path/to/tensors/ \ |
| 59 | + --input_tensors ecg.ecg_rest_mgb \ |
| 60 | + --output_tensors survival.mgb_stroke_wrt_instance2 survival.mgb_death_wrt_instance2 \ |
| 61 | + demographics.sex_dummy demographics.age_in_days demographics.af_dummy \ |
| 62 | + --tensormap_prefix ml4h.tensormap.ukb \ |
| 63 | + --id ecg2stroke_dropout_task_inference \ |
| 64 | + --output_folder /path/to/ml4h_runs/ \ |
| 65 | + --model_file /path/to/ml4h/model_zoo/ECG2Stroke/ecg2stroke_dropout_2024_10_04_10_49_43.h5' |
| 66 | +``` |
| 67 | +
|
| 68 | +### Study flow diagram |
| 69 | +<div style="padding: 10px; background-color: white; display: inline-block;"> |
| 70 | + <img src="./ecg2stroke_study_design.png" alt="Study flow diagram" /> |
| 71 | +</div> |
| 72 | +
|
| 73 | +### Performance |
| 74 | +A) Smoothed calibration curves depicting predicted versus observed event probabilities for 10-year incident stroke for ECG2Stroke in MGH, BWH, and BIDMC test sets. Diagonal dashed line indicates perfect calibration. Curves are obtained using restricted cubic spliness22. B-D) Stroke-free survival stratified by quintiles of ECG2Stroke predicted risk for 10-year incident stroke in MGH, BWH, and BIDMC test sets. Transparent bands indicate 95% confidence intervals for survival probability. |
| 75 | +<div style="padding: 10px; background-color: white; display: inline-block;"> |
| 76 | + <img src="./ecg2stroke_performance.png" alt="Calibration and risk stratification for 10-year incident stroke prediction" /> |
| 77 | +</div> |
| 78 | +
|
| 79 | +### Salience |
| 80 | +<div style="padding: 10px; background-color: white; display: inline-block;"> |
| 81 | +Saliency maps of ECG2Stroke demarcating regions of the ECG waveform having the greatest influence on stroke risk predictions. Blue shades depict the magnitude of the gradient of predicted stroke risk with respect to the ECG waveform amplitude, where darker shades illustrate regions of the waveform exerting greater salience or influence on stroke risk predictions. Saliency was averaged over 7080 random samples from the Brigham and Women’s Hospital (BWH) test set . The black waveform depicts the median waveform in each lead among the 7080 individuals. <div style="padding: 10px; background-color: white; display: inline-block;"> |
| 82 | + <img src="./ecg2stroke_salience.png" alt="ECG Salience" /> |
| 83 | +</div> |
| 84 | +
|
| 85 | +### Architecture |
| 86 | +A schematic of the neural network architecture. The model takes 10 seconds of 12-lead ECG waveform data as input, which is processed through a series of convolutional layers. The resulting learned features are passed to fully-connected layers to produce an estimate of time to stroke (primary) as well as predictions of time to death, age, sex, and presence of AF in the ECG diagnostic statement (secondary). Arrows indicate the flow of data between layers. Conv1D, one-dimensional convolution, MaxPooling1D, one-dimensional maximum pooling. |
| 87 | +<div style="padding: 10px; background-color: white; display: inline-block;"> |
| 88 | + <img src="./ecg2stroke_architecture.png" alt="Neural network architecture" /> |
| 89 | +</div> |
0 commit comments