Skip to content

[RFC]: Unified, YAML-Driven Training Architecture for Video Diffusion Models #1158

@alexzms

Description

@alexzms

Motivation.

Training video diffusion models today involves a tangle of concerns: model loading, noise scheduling, distillation algorithms, distributed strategies, checkpointing, and validation. Existing training scripts tend to hard-wire these together, making it painful to:

  1. Try a new distillation algorithm on an existing model (requires forking the training loop).
  2. Add a new model to an existing algorithm (requires re-implementing boilerplate).
  3. Switch distributed strategies (FSDP, TP, SP) without touching algorithm code.
  4. Resume, checkpoint, and validate uniformly across all combinations.

Our goal is a single, extensible training framework where each axis of variation is an independent plugin.

Proposed Change.

Architecture Overview

YAML Config
    |
    v
+------------------+     +---------------------+     +------------------+
|   Models Layer   |     |   Methods Layer     |     | Infrastructure   |
|  (per-role)      |     |   (algorithm)       |     |   Layer          |
|                  |     |                     |     |                  |
| - ModelBase      |<----|  - TrainingMethod   |---->| - Trainer        |
| - CausalModelBase|     |  - single_train_step|     | - Callbacks      |
|                  |     |  - backward         |     | - Checkpoint     |
| Roles:           |     |  - optimizers       |     | - Tracker (W&B)  |
|  student         |     |                     |     | - Dataloader     |
|  teacher         |     | Algorithms:         |     |                  |
|  critic          |     |  DMD2, SelfForcing, |     | Distributed:     |
|                  |     |  SFT, DFSFT         |     |  HSDP, TP, SP    |
+------------------+     +---------------------+     +------------------+

Three Layers

Layer Responsibility Extension point
Models (fastvideo/train/models/) Load transformer + scheduler, define predict_noise, predict_x0, add_noise, backward. Each training role (student/teacher/critic) is an independent instance. Subclass ModelBase (or CausalModelBase for streaming).
Methods (fastvideo/train/methods/) Implement the training algorithm: own role models, define single_train_step + backward, manage optimizers/schedulers. Subclass TrainingMethod.
Infrastructure (fastvideo/train/trainer.py, utils/, callbacks/) Training loop, gradient accumulation, distributed setup, checkpointing (DCP), W&B tracking, validation, EMA, grad clipping. Add callbacks; everything else is shared.

YAML-Driven Configuration

Everything is configured declaratively. The _target_ field selects the Python class to instantiate:

models:
  student:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true
  teacher:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: false
  critic:
    _target_: fastvideo.train.models.wan.WanModel
    init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
    trainable: true

method:
  _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
  rollout_mode: simulate
  dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125]
  generator_update_interval: 5
  real_score_guidance_scale: 3.5
  # ...

training:
  distributed: { num_gpus: 8, sp_size: 1, tp_size: 1 }
  data: { data_path: ..., num_latent_t: 20, num_frames: 77 }
  optimizer: { learning_rate: 2.0e-6, betas: [0.0, 0.999] }
  loop: { max_train_steps: 4000 }
  checkpoint: { output_dir: outputs/my_run }

callbacks:
  grad_clip: { max_grad_norm: 1.0 }
  validation: { pipeline_target: ..., every_steps: 100 }

To switch from DMD2 to SFT, change the method._target_ and remove the teacher/critic — no code changes.


Model Abstraction

ModelBase — Standard (Bidirectional) Models

Every role gets its own ModelBase instance owning a transformer and noise_scheduler. The base class defines:

  • prepare_batch() — Convert raw dataloader output into forward-ready TrainingBatch.
  • add_noise() — Apply forward-process noise at a given timestep.
  • predict_noise() / predict_x0() — Run the transformer and return predictions.
  • backward() — Backward pass that restores forward context (attention metadata, timesteps).
  • init_preprocessors() — Lazy-load VAE, build dataloader (called only on the student).

CausalModelBase — Streaming / Causal Models

Extends ModelBase with streaming inference primitives for causal video generation:

class CausalModelBase(ModelBase):
    def clear_caches(self, *, cache_tag: str = "pos") -> None: ...
    def predict_noise_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ...
    def predict_x0_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ...

Key design: KV caches are internal to the model instance, keyed by cache_tag. The method controls when to store (store_kv=True) vs. read-only (store_kv=False), enabling block-by-block causal rollout during training.


Supported Training Methods

1. DMD2 (Distribution Matching Distillation)

Roles: student (trainable) + teacher (frozen) + critic (trainable)

The student learns to generate clean video in few steps by matching the teacher's score function, with a critic network providing a learned fake-score baseline.

  • Rollout modes:
    • simulate — Student starts from pure noise and iteratively denoises through the full step schedule.
    • data_latent — Student denoises from a single randomly-noised data sample.
  • Losses: Generator loss (DMD gradient) + critic flow-matching loss, with alternating updates (generator_update_interval).

2. Self-Forcing (Causal DMD)

Roles: student (causal, trainable) + teacher (frozen) + critic (trainable)

Extends DMD2 for streaming/causal video generation. The key idea: during training, the student processes video in temporal chunks, using its own previously-denoised outputs as context for future chunks — simulating online autoregressive rollout.

  • Video is split into blocks of chunk_size latent frames.
  • Each block is denoised through the student's step schedule; a random early-exit step is sampled per block.
  • After denoising a block, its output is fed back (with optional context_noise) as KV cache context for subsequent blocks via predict_noise_streaming(store_kv=True).
  • Supports SDE and ODE sampling during rollout.
  • Selective gradient control: enable_gradient_in_rollout, start_gradient_frame.

3. Supervised Fine-Tuning (SFT)

Roles: student only

Standard flow-matching loss between predicted and ground-truth noise/x0.

4. Diffusion-Forcing SFT (DFSFT)

Roles: student only

SFT with inhomogeneous (per-chunk) timesteps — each temporal chunk in a video gets a different noise level. This trains the model to handle mixed-noise inputs, which is a prerequisite for causal/streaming inference where earlier frames are cleaner than later ones.


Training Loop

The Trainer runs a standard loop with pluggable method and callbacks:

for step in range(start_step, max_steps):
    for accum_iter in range(grad_accum_steps):
        batch <- dataloader
        loss_map, outputs, metrics <- method.single_train_step(batch, step)
        method.backward(loss_map, outputs)

    callbacks.on_before_optimizer_step()   # grad clipping
    method.optimizers_schedulers_step()
    method.optimizers_zero_grad()
    callbacks.on_training_step_end()       # logging
    checkpoint_manager.maybe_save(step)
    callbacks.on_validation_begin()        # periodic inference

Callbacks

  • GradNormClipCallback — Per-module gradient norm logging + global clipping.
  • ValidationCallback — Periodic inference sampling with configurable pipeline, sampling steps, and guidance scale.
  • EMACallback — Exponential moving average of student weights.

Checkpointing

  • DCP (Distributed Checkpoint) format, compatible with FSDP/HSDP.
  • Saves: model weights, optimizer states, scheduler states, RNG states (per role).
  • Full resume support: auto-restores step counter and all RNG states.

Current Status

Component Status
Core framework (trainer, config, callbacks) Implemented and tested
WanModel (Wan 2.1 T2V) Implemented and tested
WanGameModel (WanGame 2.1 I2V) Implemented and tested
WanGameCausalModel (streaming) Implemented and tested
WanCausalModel (Wan T2V causal) In progress
DMD2 method Implemented and tested
Self-Forcing method Implemented
SFT method Implemented and tested
DFSFT method Implemented and tested
DCP checkpointing + resume Implemented
EMA callback Implemented
Validation callback Implemented and tested

Feedback Period.

No response

CC List.

No response

Any Other Things.

File Structure Reference

fastvideo/train/
  trainer.py                  # Training loop
  models/
    base.py                   # ModelBase, CausalModelBase ABCs
    wan/wan.py                # Wan 2.1 T2V model plugin
    wangame/wangame.py        # WanGame 2.1 I2V model plugin
    wangame/wangame_causal.py # WanGame causal (streaming) plugin
  methods/
    base.py                   # TrainingMethod ABC
    distribution_matching/
      dmd2.py                 # DMD2 distillation
      self_forcing.py         # Self-Forcing (causal DMD)
    fine_tuning/
      finetune.py             # Supervised fine-tuning
      dfsft.py                # Diffusion-forcing SFT
  callbacks/
    grad_clip.py              # Gradient clipping + norm logging
    validation.py             # Periodic inference validation
    ema.py                    # EMA weight averaging
  entrypoint/
    train.py                  # CLI entrypoint (torchrun)
  utils/
    config.py                 # YAML parser -> RunConfig
    builder.py                # build_from_config: model/method instantiation
    training_config.py        # TrainingConfig dataclass
    dataloader.py             # Dataset/dataloader construction
    optimizer.py              # Optimizer/scheduler construction
    checkpoint.py             # DCP save/resume
    tracking.py               # W&B tracker

Before submitting a new issue...

  • Make sure you already searched for relevant issues.

Thanks to @jzhang38 heavily discussing, reviewing and modifying code!

Metadata

Metadata

Assignees

No one assigned

    Labels

    trainingTraining and finetuning

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions