Author: Vladislav Kim
Contributors: Lisa Schneider, Adrian Wolny, Christian Bender
HeartMAE is a cardiac MRI foundation model that uses masked autoencoding with optical flow guidance to learn rich spatiotemporal representations from cine MRI sequences. This repository contains code for
- Data preprocessing
- Model training
- Feature extraction
- Feature aggregation
- License and Code Attribution
Important: Model training and inference are supported only on Linux systems with NVIDIA GPUs.
First, ensure that the CUDA version specified in environment.yml (currently 11.8) matches the CUDA version installed on your system:
nvidia-smi | grep CUDACreate and activate the conda environment:
conda env create -f environment.yml
conda activate heartmae
MRI data is typically stored in DICOM format. To begin preprocessing, create a CSV file that lists the full paths to each DICOM series in the following format:
| seriesid | path |
|---|---|
| series_1 | /path/to/series_1.dcm |
| ... | ... |
We recommend converting DICOM images to .npy format for faster I/O operations.
Run the following script:
python dataprep/DICOM_to_numpy.py --csv path/to/csv_file \
-o data/numpy --print_freq 1000
--csvspecifies the path to your CSV file--print_freqcontrols how often progress updates are printed.
Note: Loading .npy files is substantially faster than loading DICOMs:
HeartMAE requires optical flow data for training. If you want to use only standard masked autoencoders without flow guidance, you can skip this step.
To precompute optical flows using the Farneback method from OpenCV:
python dataprep/optical_flow.py --img_dir data/numpy \
--csv path/to/csv_file \
-o data/opticalflow
--img_dirspecifies the directory with preprocessed.npyfiles.--csvspecifies the same CSV file listing your series.-ospecifies the output directory for optical flow files saved in.npyformat.
Before launching training, make sure your data is properly prepared:
- A CSV file listing the cardiac MRI
.npypaths. - Preprocessed
.npyCMR files. - Precomputed optical flow
.npyfiles (if using optical flow guidance)
First, export the environment variable pointing to your data directory:
export DATADIR=/path/to/your/dataset
This directory should contain:
dataload/CMR_trainset.csv- CSV file listing your training samples.cardiac_MRI/numpy— directory with preprocessed.npyMRI files.opticalflow— directory with optical flow.npyfiles.
To start distributed training (using 4 GPUs):
export CUDA_VISIBLE_DEVICES=0,1,2,3 && torchrun --nproc_per_node=4 \
--nnodes 1 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:47525 \
train_heartmae.py --world_size=4 --mask_strategy random \
--path_to_csv $DATADIR/dataload/CMR_trainset.csv \
--path_to_cmrs $DATADIR/cardiac_MRI/numpy \
--path_to_optical_flow $DATADIR/optical_flow \
--normalize_intensities \
--t_patch_size 4 --num_frames 24 --pred_t_dim 12 --num_workers 4 \
--accum_iter 4 --model mae_vit_base_patch16 --repeat_aug 4 --batch_size 10 \
--norm_pix_loss --cls_embed --sep_pos_embed --mask_ratio 0.9 \
--epochs 100 --warmup_epochs 30 \
--blr 1.5e-4 --weight_decay 0.05 --output_dir experiments/MAEst/opticalflow
Key arguments:
--mask_strategy: Masking strategy of spatiotemporal patches during training (e.g., random, tube).--normalize_intensities: Normalize input pixel intensities.- Note: to apply dataset-specific normalization, set the
--chan_meansdargument with the mean and standard deviation of your dataset.
- Note: to apply dataset-specific normalization, set the
--t_patch_size,--num_frames,--pred_t_dim: Temporal modeling parameters:t_patch_size: Temporal patch size (how many frames grouped together per token).num_frames: Number of frames sampled from each MRI sequence.pred_t_dim: Temporal prediction length (used during reconstruction).
--accum_iter: Gradient accumulation steps (to increase the effective batch size).--repeat_aug: Number of temporal clips to sample per cardiac MRI video during training.--norm_pix_loss: Use normalized pixel loss.--mask_ratio: Proportion of patches to mask during training.--blr: Base learning rate.--output_dir: Where checkpoints and logs will be saved.
Once you have trained a HeartMAE model, you can extract feature embeddings for your MRI dataset.
python inference.py --arch vit_base_patch16 \
--ckpt path/to/checkpoint.pth \
--batch_size 32 --num_workers 6 \
-o $DATADIR/embeddings \
--path_to_csv $DATADIR/dataload/CMR_trainset.csv \
--path_to_cmrs $DATADIR/cardiac_MRI/numpy \
--num_frames 24 --t_patch_size 4 --crop_size 224 --gpus 0 1
Key parameters:
--ckpt: Path to the.pthcheckpoint file of a trained HeartMAE model.--arch: Vision transformer (ViT) architecture used during training (e.g.,vit_base_patch16).-o: Output directory where extracted features will be saved.--path_to_csv: CSV listing the MRI.npyfiles for inference.--num_frames: Number of frames in each clip, matching the setting used during model training.--t_patch_size: Temporal patch size used during pretraining.--crop_size: Spatial crop size applied to MRI frames before feeding them to the model.--gpus: GPU device IDs to use for inference.
Extracted feature vectors are saved in a single .h5 (HDF5) file inside the specified output directory. The HDF5 file is organized by subject ID, with each subject containing multiple cardiac views:
- SAX (short-axis)
- SAX optical sections (if applicable)
- LAX (long-axis)
This output format allows easy access to all cardiac views per subject for downstream tasks.
After extracting features, you will have an HDF5 file containing embeddings for each subject ID, with multiple cardiac views (e.g., SAX and LAX) and spatiotemporal clips. To obtain a single feature vector per subject, you need to aggregate these embeddings across views.
The simplest approach is to compute the element-wise mean or median across all available views and clips for each subject. This results in a fixed-size embedding per subject, independent of the number of views.
To perform mean aggregation:
python aggregate_features.py \
--embed_dir $DATADIR/embeddings \
--subject_csv /path/to/subject_list.csv \
--gpu 0 --embed_dim 768 --method mean \
-o results
To use median aggregation, change:
--method median
By default, this performs aggregation across all views and clips combined, and saves the output to
outdir/embeddings_per_subject.csv.
You can optionally aggregate separately for each cardiac view by providing, e.g.:
--cardiac_views SAX 2Ch 3Ch 4Ch
In this case aggregation is performed within each view across clips and one output file is generated per view (e.g., SAX_embeddings.csv, 2Ch_embeddings.csv, etc.). If you want to disable view-wise aggregation completely, simply omit the --cardiac_views argument.
Alternatively, you can learn how to aggregate features using Multiple Instance Learning (MIL). In this setup, each subject is treated as a bag of instances (i.e., views and spatiotemporal clips), and the model learns to assign attention weights to each view based on their relevance to a supervised prediction task. We use attention-based MIL (ABMIL) with multitask regression, where the goal is to predict clinical metrics such as LVEF, LVM, LVEDV, etc.
The learned attention weights
where
To train the ABMIL model via multitask regression:
python abmil_regression.py \
--embed_dir data/embed_dir \
--regr_data data/CMR_regression.csv \
--train_data data/CMR_trainset.csv \
--val_data data/CMR_valset.csv \
--label_col_list LVEF LVM LVEDV \
--gpu 1 --lr 5e-4 --num_epochs 150 --embed_dim 768 \
--outdir experiments/ABMIL/regression \
--dropout 0.5 --scale_labels true
Key parameters:
--embed_dir: Directory containing the extracted view-level features (HDF5 file).--regr_data: CSV file with subject-level regression labels.--train_data/--val_data: CSVs listing the training and validation subjects.--label_col_list: List of label columns to use for multitask regression.--embed_dim: Dimensionality of each input embedding (must match feature extractor).--scale_labels: Normalize labels to zero mean and unit variance (recommended).--dropout: Dropout rate in ABMIL.
After training, the ABMIL model can be used to aggregate features for each subject by computing the attention-weighted pooled embedding.
Once trained, you can use the ABMIL model to aggregate embeddings for all subjects using attention weights:
python aggregate_features.py \
--embed_dir /path/to/embeddings \
--subject_csv /path/to/subject_list.csv \
--gpu 0 --embed_dim 768 --method abmil \
-o results \
--ckpt_path /path/to/ABMIL_model/checkpoint.pth
This will output a CSV file: results/embeddings_per_subject.csv with each row containing the attention-weighted subject embedding using the ABMIL model.
This code is distributed under the CC-BY-NC 4.0 (Creative Commons Attribution-NonCommercial 4.0 International) license.
This code base builds upon MAE-ST (Spatiotemporal Masked Autoencoder). The core architecture and training framework is derived from mae_st by Meta AI Research.
- Paper: Masked Autoencoders for Spatiotemporal Representation Learning (NeurIPS 2022)
- Original license: CC-BY-NC 4.0
