By Yaqian Chen, Hanxue Gu, Haoyu Dong, Qihang Li, Yuwen Chen, Nicholas Konz, Lin Li and Maciej Mazurowski
Please check the Google Drive for Guided-Trans-DT weights and the ISPY2 external test dataset.
- update arxiv
- add feature extraction gpu version
This is the official code for our paper:
GuidedMorph: Two-Stage Deformable Registration for Breast MRI

- Installation
- Data Preparation
- Configuration
- Training
- Inference & Evaluation
- File Structure
- Citation
- License
-
Clone the repository
git clone https://github.com/yourusername/GuidedMorph.git cd GuidedMorph -
Install dependencies
- Recommended: Use conda for environment management.
- Create environment:
conda env create -f environment.yml conda activate guidedmorph
- Or manually install requirements:
pip install torch torchvision numpy matplotlib natsort
-
Input Format: All data should be preprocessed and saved as
.pklfiles, where each file contains:x(moving image),y(fixed image)- Followed by pairs of segmentation masks:
x_seg1, y_seg1, x_seg2, y_seg2, ...(for each label)
-
Example directory structure:
demo/ train/ case1.pkl case2.pkl ... test/ case1.pkl case2.pkl ... -
Data loader will automatically handle any number of label pairs per case.
All parameters are managed via JSON config files:
- Training:
config.json - Inference:
infer_config.json
Example (config.json):
{
"GPU_iden": 0,
"batch_size": 1,
"train_dir": "demo/train/",
"save_frequency": 2,
"lr": 0.0005,
"epoch_start": 0,
"max_epoch": 15000,
"img_size": [128, 256, 256],
"cont_training": false,
"weights": [1.0, 1.0, 0.08],
"architecture": "UNet_Cbam_STN"
}Example (infer_config.json):
{
"test_dir": "demo/test/",
"img_size": [128, 256, 256],
"weights": [1, 1, 0.06],
"model_idx": -1,
"model_type": "VxmDense_2",
"model_folder_template": "vxm_2_mse_{0}_diffusion_{1}_{2}_2/"
}To train the model:
python train_vxm.py- All training parameters are controlled by
config.json. - Checkpoints and logs will be saved in the
experiments/directory.
To run inference and evaluate Dice for each label:
python infer.py- All inference parameters are controlled by
infer_config.json. - The script will automatically compute Dice for each label in every test case and print the mean and standard deviation.
GuidedMorph/
├── data/
│ ├── data_utils.py
│ ├── datasets.py
│ └── ...
├── edge.py
├── feature_extract.py
├── infer.py
├── infer_config.json
├── losses.py
├── models.py
├── train_vxm.py
├── config.json
├── utils.py
├── experiments/
│ └── ... (checkpoints, logs)
├── demo/
│ ├── train/
│ └── test/
└── README.md
If you use this code or our method in your research, please cite:
@article{chen2024guidedmorph,
title={GuidedMorph: Two-Stage Deformable Registration for Breast MRI},
author={Chen, Yaqian and Gu, Hanxue and Dong, Haoyu and Li, Qihang and Chen, Yuwen and Konz, Nicholas and Li, Lin and Mazurowski, Maciej},
journal={arXiv preprint arXiv:2505.13414},
year={2024}
}For questions or collaborations, please contact Yaqian Chen or open an issue on GitHub.