Skip to content

darioShar/DDM-Lightning

Repository files navigation

Discrete Diffusion Language Modeling with Pytorch Lightning

Masked discrete diffusion text process Uniform discrete diffusion text process
Masked diffusion (MDM) and uniform diffusion (UDM) forward/backward processes.

A research-oriented PyTorch Lightning framework for discrete diffusion language modeling. The repository is intentionally organized so new diffusion processes, losses, samplers, metrics, callbacks, tokenizers, datasets, and model architectures can be added without rewriting the training loop.

Vibecoding

Use the available CLAUDE.md file to vibecode through the codebase.

Supported Methods

  • Masked discrete diffusion language models (MDM / MDLM-style absorbing-mask diffusion)
  • Uniform discrete diffusion language models (UDM / UDLM-style uniform corruption)
  • Autoregressive GPT-style language model baselines
  • AR-to-MDM attention-mask annealing via attn_mask_annealing_ratio
  • Optional graph MDM and graph UDM experiments for molecular graphs

Features

  • Lightning training, validation, prediction, checkpointing, resume, and distributed strategies
  • EMA weights for training and evaluation
  • SDPA and FlashAttention-compatible DiT attention backends
  • torch.compile support
  • bf16/mixed precision through Lightning trainer configs
  • Gradient checkpointing
  • Gradient norm logging
  • Generative perplexity, entropy, token distribution KL, sliced Wasserstein, sensitivity, and gradient-moment metrics
  • CSV and Weights & Biases loggers
  • Hydra config composition for experiments and variants

Design Philosophy

This repository follows the config-driven style popularized by Stability AI's generative-models codebase: modules are small, named in YAML, and assembled at runtime with instantiate_from_config().

In practice, a run is just a composition of replaceable parts:

  • model_config builds the neural architecture.
  • loss_config defines the training objective.
  • sampler_config defines the reverse process used for generation.
  • data configs define tokenization, sequence length, batch size, and dataset loading.
  • Lightning owns the training loop, checkpointing, logging, precision, and distributed execution.

The goal is to make research changes local. To try a new sampler, loss, metric, callback, or architecture, add a Python class, point a Hydra config at it with target and params, and reuse the same Lightning loop.

Repository Map

Component Location
Lightning entry point hydra_main.py
Base language model loop and metrics dlm/models/baselm.py
Diffusion LM Lightning module dlm/models/diffusion.py
Autoregressive Lightning module dlm/models/autoregressive.py
Graph Lightning module dlm/models/graph.py
MDM loss and sampler dlm/modules/diffusionmodules/masked_loss.py, masked_sampling.py
UDM loss and sampler dlm/modules/diffusionmodules/uniform_loss.py, uniform_sampling.py
Noise schedules and time sampling dlm/modules/diffusionmodules/noise_sampling.py, time_sampling.py
DiT text architectures dlm/transformers/dit/
GPT-2 AR architecture wrapper dlm/transformers/gpt2/
Graph architectures dlm/transformers/graph/
Text datamodules dlm/data/base.py, dlm/data/text8.py
Molecule datamodules dlm/data/molecules.py
Metrics dlm/metrics/
Callbacks dlm/callbacks/
Hydra configs conf/
Launch examples bash_commands/

Default Tokenizers

  • OpenWebText defaults to the GPT-2 tokenizer: conf/data/owt/owt.yaml.
  • LM1B defaults to the BERT tokenizer: conf/data/lm1b/lm1b_bert_packed.yaml.

Quick Start

Install dependencies:

uv sync --no-install-package flash-attn
uv sync
# make sure to login huggingface and wandb
uv run huggingface-cli login
uv run wandb login

Preprocess a dataset:

uv run python preprocess_dataset.py conf/data/owt/owt.yaml
uv run python preprocess_dataset.py conf/data/lm1b/lm1b_bert_packed.yaml

Train OWT MDM:

uv run python hydra_main.py +experiments/owt=mdlm_train

Train LM1B UDM:

uv run python hydra_main.py +experiments/lm1b=udlm_bert_train

Train an AR baseline:

uv run python hydra_main.py +experiments/owt=autoregressive_train

Enable AR-to-MDM attention-mask annealing:

uv run python hydra_main.py +experiments/owt=mdlm_train attn_mask_annealing_ratio=0.2

Running Experiments

Experiment YAMLs define the run type, dataset, model family, and log root. Variant YAMLs define reusable model-size or sampler choices. Shared defaults live in files such as conf/experiments/owt/_train_common.yaml, conf/experiments/lm1b/_train_common.yaml, and conf/experiments/molecules/variants/_udlm_common.yaml.

OWT and LM1B MDM runs are checkpointed and evaluated every 20,000 optimizer steps by default. LM1B UDM, text8, and graph runs use their own defaults, usually 10,000 or 40,000 steps; see the experiment YAMLs and conf/lightning/callbacks/ckpt.yaml.

For local runs, use the dataset wrappers under bash_commands/<dataset>/:

bash bash_commands/owt/experiment_train_local.sh +experiments/owt=mdlm_train
bash bash_commands/lm1b/experiment_train_local.sh +experiments/lm1b=udlm_bert_train
bash bash_commands/owt/experiment_sample_local.sh +experiments/owt=mdlm_sample

For multi-GPU or multi-node jobs, define the scheduler environment variables and launch with torchrun:

export NUM_GPUS=${NUM_GPUS:-8}
export NUM_NODES=${NUM_NODES:-1}
export MASTER_ADDR=${MASTER_ADDR:-$(hostname)}
export MASTER_PORT=${MASTER_PORT:-29500}

srun .venv/bin/torchrun \
  --nproc_per_node=${NUM_GPUS} \
  --nnodes=${NUM_NODES} \
  --node_rank=${SLURM_NODEID:-0} \
  --rdzv_backend=c10d \
  --rdzv_id=${SLURM_JOB_ID:-0} \
  --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
  hydra_main.py \
  num_gpu_devices=${NUM_GPUS} \
  num_nodes=${NUM_NODES} \
  +experiments/owt=mdlm_train

For offline jobs, pre-download datasets and models first, then add hf_offline=true and, if needed, hf_cache_dir=/path/to/cache to the Hydra overrides.

Sampling From A Checkpoint

Sampling and evaluation configs resume from the matching training log root by default. To choose a specific run or checkpoint, add:

resume.from_dir=/path/to/run_or_parent
resume.ckpt_path=/path/to/checkpoint.ckpt

Examples:

bash bash_commands/owt/experiment_sample_local.sh \
  +experiments/owt=mdlm_sample \
  resume.from_dir=logs/models/owt/elbo_false_mdlm_owt_BL

bash bash_commands/lm1b/experiment_sample_local.sh \
  +experiments/lm1b=udlm_bert_sample \
  resume.ckpt_path=/path/to/checkpoints/last.ckpt

Outputs And Checkpoints

Each run creates a timestamped directory under the configured log root:

logs/models/<dataset>/<experiment_root>/<timestamp>_<config_tag>/

Inside a run directory:

  • checkpoints/ contains Lightning checkpoints. last.ckpt is saved when enabled, and periodic checkpoints follow the callback filename pattern in the experiment config.
  • configs/ contains resolved Hydra config snapshots, including composed_config.yaml.
  • csv/version_*/metrics.csv contains scalar training or evaluation metrics.
  • wandb_id.txt stores the W&B run ID when W&B is enabled.
  • wandb/ contains local W&B run files. For offline runs, sync with uv run wandb sync <run_dir>/wandb/offline-run-*.

Sampling and evaluation runs use their own log roots under paths such as logs/models/owt/sampling/..., so their CSVs are separate from training CSVs. Prediction mode stores generated tensors under logs/generated_samples/.

How To Add A Method

  1. Add the implementation under dlm/modules/diffusionmodules/ or a new focused module.
  2. Expose the loss through a Hydra loss_config in conf/base/<method>.yaml.
  3. Expose the sampler through sampler_config in the same base config.
  4. Add or reuse a model config under conf/model/.
  5. Add an experiment config under conf/experiments/<dataset>/.

How To Add A Model Architecture

  1. Implement a torch.nn.Module or transformers.PreTrainedModel under dlm/transformers/<name>/.
  2. Return an object with a .logits tensor shaped [batch, sequence, vocab].
  3. Add a Hydra target in conf/model/<dataset>/<model>.yaml.
  4. Select it from an experiment with override /model: <dataset>/<model>.

For DiT-style text models, the main reusable implementation is dlm.transformers.dit.dit.DiscreteDiTModel.

Logging And Evaluation

Common metric flags live under model.params in the base configs:

  • enable_entropy
  • enable_generative_perplexity
  • enable_gradient_moment_metric
  • enable_sliced_wasserstein
  • enable_token_distribution_kl
  • enable_sensitivity_metric
  • enable_grad_norms_logging

CSV logging is always enabled. Set enable_wandb=false for CSV-only runs.

Acknowledgements

This repository builds off of the following works.

We are grateful for their work. We also thank Makoto Shing for providing a first iteration of the codebase and further help and advice.

Citation

If you use this framework in your research, please cite it as:

@software{shariatian2026simpleDiscreteDiffusionLightning,
  title = {Discrete Diffusion Language Modeling with Pytorch Lightning},
  author = {Shariatian, Dario},
  year = {2026},
  url = {https://github.com/darioShar/ddm}
}

About

Simple discrete diffusion language modeling with Pytorch Lightning.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors