Skip to content

Commit 8353d5f

Browse files
authored
Dfp ecg2stroke zoo (#612)
* COMP: initialize git lfs for ecg2stroke model * ENH: Add ecg2stroke model weights * COMP: initialize git lfs for ecg2stroke figures * ENH: Add ecg2stroke figures * ENH: Add ECG2Stroke README
1 parent 0a45631 commit 8353d5f

File tree

8 files changed

+120
-1
lines changed

8 files changed

+120
-1
lines changed

.gitattributes

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,8 @@ model_zoo/ECG_PheWAS/encoder_median.keras filter=lfs diff=lfs merge=lfs -text
4949
model_zoo/ECG_PheWAS/mgh_biosppy_median_60bpm_autoencoder_256d_v2022_05_21.keras filter=lfs diff=lfs merge=lfs -text
5050
model_zoo/ECG2HF/ecg_5000_hf_quintuplet_dropout_v2023_04_17.keras filter=lfs diff=lfs merge=lfs -text
5151
model_zoo/ECG2Age/*.keras filter=lfs diff=lfs merge=lfs -text
52+
model_zoo/ECG2Stroke/ecg2stroke_dropout_2024_10_04_10_49_43.h5 filter=lfs diff=lfs merge=lfs -text
53+
model_zoo/ECG2Stroke/ecg2stroke_salience.png filter=lfs diff=lfs merge=lfs -text
54+
model_zoo/ECG2Stroke/ecg2stroke_study_design.png filter=lfs diff=lfs merge=lfs -text
55+
model_zoo/ECG2Stroke/ecg2stroke_architecture.png filter=lfs diff=lfs merge=lfs -text
56+
model_zoo/ECG2Stroke/ecg2stroke_performance.png filter=lfs diff=lfs merge=lfs -text

ml4h/tensormap/ukb/survival.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,17 @@ def _cox_tensor_from_file(tm: TensorMap, hd5: h5py.File, dependents=None):
421421
start_date_is_attribute=True, incidence_only=True,
422422
),
423423
)
424-
424+
mgb_stroke_wrt_instance2 = TensorMap(
425+
'stroke_event',
426+
Interpretation.SURVIVAL_CURVE,
427+
shape=(50,),
428+
days_window=DAYS_IN_10_YEARS,
429+
tensor_from_file=_survival_tensor(
430+
'ukb_ecg_rest/ecg_rest_text/instance_0', DAYS_IN_10_YEARS,
431+
disease_name_override='stroke',
432+
start_date_is_attribute=True, incidence_only=True,
433+
),
434+
)
425435

426436
prevalent_hf_wrt_instance2 = TensorMap(
427437
'heart_failure', Interpretation.CATEGORICAL, storage_type=StorageType.CATEGORICAL_FLAG,

model_zoo/ECG2Stroke/README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# ECG2Stroke — Deep Learning to Predict Incident Ischemic Stroke
2+
3+
This directory contains models and code for predicting incident ischemic stroke from 12-lead resting ECGs, as described in <mark>INSERT LINK HERE</mark>.
4+
5+
**Input:**
6+
* **Modality**: 12‑lead resting ECG
7+
* **Expected shape**: `(batch, time, leads)` = `(B, 5000, 12)`
8+
9+
* Sampling rate: **500 Hz** (10 seconds ⇒ 5000 samples)
10+
* **Normalization**: Z-score normalized per‑lead (i.e., mean of 0 and standard deviation of 1).
11+
12+
**Outputs:**
13+
* **Survival curve prediction for incident ischemic stroke**
14+
* **Survival curve prediction for death**
15+
* **Sex classification**
16+
* **Age regression**
17+
* **Classification of atrial fibrillation at the time of ECG**
18+
19+
The raw model files are stored using `git lfs` so you must have it installed and localize the full ~225MB file with:
20+
```bash
21+
git lfs pull --include model_zoo/ECG2Stroke/ecg2stroke_dropout_2024_10_04_10_49_43.h5
22+
```
23+
24+
To load the model in a jupyter notebook (running with the ml4h docker), run:
25+
26+
```python
27+
import numpy as np
28+
from tensorflow.keras.models import load_model
29+
from ml4h.tensormap.ukb.demographics import sex_dummy1, age_in_days, af_dummy2
30+
from ml4h.tensormap.ukb.survival import mgb_stroke_wrt_instance2, mgb_death_wrt_instance2
31+
32+
output_tensormaps = {tm.output_name(): tm for tm in [mgb_stroke_wrt_instance2, mgb_death_wrt_instance2,
33+
sex_dummy1, age_in_days, af_dummy2]}
34+
model = load_model('./ecg2stroke_dropout_2024_10_04_10_49_43.h5')
35+
ecg = np.random.random((1, 5000, 12))
36+
prediction = model(ecg)
37+
```
38+
If above does not work you may need to use an absolute path in `load_model`.
39+
40+
The model has 5 output heads: the survival curve prediction for incident ischemic stroke, the survival curve prediction for death, sex classification, age regression, and classification of atrial fibrillation at the time of ECG. Those outputs can be accessed with:
41+
```python
42+
for name, pred in zip(model.output_names, prediction):
43+
otm = output_tensormaps[name]
44+
if otm.is_survival_curve():
45+
intervals = otm.shape[-1] // 2
46+
days_per_bin = 1 + (2*otm.days_window) // intervals
47+
predicted_survivals = np.cumprod(pred[:, :intervals], axis=1)
48+
print(f'Stroke Risk {otm} prediction is: {str(1 - predicted_survivals[0, -1])}')
49+
else:
50+
print(f'{otm} prediction is {pred}')
51+
```
52+
53+
54+
To perform command line inference with this model run:
55+
```bash
56+
python /path/to/ml4h/ml4h/recipes.py \
57+
--mode infer \
58+
--tensors /path/to/tensors/ \
59+
--input_tensors ecg.ecg_rest_mgb \
60+
--output_tensors survival.mgb_stroke_wrt_instance2 survival.mgb_death_wrt_instance2 \
61+
demographics.sex_dummy demographics.age_in_days demographics.af_dummy \
62+
--tensormap_prefix ml4h.tensormap.ukb \
63+
--id ecg2stroke_dropout_task_inference \
64+
--output_folder /path/to/ml4h_runs/ \
65+
--model_file /path/to/ml4h/model_zoo/ECG2Stroke/ecg2stroke_dropout_2024_10_04_10_49_43.h5'
66+
```
67+
68+
### Study flow diagram
69+
<div style="padding: 10px; background-color: white; display: inline-block;">
70+
<img src="./ecg2stroke_study_design.png" alt="Study flow diagram" />
71+
</div>
72+
73+
### Performance
74+
A) Smoothed calibration curves depicting predicted versus observed event probabilities for 10-year incident stroke for ECG2Stroke in MGH, BWH, and BIDMC test sets. Diagonal dashed line indicates perfect calibration. Curves are obtained using restricted cubic spliness22. B-D) Stroke-free survival stratified by quintiles of ECG2Stroke predicted risk for 10-year incident stroke in MGH, BWH, and BIDMC test sets. Transparent bands indicate 95% confidence intervals for survival probability.
75+
<div style="padding: 10px; background-color: white; display: inline-block;">
76+
<img src="./ecg2stroke_performance.png" alt="Calibration and risk stratification for 10-year incident stroke prediction" />
77+
</div>
78+
79+
### Salience
80+
<div style="padding: 10px; background-color: white; display: inline-block;">
81+
Saliency maps of ECG2Stroke demarcating regions of the ECG waveform having the greatest influence on stroke risk predictions. Blue shades depict the magnitude of the gradient of predicted stroke risk with respect to the ECG waveform amplitude, where darker shades illustrate regions of the waveform exerting greater salience or influence on stroke risk predictions. Saliency was averaged over 7080 random samples from the Brigham and Women’s Hospital (BWH) test set . The black waveform depicts the median waveform in each lead among the 7080 individuals. <div style="padding: 10px; background-color: white; display: inline-block;">
82+
<img src="./ecg2stroke_salience.png" alt="ECG Salience" />
83+
</div>
84+
85+
### Architecture
86+
A schematic of the neural network architecture. The model takes 10 seconds of 12-lead ECG waveform data as input, which is processed through a series of convolutional layers. The resulting learned features are passed to fully-connected layers to produce an estimate of time to stroke (primary) as well as predictions of time to death, age, sex, and presence of AF in the ECG diagnostic statement (secondary). Arrows indicate the flow of data between layers. Conv1D, one-dimensional convolution, MaxPooling1D, one-dimensional maximum pooling.
87+
<div style="padding: 10px; background-color: white; display: inline-block;">
88+
<img src="./ecg2stroke_architecture.png" alt="Neural network architecture" />
89+
</div>
Lines changed: 3 additions & 0 deletions
Loading
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:d3dc62afacbabc56a1150198e4ec1c69b97fa8cae3412fbbf5a820ab0216888f
3+
size 225430664
Lines changed: 3 additions & 0 deletions
Loading
Lines changed: 3 additions & 0 deletions
Loading
Lines changed: 3 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)