Images generated by SR-DiT-B/1 (140M parameter diffusion model) with 400k training steps
This repository contains the reference implementation for SR-DiT (Speedrun Diffusion Transformer), a framework that combines representation alignment (REG-style), token routing (SPRINT), architectural improvements, and training modifications on top of a SiT-B/1 backbone with the INVAE tokenizer.
Links:
- Code: https://github.com/SwayStar123/SpeedrunDiT
- Checkpoints: https://huggingface.co/SwayStar123/SpeedrunDiT/tree/main
- W&B runs: https://wandb.ai/kagaku-ai/REG/
- Ablations (branches): https://github.com/SwayStar123/REG
- ImageNet-256 (400K iters, no CFG): FID 3.49, KDD 0.319, 140M params, sampling at NFE=250
- ImageNet-512 (400K iters, no CFG): FID 4.23, KDD 0.306, sampling at NFE=250
SR-DiT builds on top of a strong baseline (REG + INVAE) and then progressively adds:
- Semantic latent space via E2E-INVAE
- SPRINT token routing
- RMSNorm, RoPE, QK normalization, value residual learning
- Contrastive Flow Matching (CFM)
- Time shifting and balanced label sampling (for evaluation)
train.py: training loop (Accelerate)generate.py: multi-GPU sampling to.pngand.npzevaluations/evaluator.py: computes FID/sFID/IS/Precision/Recall from.npzpreprocessing/dataset_tools.py: ImageNet preprocessing + INVAE encodingtrain.sh,eval.sh: example scripts used for our runs
Create an environment (python 3.11) and install dependencies:
pip install -r requirements.txtTraining expects a directory (passed via --data-dir) containing:
dataset/
images/ # preprocessed ImageNet images (256x256 or 512x512)
vae-in/ # INVAE latents (.npy) + dataset.json labels
Follow the preprocessing guide in preprocessing/README.md. The minimal flow is:
# 1) Convert raw ImageNet to resized/cropped PNG dataset
python preprocessing/dataset_tools.py convert --source /path/to/imagenet/train \
--dest dataset/images --resolution=256x256 --transform=center-crop-dhariwal
# 2) Encode images to INVAE latents
python preprocessing/dataset_tools.py encode --source dataset/images \
--dest dataset/vae-inPreprocessed dataset is also uploaded here:
https://huggingface.co/datasets/SwayStar123/repa-imagenet-256/blob/main/dataset.zip
https://huggingface.co/datasets/SwayStar123/repa-imagenet-256/blob/main/vae-in.zip
You must first unzip the dataset.zip file, and then unzip the vae-in.zip inside the newly created dataset folder
An example command is provided in train.sh:
bash train.shKey arguments:
--model: useSiT-B/1for the SR-DiT-B/1 configuration--data-dir: directory containingimages/andvae-in/--qk-norm: enables QK normalization--cfm-coeff,--cfm-weighting: CFM settings--time-shifting,--shift-base: time shifting for training
Checkpoints are written to:
exps/<exp-name>/checkpoints/<step>.pt
eval.sh runs sampling (generate.py) and then computes metrics (evaluations/evaluator.py).
bash eval.shNotes:
generate.pycurrently supports--mode sde(theodebranch is not implemented).- For metric computation, download the matching reference batch listed in
evaluations/README.md. - Balanced label sampling can be enabled via
--balanced-samplingwhen generating samples.
If you use this repository, please cite SR-DiT:
@misc{bhanded2025speedrundit,
title = {Speedrunning ImageNet Diffusion},
author = {Bhanded, Swayam},
year = {2025},
eprint = {2512.12386},
archivePrefix = {arXiv},
primaryClass = {cs.CV},
url = {https://arxiv.org/abs/2512.12386},
}Please open a GitHub issue for any questions or issues.
This codebase builds upon:
- REG / REPA
- SiT
- DINOv2
- ADM evaluations
- NVLabs
edm2preprocessing utilities
We gratefully acknowledge support from WayfarerLabs (Open World Labs) for sponsoring compute resources used in this work.