Skip to content

biomedia-mira/seg-cft

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Seg-CFT: Segmentor-Guided Counterfactual Fine-Tuning

Official code for the MICCAI 2025 paper:

Segmentor-Guided Counterfactual Fine-Tuning for Locally Coherent and Targeted Image Synthesis

Tian Xia, Matthew Sinclair, Andreas Schuh, Fabio De Sousa Ribeiro, Raghav Mehta, Rajat Rasal, Esther Puyol-Antón, Samuel Gerber, Kersten Petersen, Michiel Schaap, Ben Glocker

[Paper]


Overview

Seg-CFT proposes a novel counterfactual fine-tuning strategy for Deep Structural Causal Models (DSCMs). Instead of relying on pre-trained regressors (Reg-CFT), Seg-CFT uses frozen segmentors to provide spatially coherent supervision during fine-tuning, enabling locally targeted interventions on structure-specific variables (e.g., organ area) while avoiding undesirable global image changes.

Key contributions:

  • Segmentor-Guided Counterfactual Fine-Tuning (Seg-CFT) for structure-specific causal interventions
  • Comparison against Reg-CFT on PadChest chest radiographs (left lung, right lung, heart area)
  • Early results on coronary artery disease progression modeling from CCTA images

Repository Structure

.
├── causal_models/                  # HVAE and DSCM training
│   ├── main.py                     # Stage 1: HVAE training entry point
│   ├── hvae2.py                    # Hierarchical VAE architecture
│   ├── trainer.py                  # HVAE training loop
│   ├── train_setup.py              # Dataloaders, optimizer, logging setup
│   ├── hps.py                      # Hyperparameter sets (mimic256_64, padchest224_224, rsna224_224)
│   ├── utils.py                    # EMA, seeding, misc utilities
│   ├── plotting_utils.py           # Visualization helpers
│   ├── run_padchest_sa_causal.sh   # SLURM: Stage 1 (sex-age-seg causal graph)
│   ├── pgm/                        # Stages 2–3: PGM and auxiliary predictor training
│   │   ├── chest_pgm_segmentor.py  # PGM with U-Net segmentor (Seg-CFT path)
│   │   ├── chest_pgm_regressor.py  # PGM with ResNet regressor (Reg-CFT path)
│   │   ├── train_pgm_segmentor.py  # Stage 2+3 training script (Seg-CFT path)
│   │   ├── train_pgm_regressor.py  # Stage 2+3 training script (Reg-CFT path)
│   │   ├── unet.py                 # U-Net segmentor
│   │   ├── resnet.py               # ResNet regressor
│   │   ├── layers.py               # Normalizing flow layers
│   │   ├── utils_pgm.py            # PGM utilities
│   │   ├── train_padchest_pgm_slurm_segmentor.sh   # SLURM: Stage 2 PGM (Seg-CFT)
│   │   ├── train_padchest_pgm_slurm_regressor.sh   # SLURM: Stage 2 PGM (Reg-CFT)
│   │   ├── train_padchest_aux_slurm_segmentor.sh   # SLURM: Stage 3 U-Net auxiliary
│   │   └── train_padchest_aux_slurm_regressor.sh   # SLURM: Stage 3 ResNet auxiliary
│   └── dscm/                       # Stage 4: Counterfactual fine-tuning
│       ├── dscm_segmentor.py       # Seg-CFT DSCM
│       ├── dscm_regressor.py       # Reg-CFT DSCM
│       ├── train_cf_segmentor.py   # Seg-CFT fine-tuning script
│       ├── train_cf_regressor.py   # Reg-CFT fine-tuning script
│       ├── train_cf_segmentor.sh   # SLURM: Stage 4 (Seg-CFT)
│       └── train_cf_regressor.sh   # SLURM: Stage 4 (Reg-CFT)
├── data_handling/                  # Dataset loaders
│   ├── padchest.py                 # PadChest dataset
│   ├── base.py                     # Base data module
│   ├── augmentations.py            # Training augmentations
│   ├── caching.py                  # Shared memory cache
│   ├── sampler.py                  # Custom batch samplers
│   └── PadChest_meta/              # Dataset splits used in paper
│       ├── train_dataset.csv
│       ├── val_dataset.csv
│       └── test_dataset.csv
├── segmentation/                   # Segmentation mask generation
│   └── padchest_segmentation.py
└── configs/                        # Hydra config files
    ├── config.yaml
    └── data/
        ├── padchest224_224_with_seg.yaml
        └── padchest224_224.yaml

Installation

git clone https://github.com/biomedia-mira/seg-cft.git
cd seg-cft
pip install -r requirements.txt

Requirements: Python 3.9+, CUDA 11.x


Data Preparation

PadChest (Study 1: Chest Radiographs)

  1. Download the PadChest dataset.
  2. Generate segmentation masks using the TorchXRayVision package:
    python segmentation/padchest_segmentation.py --data_dir /path/to/padchest
  3. The dataset splits used in the paper are provided in data_handling/PadChest_meta/.

Note: The CCTA coronary artery dataset (Study 2) is an internal Heartflow dataset and cannot be released.


Training Pipeline

Training proceeds in four stages. All scripts use SLURM; remove the sbatch wrapper to run locally.

Stage 1 — Train HVAE

Train the base hierarchical variational autoencoder on PadChest:

cd causal_models
bash run_padchest_sa_causal.sh
# or directly:
python main.py \
    --hps padchest224_224_with_seg \
    --lr 1e-3 --batch_size 32 --wd 5e-2 --epochs 1000 \
    --exp_name padchest_224_224_beta_9_sa_causal \
    --parents sex age Left-Lung_volume Right-Lung_volume Heart_volume \
    --beta 9 --bottleneck 4 --z_max_res 32

Checkpoint saved to: causal_models/checkpoints/<parents>/<exp_name>/checkpoint.pt

Stage 2 — Train PGM (Probabilistic Graphical Model)

Train the normalizing-flow-based causal graph over attributes:

Seg-CFT path:

cd causal_models/pgm
bash train_padchest_pgm_slurm_segmentor.sh

Reg-CFT path:

bash train_padchest_pgm_slurm_regressor.sh

Stage 3 — Train Auxiliary Predictor

Seg-CFT — train U-Net segmentor:

bash train_padchest_aux_slurm_segmentor.sh

Reg-CFT — train ResNet regressor:

bash train_padchest_aux_slurm_regressor.sh

Checkpoints saved to: causal_models/pgm/checkpoints/<parents>/

Stage 4 — Counterfactual Fine-Tuning (CFT)

Fine-tune the DSCM using the pre-trained HVAE, PGM, and predictor:

Seg-CFT (proposed):

cd causal_models/dscm
bash train_cf_segmentor.sh

Reg-CFT (baseline):

bash train_cf_regressor.sh

Update --vae_path, --pgm_path, and --predictor_path in the shell scripts to point to the checkpoints from Stages 1–3.


Evaluation

Effectiveness is measured as MAPE (%) between the segmentor-predicted area of the counterfactual image and the target intervention value, using a held-out evaluation segmentor (different from the one used for fine-tuning).


Citation

@inproceedings{xia2025segcft,
  title     = {Segmentor-Guided Counterfactual Fine-Tuning for Locally Coherent and Targeted Image Synthesis},
  author    = {Xia, Tian and Sinclair, Matthew and Schuh, Andreas and {De Sousa Ribeiro}, Fabio and Mehta, Raghav and Rasal, Rajat and Puyol-Ant{\'o}n, Esther and Gerber, Samuel and Petersen, Kersten and Schaap, Michiel and Glocker, Ben},
  booktitle = {Medical Image Computing and Computer-Assisted Intervention (MICCAI)},
  year      = {2025},
  publisher = {Springer},
  doi       = {10.1007/978-3-032-04937-7_50}
}

Acknowledgements

This research was funded by Heartflow. B.G. received support from the Royal Academy of Engineering as part of his Kheiron/RAEng Research Chair. B.G. and F.R. acknowledge the support of the UKRI AI programme and the EPSRC for CHAI – EPSRC Causality in Healthcare AI Hub (grant no. EP/Y028856/1).

This codebase builds on Deep Structural Causal Models by Pawlowski et al. and the DSCM implementation from De Sousa Ribeiro et al..


License

MIT License

About

This is the repository for our seg-cft work accepted in MICCAI 2025

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors