Masked diffusion (MDM) and uniform diffusion (UDM) forward/backward processes.
A research-oriented PyTorch Lightning framework for discrete diffusion language modeling. The repository is intentionally organized so new diffusion processes, losses, samplers, metrics, callbacks, tokenizers, datasets, and model architectures can be added without rewriting the training loop.
Use the available CLAUDE.md file to vibecode through the codebase.
- Masked discrete diffusion language models (MDM / MDLM-style absorbing-mask diffusion)
- Uniform discrete diffusion language models (UDM / UDLM-style uniform corruption)
- Autoregressive GPT-style language model baselines
- AR-to-MDM attention-mask annealing via
attn_mask_annealing_ratio - Optional graph MDM and graph UDM experiments for molecular graphs
- Lightning training, validation, prediction, checkpointing, resume, and distributed strategies
- EMA weights for training and evaluation
- SDPA and FlashAttention-compatible DiT attention backends
torch.compilesupport- bf16/mixed precision through Lightning trainer configs
- Gradient checkpointing
- Gradient norm logging
- Generative perplexity, entropy, token distribution KL, sliced Wasserstein, sensitivity, and gradient-moment metrics
- CSV and Weights & Biases loggers
- Hydra config composition for experiments and variants
This repository follows the config-driven style popularized by Stability AI's
generative-models codebase:
modules are small, named in YAML, and assembled at runtime with instantiate_from_config().
In practice, a run is just a composition of replaceable parts:
model_configbuilds the neural architecture.loss_configdefines the training objective.sampler_configdefines the reverse process used for generation.- data configs define tokenization, sequence length, batch size, and dataset loading.
- Lightning owns the training loop, checkpointing, logging, precision, and distributed execution.
The goal is to make research changes local. To try a new sampler, loss, metric, callback, or
architecture, add a Python class, point a Hydra config at it with target and params, and reuse
the same Lightning loop.
| Component | Location |
|---|---|
| Lightning entry point | hydra_main.py |
| Base language model loop and metrics | dlm/models/baselm.py |
| Diffusion LM Lightning module | dlm/models/diffusion.py |
| Autoregressive Lightning module | dlm/models/autoregressive.py |
| Graph Lightning module | dlm/models/graph.py |
| MDM loss and sampler | dlm/modules/diffusionmodules/masked_loss.py, masked_sampling.py |
| UDM loss and sampler | dlm/modules/diffusionmodules/uniform_loss.py, uniform_sampling.py |
| Noise schedules and time sampling | dlm/modules/diffusionmodules/noise_sampling.py, time_sampling.py |
| DiT text architectures | dlm/transformers/dit/ |
| GPT-2 AR architecture wrapper | dlm/transformers/gpt2/ |
| Graph architectures | dlm/transformers/graph/ |
| Text datamodules | dlm/data/base.py, dlm/data/text8.py |
| Molecule datamodules | dlm/data/molecules.py |
| Metrics | dlm/metrics/ |
| Callbacks | dlm/callbacks/ |
| Hydra configs | conf/ |
| Launch examples | bash_commands/ |
- OpenWebText defaults to the GPT-2 tokenizer:
conf/data/owt/owt.yaml. - LM1B defaults to the BERT tokenizer:
conf/data/lm1b/lm1b_bert_packed.yaml.
Install dependencies:
uv sync --no-install-package flash-attn
uv sync
# make sure to login huggingface and wandb
uv run huggingface-cli login
uv run wandb loginPreprocess a dataset:
uv run python preprocess_dataset.py conf/data/owt/owt.yaml
uv run python preprocess_dataset.py conf/data/lm1b/lm1b_bert_packed.yamlTrain OWT MDM:
uv run python hydra_main.py +experiments/owt=mdlm_trainTrain LM1B UDM:
uv run python hydra_main.py +experiments/lm1b=udlm_bert_trainTrain an AR baseline:
uv run python hydra_main.py +experiments/owt=autoregressive_trainEnable AR-to-MDM attention-mask annealing:
uv run python hydra_main.py +experiments/owt=mdlm_train attn_mask_annealing_ratio=0.2Experiment YAMLs define the run type, dataset, model family, and log root. Variant YAMLs define reusable model-size or sampler choices. Shared defaults live in files such as conf/experiments/owt/_train_common.yaml, conf/experiments/lm1b/_train_common.yaml, and conf/experiments/molecules/variants/_udlm_common.yaml.
OWT and LM1B MDM runs are checkpointed and evaluated every 20,000 optimizer steps by default. LM1B UDM, text8, and graph runs use their own defaults, usually 10,000 or 40,000 steps; see the experiment YAMLs and conf/lightning/callbacks/ckpt.yaml.
For local runs, use the dataset wrappers under bash_commands/<dataset>/:
bash bash_commands/owt/experiment_train_local.sh +experiments/owt=mdlm_train
bash bash_commands/lm1b/experiment_train_local.sh +experiments/lm1b=udlm_bert_train
bash bash_commands/owt/experiment_sample_local.sh +experiments/owt=mdlm_sampleFor multi-GPU or multi-node jobs, define the scheduler environment variables and launch with torchrun:
export NUM_GPUS=${NUM_GPUS:-8}
export NUM_NODES=${NUM_NODES:-1}
export MASTER_ADDR=${MASTER_ADDR:-$(hostname)}
export MASTER_PORT=${MASTER_PORT:-29500}
srun .venv/bin/torchrun \
--nproc_per_node=${NUM_GPUS} \
--nnodes=${NUM_NODES} \
--node_rank=${SLURM_NODEID:-0} \
--rdzv_backend=c10d \
--rdzv_id=${SLURM_JOB_ID:-0} \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
hydra_main.py \
num_gpu_devices=${NUM_GPUS} \
num_nodes=${NUM_NODES} \
+experiments/owt=mdlm_trainFor offline jobs, pre-download datasets and models first, then add hf_offline=true and, if needed, hf_cache_dir=/path/to/cache to the Hydra overrides.
Sampling and evaluation configs resume from the matching training log root by default. To choose a specific run or checkpoint, add:
resume.from_dir=/path/to/run_or_parent
resume.ckpt_path=/path/to/checkpoint.ckptExamples:
bash bash_commands/owt/experiment_sample_local.sh \
+experiments/owt=mdlm_sample \
resume.from_dir=logs/models/owt/elbo_false_mdlm_owt_BL
bash bash_commands/lm1b/experiment_sample_local.sh \
+experiments/lm1b=udlm_bert_sample \
resume.ckpt_path=/path/to/checkpoints/last.ckptEach run creates a timestamped directory under the configured log root:
logs/models/<dataset>/<experiment_root>/<timestamp>_<config_tag>/
Inside a run directory:
checkpoints/contains Lightning checkpoints.last.ckptis saved when enabled, and periodic checkpoints follow the callback filename pattern in the experiment config.configs/contains resolved Hydra config snapshots, includingcomposed_config.yaml.csv/version_*/metrics.csvcontains scalar training or evaluation metrics.wandb_id.txtstores the W&B run ID when W&B is enabled.wandb/contains local W&B run files. For offline runs, sync withuv run wandb sync <run_dir>/wandb/offline-run-*.
Sampling and evaluation runs use their own log roots under paths such as logs/models/owt/sampling/..., so their CSVs are separate from training CSVs. Prediction mode stores generated tensors under logs/generated_samples/.
- Add the implementation under
dlm/modules/diffusionmodules/or a new focused module. - Expose the loss through a Hydra
loss_configinconf/base/<method>.yaml. - Expose the sampler through
sampler_configin the same base config. - Add or reuse a model config under
conf/model/. - Add an experiment config under
conf/experiments/<dataset>/.
- Implement a
torch.nn.Moduleortransformers.PreTrainedModelunderdlm/transformers/<name>/. - Return an object with a
.logitstensor shaped[batch, sequence, vocab]. - Add a Hydra target in
conf/model/<dataset>/<model>.yaml. - Select it from an experiment with
override /model: <dataset>/<model>.
For DiT-style text models, the main reusable implementation is dlm.transformers.dit.dit.DiscreteDiTModel.
Common metric flags live under model.params in the base configs:
enable_entropyenable_generative_perplexityenable_gradient_moment_metricenable_sliced_wassersteinenable_token_distribution_klenable_sensitivity_metricenable_grad_norms_logging
CSV logging is always enabled. Set enable_wandb=false for CSV-only runs.
This repository builds off of the following works.
We are grateful for their work. We also thank Makoto Shing for providing a first iteration of the codebase and further help and advice.
If you use this framework in your research, please cite it as:
@software{shariatian2026simpleDiscreteDiffusionLightning,
title = {Discrete Diffusion Language Modeling with Pytorch Lightning},
author = {Shariatian, Dario},
year = {2026},
url = {https://github.com/darioShar/ddm}
}