Authors: Haoyu Dong, Yuwen Chen, Hanxue Gu, Nicholas Konz, Yaqian Chen, Qihang Li, Maciej A. Mazurowski
This is the official code for our paper: MRI-CORE: A Foundation Model for Magnetic Resonance Imaging, where we propose a new foundation model designed specifically for MRI.
Figure 1 provides an overview of MRI-CORE, including training data, training algorithm, and performance on the few-shot segmentation task:
MRI-CORE achieves strong performance across few-shot segmentation, linear probing, and zero-shot segmentation on multiple datasets:
If you find our work useful, please cite:
@article{dong2024mricore,
title={MRI-CORE: A Foundation Model for Magnetic Resonance Imaging},
author={Dong, Haoyu and Chen, Yuwen and Gu, Hanxue and Konz, Nicholas and Chen, Yaqian and Li, Qihang and Mazurowski, Maciej A},
journal={arXiv preprint arXiv:2404.09957},
year={2024}
}We recommend creating a fresh conda environment and installing the required dependencies:
# Create environment
conda create --name myenv python=3.12
conda activate myenv
# Install PyTorch with CUDA 12.4
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
# Install other dependencies
pip install -r requirements.txtNote: If you have a different CUDA version, please adjust the PyTorch index URL accordingly (see the official PyTorch website for the correct wheel).
- MRI-CORE checkpoint: Download here
- (Optional) SAM weights (for few-shot segmentation): Download here
After downloading, place the checkpoints inside the pretrained_weights/ directory.
Example code to load the model and extract image embeddings:
from models.sam import sam_model_registry
import cfg
args = cfg.parse_args()
# Disable adapters when using the model only for feature extraction (no fine-tuning)
args.if_encoder_adapter = False
args.if_mask_decoder_adapter = False
model = sam_model_registry['vit_b'](
args,
checkpoint="PATH_TO_CHECKPOINT",
num_classes=args.num_cls,
image_size=args.image_size,
pretrained_sam=False
)
# imgs should be a FloatTensor normalized to [0, 1]
# shape: [B, C, H, W]
img_emb = model.image_encoder(imgs)Since MRI-CORE is based on SAM (a 2D model), all MRI volumes must be sliced into 2D images. Normalize each slice to the range [0, 1].
Expected directory structure:
datasets/
├── images/ # 2D slices
├── masks/ # segmentation masks
├── train.txt
├── val.txt
└── test.txt
We provide an example of a pre-processed dataset here. For more details on the dataset, see the original repository.
Once the dataset is prepared, you can run few-shot segmentation with the provided training script (example):
python main.py --img_folder datasets/images --mask_folder datasets/masks --train_img_list datasets/train.txt --val_img_list datasets/val.txt --test_img_list datasets/test.txt --n_type slice_norm --image_size 1024 --b 4 --num_cls 1 --checkpoint pretrained_weights/MRI_CORE_vitb.pthAdjust arguments (batch size, image size, normalization type, etc.) to your environment and dataset.
import torch
from models.sam import sam_model_registry
import cfg
# 1) Load args and model
args = cfg.parse_args()
model = sam_model_registry['vit_b'](
args,
checkpoint="pretrained_weights/MRI_CORE_vitb.pth",
num_classes=1,
image_size=1024,
pretrained_sam=True
).eval().cuda()
# 2) Prepare input (B, C, H, W), normalized to [0, 1]
imgs = torch.randn(1, 3, 1024, 1024, device="cuda") # replace with your preprocessed image tensor
# 3) Forward pass
img_emb = model.image_encoder(imgs)
sparse_emb, dense_emb = model.prompt_encoder(points=None, boxes=None, masks=None)
pred, _ = model.mask_decoder(
image_embeddings=img_emb,
image_pe=model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_emb,
dense_prompt_embeddings=dense_emb,
multimask_output=True
)The code and models are released under the Apache 2.0 License.

