Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9da01fa
full training infra implementation
alexzms Mar 8, 2026
c20b025
remove sampler_kind scheduler kind
alexzms Mar 8, 2026
36af66f
scrips and yamls
alexzms Mar 8, 2026
27fd7c1
fix ema with fsdp
alexzms Mar 8, 2026
cb3f85c
fix wandb tracker
alexzms Mar 8, 2026
7427510
no redundant predict x0 in model specific impl
alexzms Mar 8, 2026
1dceae4
fix validation dmd timestep missing
alexzms Mar 8, 2026
7353a4f
causal wan init impl
alexzms Mar 8, 2026
05b40f8
move dev file position
alexzms Mar 8, 2026
57a0ee9
Merge branch 'main' of https://github.com/hao-ai-lab/FastVideo into t…
jzhang38 Mar 8, 2026
4a56625
mv design to docs
jzhang38 Mar 8, 2026
9abae81
update ema
jzhang38 Mar 8, 2026
6ebcbdc
fix generator and reproducibility
jzhang38 Mar 8, 2026
86dcdad
testing self forcing
alexzms Mar 8, 2026
3898fbc
training infra doc init impl
alexzms Mar 8, 2026
36bff05
generic cli override of config. remove seperate resume
alexzms Mar 8, 2026
aff3174
ode init conversion, deterministic unset
alexzms Mar 8, 2026
0e6d8e9
remove deprecated rollout_mode param
alexzms Mar 9, 2026
cfef92b
update agents
jzhang38 Mar 9, 2026
498a712
validation remove deprecated dmd_steps
alexzms Mar 9, 2026
ad4622a
fix ckpt resuming
alexzms Mar 9, 2026
0a0dc8f
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 Mar 9, 2026
5409a6a
Merge branch 'hao-ai-lab:main' into train-clean-refactor
alexzms Mar 9, 2026
4512ad5
revert unused randomstate wrapper
alexzms Mar 9, 2026
c32f616
precommit
alexzms Mar 9, 2026
f222068
slurm run script
alexzms Mar 9, 2026
2449cd7
minor
alexzms Mar 9, 2026
17f1e38
+ resume from checkpoint latest
jzhang38 Mar 9, 2026
b3ef4e0
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 Mar 9, 2026
680bc8e
precommit
jzhang38 Mar 9, 2026
0674a89
Merge remote-tracking branch 'origin/train-clean-refactor' into train…
alexzms Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/train/dfsft_wangame_causal_v3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# V3 config: WanGame causal Diffusion-Forcing SFT (DFSFT).
#
# Uses _target_-based instantiation — each model role is an independent
# class instance; the method class is resolved directly from the YAML.

models:
student:
_target_: fastvideo.train.models.wangame.WanGameCausalModel
init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps
trainable: true

method:
_target_: fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod
attn_kind: dense
# use_ema: true
chunk_size: 3
min_timestep_ratio: 0.02
max_timestep_ratio: 0.98

training:
distributed:
num_gpus: 8
sp_size: 1
tp_size: 1
hsdp_replicate_dim: 8
hsdp_shard_dim: 1

data:
data_path: >-
/mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0,
/mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:0,
/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1,
/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1,
/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:1,
/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:1,
/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:1,
/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:1
dataloader_num_workers: 4
train_batch_size: 1
training_cfg_rate: 0.0
seed: 1000
num_latent_t: 20
num_height: 352
num_width: 640
num_frames: 77

optimizer:
learning_rate: 1.0e-5
betas: [0.9, 0.999]
weight_decay: 1.0e-4
lr_scheduler: constant
lr_warmup_steps: 0

loop:
max_train_steps: 20000
gradient_accumulation_steps: 1

checkpoint:
output_dir: outputs/wangame_dfsft_causal_v3
training_state_checkpointing_steps: 1000
checkpoints_total_limit: 2

tracker:
project_name: distillation_wangame_r
run_name: wangame_dfsft_causal_v3

model:
enable_gradient_checkpointing_type: full

callbacks:
grad_clip:
_target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback
max_grad_norm: 1.0
# ema:
# _target_: fastvideo.train.callbacks.ema.EMACallback
# beta: 0.9999
validation:
_target_: fastvideo.train.callbacks.validation.ValidationCallback
pipeline_target: fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline.WanGameCausalDMDPipeline
dataset_file: examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json
every_steps: 100
sampling_steps: [40]
rollout_mode: streaming
sampler_kind: ode
scheduler_target: fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler
guidance_scale: 1.0
num_frames: 69

pipeline:
flow_shift: 3
sampler_kind: ode
91 changes: 91 additions & 0 deletions examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# DMD2 distillation: Wan 2.1 T2V 1.3B (teacher 50-step -> student 4-step).
#
# - Teacher: frozen pretrained Wan 2.1 T2V 1.3B
# - Student: trainable, initialized from the same pretrained weights
# - Critic: trainable, initialized from the same pretrained weights
# - Validation: 4-step SDE sampling

models:
student:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: true
teacher:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: false
disable_custom_init_weights: true
critic:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: true
disable_custom_init_weights: true

method:
_target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method
rollout_mode: simulate
generator_update_interval: 5
real_score_guidance_scale: 4.5
dmd_denoising_steps: [1000, 750, 500, 250]

# Critic optimizer (required — no fallback to training.optimizer)
fake_score_learning_rate: 8.0e-6
fake_score_betas: [0.0, 0.999]
fake_score_lr_scheduler: constant

training:
distributed:
num_gpus: 8
sp_size: 1
tp_size: 1
hsdp_replicate_dim: 1
hsdp_shard_dim: 8

data:
data_path: data/Wan-Syn_77x448x832_600k
dataloader_num_workers: 4
train_batch_size: 1
training_cfg_rate: 0.0
seed: 1000
num_latent_t: 20
num_height: 448
num_width: 832
num_frames: 77

optimizer:
learning_rate: 2.0e-6
betas: [0.0, 0.999]
weight_decay: 0.01
lr_scheduler: constant
lr_warmup_steps: 0

loop:
max_train_steps: 4000
gradient_accumulation_steps: 1

checkpoint:
output_dir: outputs/wan2.1_dmd2_4steps
training_state_checkpointing_steps: 1000
checkpoints_total_limit: 3

tracker:
project_name: distillation_wan
run_name: wan2.1_dmd2_4steps

model:
enable_gradient_checkpointing_type: full

callbacks:
grad_clip:
max_grad_norm: 1.0
validation:
pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline
dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json
every_steps: 50
sampling_steps: [4]
sampler_kind: sde
sampling_timesteps: [1000, 750, 500, 250]
guidance_scale: 6.0

pipeline:
flow_shift: 8
208 changes: 208 additions & 0 deletions examples/train/example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# ==============================================================================
# Full configuration reference for fastvideo.train
#
# Legend:
# [TYPED] — parsed into a typed dataclass; fields are validated with
# defaults. Unknown keys are silently ignored.
# [FREE] — free-form dict passed as-is to the target class / method.
# Keys depend on the _target_ class constructor / method_config.
# [RESOLVED] — parsed by PipelineConfig.from_kwargs(); auto-populated from
# the model's config files. Only scalar overrides are useful.
# ==============================================================================

# ------------------------------------------------------------------------------
# models: [FREE]
#
# Each role is instantiated via _target_(*, training_config=..., **kwargs).
# Keys here are constructor kwargs of the _target_ class (e.g. WanModel).
# You can define any role name (student, teacher, critic, etc.).
# ------------------------------------------------------------------------------
models:
student:
_target_: fastvideo.train.models.wan.WanModel # required
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # required: HF repo or local path
trainable: true # default: true
disable_custom_init_weights: false # default: false
flow_shift: 3.0 # default: 3.0
enable_gradient_checkpointing_type: null # default: null (falls back to training.model)

teacher:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: false
disable_custom_init_weights: true

critic:
_target_: fastvideo.train.models.wan.WanModel
init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
trainable: true
disable_custom_init_weights: true

# ------------------------------------------------------------------------------
# method: [FREE]
#
# Instantiated via _target_(*, cfg=RunConfig, role_models=...).
# All keys besides _target_ are available in self.method_config (a plain dict).
# Keys depend entirely on the method class.
# ------------------------------------------------------------------------------
method:
_target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method # required

# --- DMD2-specific keys (read from self.method_config) ---
rollout_mode: simulate # required: "simulate" or "data_latent"
generator_update_interval: 5 # default: 1
dmd_denoising_steps: [1000, 750, 500, 250] # SDE timestep schedule

# Critic optimizer (all required — no fallback)
fake_score_learning_rate: 8.0e-6
fake_score_betas: [0.0, 0.999]
fake_score_lr_scheduler: constant

# CFG conditioning policy (optional)
# cfg_uncond:
# on_missing: error # "error" or "ignore"
# text: keep # "keep", "zero", "drop", "negative_prompt"
# image: keep # "keep", "zero", "drop"
# action: keep # "keep", "zero", "drop"

# --- FineTuneMethod keys (if using finetune instead) ---
# _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod
# attn_kind: vsa # "dense" or "vsa"
# use_ema: false

# ------------------------------------------------------------------------------
# training: [TYPED] -> TrainingConfig
#
# Every field below has a typed default. Unknown keys are ignored.
# ------------------------------------------------------------------------------
training:

# --- training.distributed [TYPED] -> DistributedConfig ---
distributed:
num_gpus: 8 # default: 1
tp_size: 1 # default: 1
sp_size: 1 # default: 1 (defaults to num_gpus in loader)
hsdp_replicate_dim: 1 # default: 1
hsdp_shard_dim: 8 # default: -1 (defaults to num_gpus in loader)
pin_cpu_memory: false # default: false

# --- training.data [TYPED] -> DataConfig ---
data:
data_path: data/my_dataset # default: ""
train_batch_size: 1 # default: 1
dataloader_num_workers: 4 # default: 0
training_cfg_rate: 0.1 # default: 0.0
seed: 1000 # default: 0
num_height: 448 # default: 0
num_width: 832 # default: 0
num_latent_t: 20 # default: 0
num_frames: 77 # default: 0

# --- training.optimizer [TYPED] -> OptimizerConfig ---
# Note: only for the student optimizer. Critic optimizer is in method config.
optimizer:
learning_rate: 2.0e-6 # default: 0.0
betas: [0.9, 0.999] # default: [0.9, 0.999]
weight_decay: 0.01 # default: 0.0
lr_scheduler: constant # default: "constant"
lr_warmup_steps: 0 # default: 0
lr_num_cycles: 0 # default: 0
lr_power: 0.0 # default: 0.0
min_lr_ratio: 0.5 # default: 0.5

# --- training.loop [TYPED] -> TrainingLoopConfig ---
loop:
max_train_steps: 10000 # default: 0
gradient_accumulation_steps: 1 # default: 1

# --- training.checkpoint [TYPED] -> CheckpointConfig ---
checkpoint:
output_dir: outputs/my_run # default: ""
resume_from_checkpoint: "" # default: "" (or use --resume-from-checkpoint CLI)
training_state_checkpointing_steps: 1000 # default: 0 (disabled)
checkpoints_total_limit: 3 # default: 0 (keep all)

# --- training.tracker [TYPED] -> TrackerConfig ---
tracker:
trackers: [] # default: [] (auto-adds "wandb" if project_name is set)
project_name: my_project # default: "fastvideo"
run_name: my_run # default: ""

# --- training.vsa [TYPED] -> VSAConfig ---
vsa:
sparsity: 0.0 # default: 0.0 (0.0 = disabled)
decay_rate: 0.0 # default: 0.0
decay_interval_steps: 0 # default: 0

# --- training.model [TYPED] -> ModelTrainingConfig ---
model:
weighting_scheme: uniform # default: "uniform"
logit_mean: 0.0 # default: 0.0
logit_std: 1.0 # default: 1.0
mode_scale: 1.0 # default: 1.0
precondition_outputs: false # default: false
moba_config: {} # default: {}
enable_gradient_checkpointing_type: full # default: null ("full" or null)

# --- training top-level [TYPED] ---
dit_precision: fp32 # default: "fp32" (master weight precision)
# model_path: ... # default: "" (auto-derived from models.student.init_from)

# ------------------------------------------------------------------------------
# callbacks: [FREE]
#
# Each callback is instantiated via _target_(*, **kwargs).
# The callback name (e.g. "grad_clip") is arbitrary — only _target_ matters.
# training_config is injected automatically (not from YAML).
# ------------------------------------------------------------------------------
callbacks:

# --- GradNormClipCallback ---
grad_clip:
_target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback # optional if using default registry
max_grad_norm: 1.0 # default: 0.0 (0.0 = disabled)
log_grad_norms: false # default: false

# --- EMACallback ---
# ema:
# _target_: fastvideo.train.callbacks.ema.EMACallback
# type: constant # default: "constant" ("constant", "power", "halflife")
# beta: 0.9999 # default: 0.9999 (for constant type)
# gamma: 16.97 # default: 16.97 (for power type)
# ema_halflife_kimg: 500.0 # default: 500.0 (for halflife type)
# ema_rampup_ratio: 0.05 # default: 0.05 (for halflife type)
# start_iter: 0 # default: 0
# batch_size: 1 # default: 1

# --- ValidationCallback ---
validation:
_target_: fastvideo.train.callbacks.validation.ValidationCallback # optional if using default registry
pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline # required
dataset_file: path/to/validation.json # required
every_steps: 100 # default: 100
sampling_steps: [4] # default: [40]
sampler_kind: sde # default: "ode" (use "sde" for few-step distilled models)
scheduler_target: null # default: null (_target_ for scheduler class, e.g.
# fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler
# fastvideo.models.schedulers.scheduling_flow_unipc_multistep.FlowUniPCMultistepScheduler)
guidance_scale: 5.0 # default: null (uses model default)
num_frames: null # default: null (derived from training.data)
output_dir: null # default: null (falls back to training.checkpoint.output_dir)
sampling_timesteps: null # default: null (explicit timestep list for SDE)
rollout_mode: parallel # default: "parallel" ("parallel" or "streaming")

# ------------------------------------------------------------------------------
# pipeline: [RESOLVED] -> PipelineConfig
#
# Parsed by PipelineConfig.from_kwargs(). Most fields are auto-populated from
# the model's config files (vae_config, dit_config, text_encoder_configs, etc.).
# Only scalar overrides are typically needed here.
# ------------------------------------------------------------------------------
pipeline:
flow_shift: 3 # default: null (model-specific)
# flow_shift_sr: null # default: null (super-resolution shift)
# embedded_cfg_scale: 6.0 # default: 6.0
# is_causal: false # default: false
# vae_tiling: true # default: true
# vae_sp: true # default: true
# disable_autocast: false # default: false
Loading