Skip to content

Bayer-Group/HeartMAE

Repository files navigation

HeartMAE

Author: Vladislav Kim

Contributors: Lisa Schneider, Adrian Wolny, Christian Bender

HeartMAE is a cardiac MRI foundation model that uses masked autoencoding with optical flow guidance to learn rich spatiotemporal representations from cine MRI sequences. This repository contains code for

Conda environment

Important: Model training and inference are supported only on Linux systems with NVIDIA GPUs.

First, ensure that the CUDA version specified in environment.yml (currently 11.8) matches the CUDA version installed on your system:

nvidia-smi | grep CUDA

Create and activate the conda environment:

conda env create -f environment.yml
conda activate heartmae

Data preparation

CSV file with DICOM file paths

MRI data is typically stored in DICOM format. To begin preprocessing, create a CSV file that lists the full paths to each DICOM series in the following format:

seriesid path
series_1 /path/to/series_1.dcm
... ...

Preprocess DICOM files

We recommend converting DICOM images to .npy format for faster I/O operations.

Run the following script:

python dataprep/DICOM_to_numpy.py --csv path/to/csv_file \
-o data/numpy --print_freq 1000
  • --csv specifies the path to your CSV file
  • --print_freq controls how often progress updates are printed.

Note: Loading .npy files is substantially faster than loading DICOMs:

alt text

Precompute optical flows

HeartMAE requires optical flow data for training. If you want to use only standard masked autoencoders without flow guidance, you can skip this step.

To precompute optical flows using the Farneback method from OpenCV:

python dataprep/optical_flow.py --img_dir data/numpy \
--csv path/to/csv_file \
-o data/opticalflow
  • --img_dir specifies the directory with preprocessed .npy files.
  • --csv specifies the same CSV file listing your series.
  • -o specifies the output directory for optical flow files saved in .npy format.

Training HeartMAE

Before launching training, make sure your data is properly prepared:

  • A CSV file listing the cardiac MRI .npy paths.
  • Preprocessed .npy CMR files.
  • Precomputed optical flow .npy files (if using optical flow guidance)

Data directory setup

First, export the environment variable pointing to your data directory:

export DATADIR=/path/to/your/dataset

This directory should contain:

  • dataload/CMR_trainset.csv - CSV file listing your training samples.
  • cardiac_MRI/numpy — directory with preprocessed .npy MRI files.
  • opticalflow — directory with optical flow .npy files.

Launch training

To start distributed training (using 4 GPUs):

export CUDA_VISIBLE_DEVICES=0,1,2,3 && torchrun --nproc_per_node=4 \
--nnodes 1 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:47525 \
train_heartmae.py --world_size=4 --mask_strategy random \
--path_to_csv $DATADIR/dataload/CMR_trainset.csv \
--path_to_cmrs $DATADIR/cardiac_MRI/numpy \
--path_to_optical_flow $DATADIR/optical_flow \
--normalize_intensities \
--t_patch_size 4 --num_frames 24 --pred_t_dim 12 --num_workers 4 \
--accum_iter 4  --model mae_vit_base_patch16 --repeat_aug 4 --batch_size 10 \
--norm_pix_loss --cls_embed --sep_pos_embed --mask_ratio 0.9 \
--epochs 100  --warmup_epochs 30 \
--blr 1.5e-4  --weight_decay 0.05 --output_dir experiments/MAEst/opticalflow

Key arguments:

  • --mask_strategy: Masking strategy of spatiotemporal patches during training (e.g., random, tube).
  • --normalize_intensities: Normalize input pixel intensities.
    • Note: to apply dataset-specific normalization, set the --chan_meansd argument with the mean and standard deviation of your dataset.
  • --t_patch_size, --num_frames, --pred_t_dim: Temporal modeling parameters:
    • t_patch_size: Temporal patch size (how many frames grouped together per token).
    • num_frames: Number of frames sampled from each MRI sequence.
    • pred_t_dim: Temporal prediction length (used during reconstruction).
  • --accum_iter: Gradient accumulation steps (to increase the effective batch size).
  • --repeat_aug: Number of temporal clips to sample per cardiac MRI video during training.
  • --norm_pix_loss: Use normalized pixel loss.
  • --mask_ratio: Proportion of patches to mask during training.
  • --blr: Base learning rate.
  • --output_dir: Where checkpoints and logs will be saved.

Extract features with a pretrained model

Once you have trained a HeartMAE model, you can extract feature embeddings for your MRI dataset.

python inference.py --arch vit_base_patch16 \
--ckpt path/to/checkpoint.pth \
--batch_size 32 --num_workers 6 \
-o $DATADIR/embeddings \
--path_to_csv $DATADIR/dataload/CMR_trainset.csv \
--path_to_cmrs $DATADIR/cardiac_MRI/numpy \
--num_frames 24 --t_patch_size 4 --crop_size 224 --gpus 0 1

Key parameters:

  • --ckpt: Path to the .pth checkpoint file of a trained HeartMAE model.
  • --arch: Vision transformer (ViT) architecture used during training (e.g., vit_base_patch16).
  • -o: Output directory where extracted features will be saved.
  • --path_to_csv: CSV listing the MRI .npy files for inference.
  • --num_frames: Number of frames in each clip, matching the setting used during model training.
  • --t_patch_size: Temporal patch size used during pretraining.
  • --crop_size: Spatial crop size applied to MRI frames before feeding them to the model.
  • --gpus: GPU device IDs to use for inference.

Extracted feature vectors are saved in a single .h5 (HDF5) file inside the specified output directory. The HDF5 file is organized by subject ID, with each subject containing multiple cardiac views:

  • SAX (short-axis)
    • SAX optical sections (if applicable)
  • LAX (long-axis)

This output format allows easy access to all cardiac views per subject for downstream tasks.

Feature aggregation per subject

After extracting features, you will have an HDF5 file containing embeddings for each subject ID, with multiple cardiac views (e.g., SAX and LAX) and spatiotemporal clips. To obtain a single feature vector per subject, you need to aggregate these embeddings across views.

Mean/median aggregation

The simplest approach is to compute the element-wise mean or median across all available views and clips for each subject. This results in a fixed-size embedding per subject, independent of the number of views.

To perform mean aggregation:

python aggregate_features.py \
--embed_dir $DATADIR/embeddings \
--subject_csv /path/to/subject_list.csv \
--gpu 0 --embed_dim 768 --method mean \
-o results

To use median aggregation, change:

--method median

By default, this performs aggregation across all views and clips combined, and saves the output to outdir/embeddings_per_subject.csv.

You can optionally aggregate separately for each cardiac view by providing, e.g.:

--cardiac_views SAX 2Ch 3Ch 4Ch

In this case aggregation is performed within each view across clips and one output file is generated per view (e.g., SAX_embeddings.csv, 2Ch_embeddings.csv, etc.). If you want to disable view-wise aggregation completely, simply omit the --cardiac_views argument.

Learnable aggregation: Attention-based MIL

Alternatively, you can learn how to aggregate features using Multiple Instance Learning (MIL). In this setup, each subject is treated as a bag of instances (i.e., views and spatiotemporal clips), and the model learns to assign attention weights to each view based on their relevance to a supervised prediction task. We use attention-based MIL (ABMIL) with multitask regression, where the goal is to predict clinical metrics such as LVEF, LVM, LVEDV, etc.

The learned attention weights $a_v$ are used to compute a weighted sum over view embeddings:

$$\mathbf{z} = \sum_{v=1}^V a_v \mathbf{h}_v$$

where $\mathbf{h}_v$ is the feature vector for view $v$, and $\mathbf{z}$ is the pooled subject-level embedding. This can provide us with an aggregator that is optimized for downstream tasks.

Train an ABMIL aggregator

To train the ABMIL model via multitask regression:

python abmil_regression.py \
--embed_dir data/embed_dir \
--regr_data data/CMR_regression.csv \
--train_data data/CMR_trainset.csv \
--val_data data/CMR_valset.csv \
--label_col_list LVEF LVM LVEDV \
--gpu 1 --lr 5e-4 --num_epochs 150 --embed_dim 768 \
--outdir experiments/ABMIL/regression \
--dropout 0.5 --scale_labels true

Key parameters:

  • --embed_dir: Directory containing the extracted view-level features (HDF5 file).
  • --regr_data: CSV file with subject-level regression labels.
  • --train_data / --val_data: CSVs listing the training and validation subjects.
  • --label_col_list: List of label columns to use for multitask regression.
  • --embed_dim: Dimensionality of each input embedding (must match feature extractor).
  • --scale_labels: Normalize labels to zero mean and unit variance (recommended).
  • --dropout: Dropout rate in ABMIL.

After training, the ABMIL model can be used to aggregate features for each subject by computing the attention-weighted pooled embedding.

ABMIL feature aggregation

Once trained, you can use the ABMIL model to aggregate embeddings for all subjects using attention weights:

python aggregate_features.py \
--embed_dir /path/to/embeddings \
--subject_csv /path/to/subject_list.csv \
--gpu 0 --embed_dim 768 --method abmil \
-o results \
--ckpt_path /path/to/ABMIL_model/checkpoint.pth

This will output a CSV file: results/embeddings_per_subject.csv with each row containing the attention-weighted subject embedding using the ABMIL model.

License and Code Attribution

This code is distributed under the CC-BY-NC 4.0 (Creative Commons Attribution-NonCommercial 4.0 International) license.

This code base builds upon MAE-ST (Spatiotemporal Masked Autoencoder). The core architecture and training framework is derived from mae_st by Meta AI Research.

About

No description or website provided.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •