Skip to content

mazurowski-lab/mri_foundation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MRI-CORE: A Foundation Model for MRI

Authors: Haoyu Dong, Yuwen Chen, Hanxue Gu, Nicholas Konz, Yaqian Chen, Qihang Li, Maciej A. Mazurowski

This is the official code for our paper: MRI-CORE: A Foundation Model for Magnetic Resonance Imaging, where we propose a new foundation model designed specifically for MRI.

Figure 1 provides an overview of MRI-CORE, including training data, training algorithm, and performance on the few-shot segmentation task:

Fig1: Overview of fine-tuning strategies based on dataset availability.

MRI-CORE achieves strong performance across few-shot segmentation, linear probing, and zero-shot segmentation on multiple datasets:

Fig2: Performance of MRI-CORE on multiple tasks and datasets.


📖 Citation

If you find our work useful, please cite:

@article{dong2024mricore,
  title={MRI-CORE: A Foundation Model for Magnetic Resonance Imaging},
  author={Dong, Haoyu and Chen, Yuwen and Gu, Hanxue and Konz, Nicholas and Chen, Yaqian and Li, Qihang and Mazurowski, Maciej A},
  journal={arXiv preprint arXiv:2404.09957},
  year={2024}
}

⚙️ Installation (Step 0)

We recommend creating a fresh conda environment and installing the required dependencies:

# Create environment
conda create --name myenv python=3.12
conda activate myenv

# Install PyTorch with CUDA 12.4
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

# Install other dependencies
pip install -r requirements.txt

Note: If you have a different CUDA version, please adjust the PyTorch index URL accordingly (see the official PyTorch website for the correct wheel).

📥 Download Pre-Trained Models (Step 1)

After downloading, place the checkpoints inside the pretrained_weights/ directory.

🔍 Extract Features with MRI-CORE (Step 2)

Example code to load the model and extract image embeddings:

from models.sam import sam_model_registry
import cfg

args = cfg.parse_args()
# Disable adapters when using the model only for feature extraction (no fine-tuning)
args.if_encoder_adapter = False
args.if_mask_decoder_adapter = False

model = sam_model_registry['vit_b'](
    args,
    checkpoint="PATH_TO_CHECKPOINT",
    num_classes=args.num_cls,
    image_size=args.image_size,
    pretrained_sam=False
)

# imgs should be a FloatTensor normalized to [0, 1]
# shape: [B, C, H, W]
img_emb = model.image_encoder(imgs)

🩻 Run Segmentation (Step 3 & Step 4)

Step 3 — Dataset Preprocessing

Since MRI-CORE is based on SAM (a 2D model), all MRI volumes must be sliced into 2D images. Normalize each slice to the range [0, 1].

Expected directory structure:

datasets/
  ├── images/        # 2D slices
  ├── masks/         # segmentation masks
  ├── train.txt
  ├── val.txt
  └── test.txt

We provide an example of a pre-processed dataset here. For more details on the dataset, see the original repository.

Step 4 — Few-Shot Segmentation

Once the dataset is prepared, you can run few-shot segmentation with the provided training script (example):

python main.py   --img_folder datasets/images   --mask_folder datasets/masks   --train_img_list datasets/train.txt   --val_img_list datasets/val.txt   --test_img_list datasets/test.txt   --n_type slice_norm   --image_size 1024   --b 4   --num_cls 1   --checkpoint pretrained_weights/MRI_CORE_vitb.pth

Adjust arguments (batch size, image size, normalization type, etc.) to your environment and dataset.

💡 Minimal Inference Example

import torch
from models.sam import sam_model_registry
import cfg

# 1) Load args and model
args = cfg.parse_args()
model = sam_model_registry['vit_b'](
    args,
    checkpoint="pretrained_weights/MRI_CORE_vitb.pth",
    num_classes=1,
    image_size=1024,
    pretrained_sam=True
).eval().cuda()

# 2) Prepare input (B, C, H, W), normalized to [0, 1]
imgs = torch.randn(1, 3, 1024, 1024, device="cuda")  # replace with your preprocessed image tensor

# 3) Forward pass
img_emb = model.image_encoder(imgs)
sparse_emb, dense_emb = model.prompt_encoder(points=None, boxes=None, masks=None)
pred, _ = model.mask_decoder(
    image_embeddings=img_emb,
    image_pe=model.prompt_encoder.get_dense_pe(),
    sparse_prompt_embeddings=sparse_emb,
    dense_prompt_embeddings=dense_emb,
    multimask_output=True
)

📜 License

The code and models are released under the Apache 2.0 License.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages