-
Notifications
You must be signed in to change notification settings - Fork 288
[feat] Refactor training framework into fastvideo/train #1159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jzhang38
merged 31 commits into
hao-ai-lab:main
from
FoundationResearch:train-clean-refactor
Mar 9, 2026
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
9da01fa
full training infra implementation
alexzms c20b025
remove sampler_kind scheduler kind
alexzms 36af66f
scrips and yamls
alexzms 27fd7c1
fix ema with fsdp
alexzms cb3f85c
fix wandb tracker
alexzms 7427510
no redundant predict x0 in model specific impl
alexzms 1dceae4
fix validation dmd timestep missing
alexzms 7353a4f
causal wan init impl
alexzms 05b40f8
move dev file position
alexzms 57a0ee9
Merge branch 'main' of https://github.com/hao-ai-lab/FastVideo into t…
jzhang38 4a56625
mv design to docs
jzhang38 9abae81
update ema
jzhang38 6ebcbdc
fix generator and reproducibility
jzhang38 86dcdad
testing self forcing
alexzms 3898fbc
training infra doc init impl
alexzms 36bff05
generic cli override of config. remove seperate resume
alexzms aff3174
ode init conversion, deterministic unset
alexzms 0e6d8e9
remove deprecated rollout_mode param
alexzms cfef92b
update agents
jzhang38 498a712
validation remove deprecated dmd_steps
alexzms ad4622a
fix ckpt resuming
alexzms 0a0dc8f
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 5409a6a
Merge branch 'hao-ai-lab:main' into train-clean-refactor
alexzms 4512ad5
revert unused randomstate wrapper
alexzms c32f616
precommit
alexzms f222068
slurm run script
alexzms 2449cd7
minor
alexzms 17f1e38
+ resume from checkpoint latest
jzhang38 b3ef4e0
Merge branch 'train-clean-refactor' of https://github.com/FoundationR…
jzhang38 680bc8e
precommit
jzhang38 0674a89
Merge remote-tracking branch 'origin/train-clean-refactor' into train…
alexzms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # V3 config: WanGame causal Diffusion-Forcing SFT (DFSFT). | ||
| # | ||
| # Uses _target_-based instantiation — each model role is an independent | ||
| # class instance; the method class is resolved directly from the YAML. | ||
|
|
||
| models: | ||
| student: | ||
| _target_: fastvideo.train.models.wangame.WanGameCausalModel | ||
| init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps | ||
| trainable: true | ||
|
|
||
| method: | ||
| _target_: fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod | ||
| attn_kind: dense | ||
| # use_ema: true | ||
| chunk_size: 3 | ||
| min_timestep_ratio: 0.02 | ||
| max_timestep_ratio: 0.98 | ||
|
|
||
| training: | ||
| distributed: | ||
| num_gpus: 8 | ||
| sp_size: 1 | ||
| tp_size: 1 | ||
| hsdp_replicate_dim: 8 | ||
| hsdp_shard_dim: 1 | ||
|
|
||
| data: | ||
| data_path: >- | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:0, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:1, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:1, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:1, | ||
| /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:1 | ||
| dataloader_num_workers: 4 | ||
| train_batch_size: 1 | ||
| training_cfg_rate: 0.0 | ||
| seed: 1000 | ||
| num_latent_t: 20 | ||
| num_height: 352 | ||
| num_width: 640 | ||
| num_frames: 77 | ||
|
|
||
| optimizer: | ||
| learning_rate: 1.0e-5 | ||
| betas: [0.9, 0.999] | ||
| weight_decay: 1.0e-4 | ||
| lr_scheduler: constant | ||
| lr_warmup_steps: 0 | ||
|
|
||
| loop: | ||
| max_train_steps: 20000 | ||
| gradient_accumulation_steps: 1 | ||
|
|
||
| checkpoint: | ||
| output_dir: outputs/wangame_dfsft_causal_v3 | ||
| training_state_checkpointing_steps: 1000 | ||
| checkpoints_total_limit: 2 | ||
|
|
||
| tracker: | ||
| project_name: distillation_wangame_r | ||
| run_name: wangame_dfsft_causal_v3 | ||
|
|
||
| model: | ||
| enable_gradient_checkpointing_type: full | ||
|
|
||
| callbacks: | ||
| grad_clip: | ||
| _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback | ||
| max_grad_norm: 1.0 | ||
| # ema: | ||
| # _target_: fastvideo.train.callbacks.ema.EMACallback | ||
| # beta: 0.9999 | ||
| validation: | ||
| _target_: fastvideo.train.callbacks.validation.ValidationCallback | ||
| pipeline_target: fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline.WanGameCausalDMDPipeline | ||
| dataset_file: examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json | ||
| every_steps: 100 | ||
| sampling_steps: [40] | ||
| rollout_mode: streaming | ||
| sampler_kind: ode | ||
| scheduler_target: fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler | ||
| guidance_scale: 1.0 | ||
| num_frames: 69 | ||
|
|
||
| pipeline: | ||
| flow_shift: 3 | ||
| sampler_kind: ode | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # DMD2 distillation: Wan 2.1 T2V 1.3B (teacher 50-step -> student 4-step). | ||
| # | ||
| # - Teacher: frozen pretrained Wan 2.1 T2V 1.3B | ||
| # - Student: trainable, initialized from the same pretrained weights | ||
| # - Critic: trainable, initialized from the same pretrained weights | ||
| # - Validation: 4-step SDE sampling | ||
|
|
||
| models: | ||
| student: | ||
| _target_: fastvideo.train.models.wan.WanModel | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||
| trainable: true | ||
| teacher: | ||
| _target_: fastvideo.train.models.wan.WanModel | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||
| trainable: false | ||
| disable_custom_init_weights: true | ||
| critic: | ||
| _target_: fastvideo.train.models.wan.WanModel | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||
| trainable: true | ||
| disable_custom_init_weights: true | ||
|
|
||
| method: | ||
| _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method | ||
| rollout_mode: simulate | ||
| generator_update_interval: 5 | ||
| real_score_guidance_scale: 4.5 | ||
| dmd_denoising_steps: [1000, 750, 500, 250] | ||
|
|
||
| # Critic optimizer (required — no fallback to training.optimizer) | ||
| fake_score_learning_rate: 8.0e-6 | ||
| fake_score_betas: [0.0, 0.999] | ||
| fake_score_lr_scheduler: constant | ||
|
|
||
| training: | ||
| distributed: | ||
| num_gpus: 8 | ||
| sp_size: 1 | ||
| tp_size: 1 | ||
| hsdp_replicate_dim: 1 | ||
| hsdp_shard_dim: 8 | ||
|
|
||
| data: | ||
| data_path: data/Wan-Syn_77x448x832_600k | ||
| dataloader_num_workers: 4 | ||
| train_batch_size: 1 | ||
| training_cfg_rate: 0.0 | ||
| seed: 1000 | ||
| num_latent_t: 20 | ||
| num_height: 448 | ||
| num_width: 832 | ||
| num_frames: 77 | ||
|
|
||
| optimizer: | ||
| learning_rate: 2.0e-6 | ||
| betas: [0.0, 0.999] | ||
| weight_decay: 0.01 | ||
| lr_scheduler: constant | ||
| lr_warmup_steps: 0 | ||
|
|
||
| loop: | ||
| max_train_steps: 4000 | ||
| gradient_accumulation_steps: 1 | ||
|
|
||
| checkpoint: | ||
| output_dir: outputs/wan2.1_dmd2_4steps | ||
| training_state_checkpointing_steps: 1000 | ||
| checkpoints_total_limit: 3 | ||
|
|
||
| tracker: | ||
| project_name: distillation_wan | ||
| run_name: wan2.1_dmd2_4steps | ||
|
|
||
| model: | ||
| enable_gradient_checkpointing_type: full | ||
|
|
||
| callbacks: | ||
| grad_clip: | ||
| max_grad_norm: 1.0 | ||
| validation: | ||
| pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline | ||
| dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json | ||
| every_steps: 50 | ||
| sampling_steps: [4] | ||
| sampler_kind: sde | ||
| sampling_timesteps: [1000, 750, 500, 250] | ||
| guidance_scale: 6.0 | ||
|
|
||
| pipeline: | ||
| flow_shift: 8 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,208 @@ | ||
| # ============================================================================== | ||
| # Full configuration reference for fastvideo.train | ||
| # | ||
| # Legend: | ||
| # [TYPED] — parsed into a typed dataclass; fields are validated with | ||
| # defaults. Unknown keys are silently ignored. | ||
| # [FREE] — free-form dict passed as-is to the target class / method. | ||
| # Keys depend on the _target_ class constructor / method_config. | ||
| # [RESOLVED] — parsed by PipelineConfig.from_kwargs(); auto-populated from | ||
| # the model's config files. Only scalar overrides are useful. | ||
| # ============================================================================== | ||
|
|
||
| # ------------------------------------------------------------------------------ | ||
| # models: [FREE] | ||
| # | ||
| # Each role is instantiated via _target_(*, training_config=..., **kwargs). | ||
| # Keys here are constructor kwargs of the _target_ class (e.g. WanModel). | ||
| # You can define any role name (student, teacher, critic, etc.). | ||
| # ------------------------------------------------------------------------------ | ||
| models: | ||
| student: | ||
| _target_: fastvideo.train.models.wan.WanModel # required | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # required: HF repo or local path | ||
| trainable: true # default: true | ||
| disable_custom_init_weights: false # default: false | ||
| flow_shift: 3.0 # default: 3.0 | ||
| enable_gradient_checkpointing_type: null # default: null (falls back to training.model) | ||
|
|
||
| teacher: | ||
| _target_: fastvideo.train.models.wan.WanModel | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||
| trainable: false | ||
| disable_custom_init_weights: true | ||
|
|
||
| critic: | ||
| _target_: fastvideo.train.models.wan.WanModel | ||
| init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||
| trainable: true | ||
| disable_custom_init_weights: true | ||
|
|
||
| # ------------------------------------------------------------------------------ | ||
| # method: [FREE] | ||
| # | ||
| # Instantiated via _target_(*, cfg=RunConfig, role_models=...). | ||
| # All keys besides _target_ are available in self.method_config (a plain dict). | ||
| # Keys depend entirely on the method class. | ||
| # ------------------------------------------------------------------------------ | ||
| method: | ||
| _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method # required | ||
|
|
||
| # --- DMD2-specific keys (read from self.method_config) --- | ||
| rollout_mode: simulate # required: "simulate" or "data_latent" | ||
| generator_update_interval: 5 # default: 1 | ||
| dmd_denoising_steps: [1000, 750, 500, 250] # SDE timestep schedule | ||
|
|
||
| # Critic optimizer (all required — no fallback) | ||
| fake_score_learning_rate: 8.0e-6 | ||
| fake_score_betas: [0.0, 0.999] | ||
| fake_score_lr_scheduler: constant | ||
|
|
||
| # CFG conditioning policy (optional) | ||
| # cfg_uncond: | ||
| # on_missing: error # "error" or "ignore" | ||
| # text: keep # "keep", "zero", "drop", "negative_prompt" | ||
| # image: keep # "keep", "zero", "drop" | ||
| # action: keep # "keep", "zero", "drop" | ||
|
|
||
| # --- FineTuneMethod keys (if using finetune instead) --- | ||
| # _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod | ||
| # attn_kind: vsa # "dense" or "vsa" | ||
| # use_ema: false | ||
|
|
||
| # ------------------------------------------------------------------------------ | ||
| # training: [TYPED] -> TrainingConfig | ||
| # | ||
| # Every field below has a typed default. Unknown keys are ignored. | ||
| # ------------------------------------------------------------------------------ | ||
| training: | ||
|
|
||
| # --- training.distributed [TYPED] -> DistributedConfig --- | ||
| distributed: | ||
| num_gpus: 8 # default: 1 | ||
| tp_size: 1 # default: 1 | ||
| sp_size: 1 # default: 1 (defaults to num_gpus in loader) | ||
| hsdp_replicate_dim: 1 # default: 1 | ||
| hsdp_shard_dim: 8 # default: -1 (defaults to num_gpus in loader) | ||
| pin_cpu_memory: false # default: false | ||
|
|
||
| # --- training.data [TYPED] -> DataConfig --- | ||
| data: | ||
| data_path: data/my_dataset # default: "" | ||
| train_batch_size: 1 # default: 1 | ||
| dataloader_num_workers: 4 # default: 0 | ||
| training_cfg_rate: 0.1 # default: 0.0 | ||
| seed: 1000 # default: 0 | ||
| num_height: 448 # default: 0 | ||
| num_width: 832 # default: 0 | ||
| num_latent_t: 20 # default: 0 | ||
| num_frames: 77 # default: 0 | ||
|
|
||
| # --- training.optimizer [TYPED] -> OptimizerConfig --- | ||
| # Note: only for the student optimizer. Critic optimizer is in method config. | ||
| optimizer: | ||
| learning_rate: 2.0e-6 # default: 0.0 | ||
| betas: [0.9, 0.999] # default: [0.9, 0.999] | ||
| weight_decay: 0.01 # default: 0.0 | ||
| lr_scheduler: constant # default: "constant" | ||
| lr_warmup_steps: 0 # default: 0 | ||
| lr_num_cycles: 0 # default: 0 | ||
| lr_power: 0.0 # default: 0.0 | ||
| min_lr_ratio: 0.5 # default: 0.5 | ||
|
|
||
| # --- training.loop [TYPED] -> TrainingLoopConfig --- | ||
| loop: | ||
| max_train_steps: 10000 # default: 0 | ||
| gradient_accumulation_steps: 1 # default: 1 | ||
|
|
||
| # --- training.checkpoint [TYPED] -> CheckpointConfig --- | ||
| checkpoint: | ||
| output_dir: outputs/my_run # default: "" | ||
| resume_from_checkpoint: "" # default: "" (or use --resume-from-checkpoint CLI) | ||
| training_state_checkpointing_steps: 1000 # default: 0 (disabled) | ||
| checkpoints_total_limit: 3 # default: 0 (keep all) | ||
|
|
||
| # --- training.tracker [TYPED] -> TrackerConfig --- | ||
| tracker: | ||
| trackers: [] # default: [] (auto-adds "wandb" if project_name is set) | ||
| project_name: my_project # default: "fastvideo" | ||
| run_name: my_run # default: "" | ||
|
|
||
| # --- training.vsa [TYPED] -> VSAConfig --- | ||
| vsa: | ||
| sparsity: 0.0 # default: 0.0 (0.0 = disabled) | ||
| decay_rate: 0.0 # default: 0.0 | ||
| decay_interval_steps: 0 # default: 0 | ||
|
|
||
| # --- training.model [TYPED] -> ModelTrainingConfig --- | ||
| model: | ||
| weighting_scheme: uniform # default: "uniform" | ||
| logit_mean: 0.0 # default: 0.0 | ||
| logit_std: 1.0 # default: 1.0 | ||
| mode_scale: 1.0 # default: 1.0 | ||
| precondition_outputs: false # default: false | ||
| moba_config: {} # default: {} | ||
| enable_gradient_checkpointing_type: full # default: null ("full" or null) | ||
|
|
||
| # --- training top-level [TYPED] --- | ||
| dit_precision: fp32 # default: "fp32" (master weight precision) | ||
| # model_path: ... # default: "" (auto-derived from models.student.init_from) | ||
|
|
||
| # ------------------------------------------------------------------------------ | ||
| # callbacks: [FREE] | ||
| # | ||
| # Each callback is instantiated via _target_(*, **kwargs). | ||
| # The callback name (e.g. "grad_clip") is arbitrary — only _target_ matters. | ||
| # training_config is injected automatically (not from YAML). | ||
| # ------------------------------------------------------------------------------ | ||
| callbacks: | ||
|
|
||
| # --- GradNormClipCallback --- | ||
| grad_clip: | ||
| _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback # optional if using default registry | ||
| max_grad_norm: 1.0 # default: 0.0 (0.0 = disabled) | ||
| log_grad_norms: false # default: false | ||
|
|
||
| # --- EMACallback --- | ||
| # ema: | ||
| # _target_: fastvideo.train.callbacks.ema.EMACallback | ||
| # type: constant # default: "constant" ("constant", "power", "halflife") | ||
| # beta: 0.9999 # default: 0.9999 (for constant type) | ||
| # gamma: 16.97 # default: 16.97 (for power type) | ||
| # ema_halflife_kimg: 500.0 # default: 500.0 (for halflife type) | ||
| # ema_rampup_ratio: 0.05 # default: 0.05 (for halflife type) | ||
| # start_iter: 0 # default: 0 | ||
| # batch_size: 1 # default: 1 | ||
|
|
||
| # --- ValidationCallback --- | ||
| validation: | ||
| _target_: fastvideo.train.callbacks.validation.ValidationCallback # optional if using default registry | ||
| pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline # required | ||
| dataset_file: path/to/validation.json # required | ||
| every_steps: 100 # default: 100 | ||
| sampling_steps: [4] # default: [40] | ||
| sampler_kind: sde # default: "ode" (use "sde" for few-step distilled models) | ||
| scheduler_target: null # default: null (_target_ for scheduler class, e.g. | ||
| # fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler | ||
| # fastvideo.models.schedulers.scheduling_flow_unipc_multistep.FlowUniPCMultistepScheduler) | ||
| guidance_scale: 5.0 # default: null (uses model default) | ||
| num_frames: null # default: null (derived from training.data) | ||
| output_dir: null # default: null (falls back to training.checkpoint.output_dir) | ||
| sampling_timesteps: null # default: null (explicit timestep list for SDE) | ||
| rollout_mode: parallel # default: "parallel" ("parallel" or "streaming") | ||
|
|
||
| # ------------------------------------------------------------------------------ | ||
| # pipeline: [RESOLVED] -> PipelineConfig | ||
| # | ||
| # Parsed by PipelineConfig.from_kwargs(). Most fields are auto-populated from | ||
| # the model's config files (vae_config, dit_config, text_encoder_configs, etc.). | ||
| # Only scalar overrides are typically needed here. | ||
| # ------------------------------------------------------------------------------ | ||
| pipeline: | ||
| flow_shift: 3 # default: null (model-specific) | ||
| # flow_shift_sr: null # default: null (super-resolution shift) | ||
| # embedded_cfg_scale: 6.0 # default: 6.0 | ||
| # is_causal: false # default: false | ||
| # vae_tiling: true # default: true | ||
| # vae_sp: true # default: true | ||
| # disable_autocast: false # default: false |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.