-
Notifications
You must be signed in to change notification settings - Fork 281
Description
Motivation.
Training video diffusion models today involves a tangle of concerns: model loading, noise scheduling, distillation algorithms, distributed strategies, checkpointing, and validation. Existing training scripts tend to hard-wire these together, making it painful to:
- Try a new distillation algorithm on an existing model (requires forking the training loop).
- Add a new model to an existing algorithm (requires re-implementing boilerplate).
- Switch distributed strategies (FSDP, TP, SP) without touching algorithm code.
- Resume, checkpoint, and validate uniformly across all combinations.
Our goal is a single, extensible training framework where each axis of variation is an independent plugin.
Proposed Change.
Architecture Overview
YAML Config
|
v
+------------------+ +---------------------+ +------------------+
| Models Layer | | Methods Layer | | Infrastructure |
| (per-role) | | (algorithm) | | Layer |
| | | | | |
| - ModelBase |<----| - TrainingMethod |---->| - Trainer |
| - CausalModelBase| | - single_train_step| | - Callbacks |
| | | - backward | | - Checkpoint |
| Roles: | | - optimizers | | - Tracker (W&B) |
| student | | | | - Dataloader |
| teacher | | Algorithms: | | |
| critic | | DMD2, SelfForcing, | | Distributed: |
| | | SFT, DFSFT | | HSDP, TP, SP |
+------------------+ +---------------------+ +------------------+
Three Layers
| Layer | Responsibility | Extension point |
|---|---|---|
Models (fastvideo/train/models/) |
Load transformer + scheduler, define predict_noise, predict_x0, add_noise, backward. Each training role (student/teacher/critic) is an independent instance. |
Subclass ModelBase (or CausalModelBase for streaming). |
Methods (fastvideo/train/methods/) |
Implement the training algorithm: own role models, define single_train_step + backward, manage optimizers/schedulers. |
Subclass TrainingMethod. |
Infrastructure (fastvideo/train/trainer.py, utils/, callbacks/) |
Training loop, gradient accumulation, distributed setup, checkpointing (DCP), W&B tracking, validation, EMA, grad clipping. | Add callbacks; everything else is shared. |
YAML-Driven Configuration
Everything is configured declaratively. The _target_ field selects the Python class to instantiate:
models:
student:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: true
teacher:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: false
critic:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: true
method:
_target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
rollout_mode: simulate
dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125]
generator_update_interval: 5
real_score_guidance_scale: 3.5
# ...
training:
distributed: { num_gpus: 8, sp_size: 1, tp_size: 1 }
data: { data_path: ..., num_latent_t: 20, num_frames: 77 }
optimizer: { learning_rate: 2.0e-6, betas: [0.0, 0.999] }
loop: { max_train_steps: 4000 }
checkpoint: { output_dir: outputs/my_run }
callbacks:
grad_clip: { max_grad_norm: 1.0 }
validation: { pipeline_target: ..., every_steps: 100 }To switch from DMD2 to SFT, change the method._target_ and remove the teacher/critic — no code changes.
Model Abstraction
ModelBase — Standard (Bidirectional) Models
Every role gets its own ModelBase instance owning a transformer and noise_scheduler. The base class defines:
prepare_batch()— Convert raw dataloader output into forward-readyTrainingBatch.add_noise()— Apply forward-process noise at a given timestep.predict_noise()/predict_x0()— Run the transformer and return predictions.backward()— Backward pass that restores forward context (attention metadata, timesteps).init_preprocessors()— Lazy-load VAE, build dataloader (called only on the student).
CausalModelBase — Streaming / Causal Models
Extends ModelBase with streaming inference primitives for causal video generation:
class CausalModelBase(ModelBase):
def clear_caches(self, *, cache_tag: str = "pos") -> None: ...
def predict_noise_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ...
def predict_x0_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ...Key design: KV caches are internal to the model instance, keyed by cache_tag. The method controls when to store (store_kv=True) vs. read-only (store_kv=False), enabling block-by-block causal rollout during training.
Supported Training Methods
1. DMD2 (Distribution Matching Distillation)
Roles: student (trainable) + teacher (frozen) + critic (trainable)
The student learns to generate clean video in few steps by matching the teacher's score function, with a critic network providing a learned fake-score baseline.
- Rollout modes:
simulate— Student starts from pure noise and iteratively denoises through the full step schedule.data_latent— Student denoises from a single randomly-noised data sample.
- Losses: Generator loss (DMD gradient) + critic flow-matching loss, with alternating updates (
generator_update_interval).
2. Self-Forcing (Causal DMD)
Roles: student (causal, trainable) + teacher (frozen) + critic (trainable)
Extends DMD2 for streaming/causal video generation. The key idea: during training, the student processes video in temporal chunks, using its own previously-denoised outputs as context for future chunks — simulating online autoregressive rollout.
- Video is split into blocks of
chunk_sizelatent frames. - Each block is denoised through the student's step schedule; a random early-exit step is sampled per block.
- After denoising a block, its output is fed back (with optional
context_noise) as KV cache context for subsequent blocks viapredict_noise_streaming(store_kv=True). - Supports SDE and ODE sampling during rollout.
- Selective gradient control:
enable_gradient_in_rollout,start_gradient_frame.
3. Supervised Fine-Tuning (SFT)
Roles: student only
Standard flow-matching loss between predicted and ground-truth noise/x0.
4. Diffusion-Forcing SFT (DFSFT)
Roles: student only
SFT with inhomogeneous (per-chunk) timesteps — each temporal chunk in a video gets a different noise level. This trains the model to handle mixed-noise inputs, which is a prerequisite for causal/streaming inference where earlier frames are cleaner than later ones.
Training Loop
The Trainer runs a standard loop with pluggable method and callbacks:
for step in range(start_step, max_steps):
for accum_iter in range(grad_accum_steps):
batch <- dataloader
loss_map, outputs, metrics <- method.single_train_step(batch, step)
method.backward(loss_map, outputs)
callbacks.on_before_optimizer_step() # grad clipping
method.optimizers_schedulers_step()
method.optimizers_zero_grad()
callbacks.on_training_step_end() # logging
checkpoint_manager.maybe_save(step)
callbacks.on_validation_begin() # periodic inference
Callbacks
- GradNormClipCallback — Per-module gradient norm logging + global clipping.
- ValidationCallback — Periodic inference sampling with configurable pipeline, sampling steps, and guidance scale.
- EMACallback — Exponential moving average of student weights.
Checkpointing
- DCP (Distributed Checkpoint) format, compatible with FSDP/HSDP.
- Saves: model weights, optimizer states, scheduler states, RNG states (per role).
- Full resume support: auto-restores step counter and all RNG states.
Current Status
| Component | Status |
|---|---|
| Core framework (trainer, config, callbacks) | Implemented and tested |
WanModel (Wan 2.1 T2V) |
Implemented and tested |
WanGameModel (WanGame 2.1 I2V) |
Implemented and tested |
WanGameCausalModel (streaming) |
Implemented and tested |
WanCausalModel (Wan T2V causal) |
In progress |
| DMD2 method | Implemented and tested |
| Self-Forcing method | Implemented |
| SFT method | Implemented and tested |
| DFSFT method | Implemented and tested |
| DCP checkpointing + resume | Implemented |
| EMA callback | Implemented |
| Validation callback | Implemented and tested |
Feedback Period.
No response
CC List.
No response
Any Other Things.
File Structure Reference
fastvideo/train/
trainer.py # Training loop
models/
base.py # ModelBase, CausalModelBase ABCs
wan/wan.py # Wan 2.1 T2V model plugin
wangame/wangame.py # WanGame 2.1 I2V model plugin
wangame/wangame_causal.py # WanGame causal (streaming) plugin
methods/
base.py # TrainingMethod ABC
distribution_matching/
dmd2.py # DMD2 distillation
self_forcing.py # Self-Forcing (causal DMD)
fine_tuning/
finetune.py # Supervised fine-tuning
dfsft.py # Diffusion-forcing SFT
callbacks/
grad_clip.py # Gradient clipping + norm logging
validation.py # Periodic inference validation
ema.py # EMA weight averaging
entrypoint/
train.py # CLI entrypoint (torchrun)
utils/
config.py # YAML parser -> RunConfig
builder.py # build_from_config: model/method instantiation
training_config.py # TrainingConfig dataclass
dataloader.py # Dataset/dataloader construction
optimizer.py # Optimizer/scheduler construction
checkpoint.py # DCP save/resume
tracking.py # W&B tracker
Before submitting a new issue...
- Make sure you already searched for relevant issues.