diff --git a/.gitignore b/.gitignore index 05cc73ff3..80b999144 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,9 @@ env **.pyc **.txt *.log +*.npy weights/ +slurm_outputs/ # SSIM test outputs fastvideo/tests/ssim/generated_videos/ diff --git a/docs/wangame/zero_init_fixes.md b/docs/wangame/zero_init_fixes.md new file mode 100644 index 000000000..3a397657f --- /dev/null +++ b/docs/wangame/zero_init_fixes.md @@ -0,0 +1,51 @@ +# Zero Initialization Fixes Summary + +## Problem +New parameters (`action_embedder`, `to_out_prope`) were not learning - weights stayed at zero after training. + +## Root Causes & Fixes + +### 1. FSDP Loader Overwriting Model Initialization + +**File:** `fastvideo/models/loader/fsdp_load.py` + +**Problem:** FSDP loader initialized ALL new parameters (not in checkpoint) with zeros, overwriting the model's `__init__` initialization. + +**Fix:** Added `KAIMING_INIT_PATTERNS` to selectively apply proper initialization: + +```python +ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress", "proj_l", "to_out_prope", "action_embedder"] +KAIMING_INIT_PATTERNS = ["fc_in.weight", "lora_A"] # Input projections need non-zero init + +for new_param_name in unused_keys: + use_kaiming = any(pattern in new_param_name for pattern in KAIMING_INIT_PATTERNS) + if use_kaiming: + nn.init.kaiming_uniform_(tensor, a=math.sqrt(5)) # Non-zero for gradient flow + else: + torch.zeros_like(...) # Zero for output projections (residual behavior) +``` + +**Why:** +- Input projections (`fc_in.weight`) need non-zero weights for gradients to flow +- Output projections (`fc_out.weight`) should be zero-initialized for stable residual learning (ControlNet/adapter pattern) + +### 2. Attention Mask Shape Mismatch + +**File:** `fastvideo/models/dits/wangame/hyworld_action_module.py` + +**Problem:** Attention mask had shape `[B, L]` but query tensor had shape `[2*B, L, ...]` (rope + prope concatenated). The prope batch (second half) had no mask coverage → output was zeros. + +**Fix:** + +```python +# Before (wrong): +attention_mask = torch.ones(batch_size, seq_len, ...) # [B, L] + +# After (correct): +attention_mask = torch.ones(batch_size * 2, seq_len, ...) # [2*B, L] +``` + +## Files Modified + +1. `fastvideo/models/loader/fsdp_load.py` - KAIMING_INIT_PATTERNS +2. `fastvideo/models/dits/wangame/hyworld_action_module.py` - attention mask shape diff --git a/examples/distill/SFWanGame2.1/distill_dmd.sh b/examples/distill/SFWanGame2.1/distill_dmd.sh new file mode 100644 index 000000000..c819ef1aa --- /dev/null +++ b/examples/distill/SFWanGame2.1/distill_dmd.sh @@ -0,0 +1,140 @@ +#!/bin/bash +#SBATCH --job-name=t2v +#SBATCH --partition=main +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:1 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=dmd_t2v_output/t2v_%j.out +#SBATCH --error=dmd_t2v_output/t2v_%j.err +#SBATCH --exclusive + +# Basic Info +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +# different cache dir for different processes +# export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29503 +export TOKENIZERS_PARALLELISM=false +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN + +# Configs +NUM_GPUS=64 + +# Model paths for Self-Forcing DMD distillation: +GENERATOR_MODEL_PATH="../WanGame-2.1" +REAL_SCORE_MODEL_PATH="../WanGame-2.1" # Teacher model +FAKE_SCORE_MODEL_PATH="../WanGame-2.1-" # Critic model + +DATA_DIR="../FastvideoWorldModel-MC/preprocessed" +VALIDATION_DATASET_FILE="examples/distill/SFWanGame2.1/validation.json" + +training_args=( + --tracker_project_name wangame_distill_self_forcing_dmd + --output_dir "checkpoints/wangame_distill_self_forcing_dmd" + --wandb_run_name "0202_1010_steps2000_bs_64" + --max_train_steps 500 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 21 + --num_height 352 + --num_width 640 + --enable_gradient_checkpointing_type "full" + --log_visualization + --simulate_generator_forward + --num_frames 81 + --num_frame_per_block 3 # Frame generation block size for self-forcing + --enable_gradient_masking + --gradient_mask_last_n_frames 21 + # --resume_from_checkpoint "checkpoints/wangame_distill_self_forcing_dmd/checkpoint-100" +) + +parallel_args=( + --num_gpus $NUM_GPUS # 64 + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 # 64 + --hsdp_shard_dim $NUM_GPUS +) + +model_args=( + --model_path $GENERATOR_MODEL_PATH # TODO: check if you can remove this in this script + --pretrained_model_name_or_path $GENERATOR_MODEL_PATH + --real_score_model_path $REAL_SCORE_MODEL_PATH + --fake_score_model_path $FAKE_SCORE_MODEL_PATH +) + +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 50 + --validation_sampling_steps "4" + --validation_guidance_scale "6.0" # not used for dmd inference +) + +optimizer_args=( + --learning_rate 6e-6 + --mixed_precision "bf16" + --training_state_checkpointing_steps 50 + --weight_only_checkpointing_steps 50 + --weight_decay 0.01 + --betas '0.0,0.999' + --max_grad_norm 1.0 +) + +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.0 + --dit_precision "fp32" + --flow_shift 5 + --seed 1000 + --use_ema True + --ema_decay 0.99 + --ema_start_step 100 + --init_weights_from_safetensors "checkpoints/wangame_ode_init_64gpu/checkpoint-2000/transformer" +) + +dmd_args=( + --dmd_denoising_steps '1000,750,500,250' + --min_timestep_ratio 0.02 + --max_timestep_ratio 0.98 + --dfake_gen_update_ratio 5 + --real_score_guidance_scale 3.0 + --fake_score_learning_rate 8e-6 + --fake_score_betas '0.0,0.999' + --warp_denoising_step +) + +self_forcing_args=( + --independent_first_frame False # Whether to treat first frame independently + --same_step_across_blocks True # Whether to use same denoising step across all blocks + --last_step_only False # Whether to only use the last denoising step + --context_noise 0 # Amount of noise to add during context caching (0 = no noise) +) + +torchrun \ +--nnodes 1 \ +--master_port $MASTER_PORT \ +--nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_self_forcing_distillation_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ + "${dmd_args[@]}" \ + "${self_forcing_args[@]}" diff --git a/examples/distill/SFWanGame2.1/distill_dmd.slurm b/examples/distill/SFWanGame2.1/distill_dmd.slurm new file mode 100644 index 000000000..b3ec595d3 --- /dev/null +++ b/examples/distill/SFWanGame2.1/distill_dmd.slurm @@ -0,0 +1,161 @@ +#!/bin/bash +#SBATCH --job-name=wg-sf +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=log/sf_train_output/ode_%j.out +#SBATCH --error=log/sf_train_output/ode_%j.err +#SBATCH --exclusive + +set -e -x + +# Environment Setup +source ~/conda/miniconda/bin/activate +conda activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +# Basic Info +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_MODE="online" +export NCCL_P2P_DISABLE=1 +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false + +echo "MASTER_ADDR: $MASTER_ADDR" +echo "NODE_RANK: $NODE_RANK" + +RUN_NAME=$(date +"%m%d_%H%M") +echo "RUN_NAME: $RUN_NAME" + +# Model paths for Self-Forcing DMD distillation: +# GENERATOR_MODEL_PATH="../wg_models/WanGame-2.1-Student-VizDoom1k-1000steps-Diffusers" +# GENERATOR_MODEL_PATH="../wg_models/SFWanGame-2.1-0223-9000steps" +GENERATOR_MODEL_PATH="../wg_models/SFWanGame-2.1-0224-4k5steps" +REAL_SCORE_MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" # Teacher model +FAKE_SCORE_MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" # Critic model + +# DATA_DIR="../traindata_0222_0030/ode_init_mc_with_mouse/preprocessed_wangame" +DATA_DIR="../traindata_0222_0030/ode_init_mc_Xonly_3k/preprocessed" +VALIDATION_DATASET_FILE="examples/training/consistency_finetune/causal_wangame_ode_init/validation.json" + +# Training arguments +training_args=( + --tracker_project_name "wangame_sf" + --output_dir "checkpoints/wangame_sf_${RUN_NAME}" + --wandb_run_name "${RUN_NAME}_bs32" + --max_train_steps 3000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 21 + --num_height 352 + --num_width 640 + --enable_gradient_checkpointing_type "full" + --log_visualization + --simulate_generator_forward + --num_frames 81 + --num_frame_per_block 3 # Frame generation block size for self-forcing + --enable_gradient_masking + --gradient_mask_last_n_frames 21 + # --init_weights_from_safetensors $CKPT_SAFETENSOR +) + +# Parallel arguments +parallel_args=( + --num_gpus 32 + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim 32 +) + +model_args=( + --model_path $GENERATOR_MODEL_PATH # TODO: check if you can remove this in this script + --pretrained_model_name_or_path $GENERATOR_MODEL_PATH + --real_score_model_path $REAL_SCORE_MODEL_PATH + --fake_score_model_path $FAKE_SCORE_MODEL_PATH +) + +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +# Validation arguments +validation_args=( + --log_validation + --log_visualization + --visualization-steps 100 + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "4" + --validation_guidance_scale "6.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 6e-6 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 100 + --training_state_checkpointing_steps 100 + --weight_decay 0.01 + --betas '0.0,0.999' + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.0 + --dit_precision "fp32" + --flow_shift 5 + --seed 1000 + --use_ema True + --ema_decay 0.99 + --ema_start_step 200 +) + +dmd_args=( + --dmd_denoising_steps '1000,750,500,250' + --min_timestep_ratio 0.02 + --max_timestep_ratio 0.98 + --dfake_gen_update_ratio 5 + --real_score_guidance_scale 3.0 + --fake_score_learning_rate 8e-6 + --fake_score_betas '0.0,0.999' + --warp_denoising_step +) + +self_forcing_args=( + --independent_first_frame False # Whether to treat first frame independently + --same_step_across_blocks True # Whether to use same denoising step across all blocks + --last_step_only False # Whether to only use the last denoising step + --context_noise 0 # Amount of noise to add during context caching (0 = no noise) +) + +mkdir -p log/sf_train_output + +srun torchrun \ +--nnodes $SLURM_JOB_NUM_NODES \ +--nproc_per_node 8 \ +--node_rank $SLURM_PROCID \ +--rdzv_backend=c10d \ +--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ + fastvideo/training/wangame_self_forcing_distillation_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ + "${dmd_args[@]}" \ + "${self_forcing_args[@]}" diff --git a/examples/distill/SFWanGame2.1/validation.json b/examples/distill/SFWanGame2.1/validation.json new file mode 100644 index 000000000..d97352dd7 --- /dev/null +++ b/examples/distill/SFWanGame2.1/validation.json @@ -0,0 +1,164 @@ +{ + "data": [ + { + "caption": "Hold [W] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [S] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000001_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [A] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000002_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [D] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000003_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [W] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [S] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000001_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [A] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000002_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [D] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000003_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [W] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [S] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000001_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [A] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000002_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [D] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000003_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [W] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [S] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000001_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [A] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000002_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "Hold [D] + Static", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/vizdoom/gen/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/FastVideo_kaiqin/examples/training/finetune/MatrixGame2.0/action/000003_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + } + ] +} \ No newline at end of file diff --git a/examples/distill/WanGame2.1/distill_dmd.slurm b/examples/distill/WanGame2.1/distill_dmd.slurm new file mode 100644 index 000000000..6e4a800e0 --- /dev/null +++ b/examples/distill/WanGame2.1/distill_dmd.slurm @@ -0,0 +1,146 @@ +#!/bin/bash +#SBATCH --job-name=wg-dmd +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=log/dmd_train_output/dmd_%j.out +#SBATCH --error=log/dmd_train_output/dmd_%j.err +#SBATCH --exclusive + +set -e -x + +# Environment Setup +source ~/conda/miniconda/bin/activate +conda activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +# Basic Info +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_MODE="online" +export NCCL_P2P_DISABLE=1 +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false + +echo "MASTER_ADDR: $MASTER_ADDR" +echo "NODE_RANK: $NODE_RANK" + +RUN_NAME=$(date +"%m%d_%H%M") +echo "RUN_NAME: $RUN_NAME" + +GENERATOR_MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" +REAL_SCORE_MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" # Teacher model +FAKE_SCORE_MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" # Critic model + +DATA_DIR="../traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed" +VALIDATION_DATASET_FILE="examples/distill/WanGame2.1/validation.json" + +# Training arguments +training_args=( + --tracker_project_name "wangame_dmd" + --output_dir "checkpoints/wangame_dmd_${RUN_NAME}" + --wandb_run_name "${RUN_NAME}_dmd" + --max_train_steps 3000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" + --training_state_checkpointing_steps 500 + --weight_only_checkpointing_steps 500 +) + +# Parallel arguments +parallel_args=( + --num_gpus 32 + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim 32 +) + +model_args=( + --model_path $GENERATOR_MODEL_PATH + --pretrained_model_name_or_path $GENERATOR_MODEL_PATH + --real_score_model_path $REAL_SCORE_MODEL_PATH + --fake_score_model_path $FAKE_SCORE_MODEL_PATH +) + +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +# Validation arguments +validation_args=( + --log_validation + --log_visualization + --visualization-steps 200 + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 200 + --validation_sampling_steps "4" + --validation_guidance_scale "6.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 2e-6 + --mixed_precision "bf16" + --weight_decay 0.01 + --betas '0.0,0.999' + --max_grad_norm 1.0 + --fake_score_learning_rate 8e-6 + --fake_score_betas '0.0,0.999' +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.0 + --dit_precision "fp32" + --flow_shift 5 + --seed 1000 + --use_ema True + --ema_decay 0.99 + --ema_start_step 200 +) + +# DMD-specific arguments +dmd_args=( + --dmd_denoising_steps '1000,750,500,250' + --min_timestep_ratio 0.02 + --max_timestep_ratio 0.98 + --dfake_gen_update_ratio 5 + --real_score_guidance_scale 3.0 + --fake_score_learning_rate 8e-6 + --fake_score_betas '0.0,0.999' + --warp_denoising_step +) + +mkdir -p log/dmd_train_output + +srun torchrun \ +--nnodes $SLURM_JOB_NUM_NODES \ +--nproc_per_node 8 \ +--node_rank $SLURM_PROCID \ +--rdzv_backend=c10d \ +--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ + fastvideo/training/wangame_distillation_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ + "${dmd_args[@]}" diff --git a/examples/distill/WanGame2.1/validation.json b/examples/distill/WanGame2.1/validation.json new file mode 100644 index 000000000..2012d50fe --- /dev/null +++ b/examples/distill/WanGame2.1/validation.json @@ -0,0 +1,324 @@ +{ + "data": [ + { + "caption": "51", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000051.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000051_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "229", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000229.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000229_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "250", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000250.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000250_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "380", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000380.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000380_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "382", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000382.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000382_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "387", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000387.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000387_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "418", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000418.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000418_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "505", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000505.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000505_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "515", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000515.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000515_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "534", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000534.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000534_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "599", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000599.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000599_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "613", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000613.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000613_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "745", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000745.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000745_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "861", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000861.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000861_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "940", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000940.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000940_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "946", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000946.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000946_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "996", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000996.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000996_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1011", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001011.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001011_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1037", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001037.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001037_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1057", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001057.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001057_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1195", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001195.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001195_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1236", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001236.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001236_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1276", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001276.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001276_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1368", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001368.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001368_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1403", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001403.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001403_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1417", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001417.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001417_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1481", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001481.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001481_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1489", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001489.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001489_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1618", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001618.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001618_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1779", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001779.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001779_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1867", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001867.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001867_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1949", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001949.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001949_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + } + ] +} \ No newline at end of file diff --git a/examples/inference/basic/basic_causal_wangame.py b/examples/inference/basic/basic_causal_wangame.py new file mode 100644 index 000000000..f18056945 --- /dev/null +++ b/examples/inference/basic/basic_causal_wangame.py @@ -0,0 +1,49 @@ +from fastvideo import VideoGenerator +from fastvideo.configs.pipelines import SelfForcingWanGameI2V480PConfig +from fastvideo.models.dits.matrixgame.utils import create_action_presets +import torch + +BASE_MODEL_PATH = "Wan2.1-Fun-1.3B-InP-Diffusers" +WEIGHTS_PATH = "checkpoints/wangame_ode_init/checkpoint-1200/transformer" + +OUTPUT_PATH = "video_samples_wangame" +IMAGE_PATH = "https://raw.githubusercontent.com/SkyworkAI/Matrix-Game/main/Matrix-Game-2/demo_images/universal/0000.png" + + +def main(): + generator = VideoGenerator.from_pretrained( + BASE_MODEL_PATH, + pipeline_config=SelfForcingWanGameI2V480PConfig(), + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=True, + override_pipeline_cls_name="WanGameCausalDMDPipeline", + override_transformer_cls_name="CausalWanGameActionTransformer3DModel", + init_weights_from_safetensors=WEIGHTS_PATH, + ) + + num_frames = 81 + actions = create_action_presets(num_frames, keyboard_dim=4) + actions["keyboard"] = torch.tensor([[1.0, 0.0, 0.0, 0.0]] * num_frames) + actions["mouse"] = torch.tensor([[0.0, 0.0]] * num_frames) + + generator.generate_video( + prompt="", + image_path=IMAGE_PATH, + mouse_cond=actions["mouse"].unsqueeze(0), + keyboard_cond=actions["keyboard"].unsqueeze(0), + num_frames=num_frames, + height=352, + width=640, + num_inference_steps=40, + guidance_scale=1.0, + output_path=OUTPUT_PATH, + save_video=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/basic/basic_wangame.py b/examples/inference/basic/basic_wangame.py new file mode 100644 index 000000000..d5154f824 --- /dev/null +++ b/examples/inference/basic/basic_wangame.py @@ -0,0 +1,46 @@ +from fastvideo import VideoGenerator +from fastvideo.configs.pipelines import WanGameI2V480PConfig +from fastvideo.models.dits.matrixgame.utils import create_action_presets + +BASE_MODEL_PATH = "Wan2.1-Fun-1.3B-InP-Diffusers" +# WEIGHTS_PATH = "wangame_1.3b_overfit/checkpoint-10000/transformer/diffusion_pytorch_model.safetensors" + +OUTPUT_PATH = "video_samples_wangame" +IMAGE_PATH = "/mnt/fast-disks/hao_lab/kaiqin/traindata_0209_1500/ode_init_mc/images/000000.jpg" + + +def main(): + generator = VideoGenerator.from_pretrained( + BASE_MODEL_PATH, + pipeline_config=WanGameI2V480PConfig(), + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=True, + override_pipeline_cls_name="WanGameActionImageToVideoPipeline", + override_transformer_cls_name="WanGameActionTransformer3DModel", + # init_weights_from_safetensors=WEIGHTS_PATH, + ) + + num_frames = 77 + actions = create_action_presets(num_frames, keyboard_dim=4) + + generator.generate_video( + prompt="", + image_path=IMAGE_PATH, + mouse_cond=actions["mouse"].unsqueeze(0), + keyboard_cond=actions["keyboard"].unsqueeze(0), + num_frames=num_frames, + height=352, + width=640, + num_inference_steps=40, + guidance_scale=1.0, + output_path=OUTPUT_PATH, + save_video=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/basic/basic_wangame_lingbot.py b/examples/inference/basic/basic_wangame_lingbot.py new file mode 100644 index 000000000..b30d0f562 --- /dev/null +++ b/examples/inference/basic/basic_wangame_lingbot.py @@ -0,0 +1,46 @@ +from fastvideo import VideoGenerator +from fastvideo.configs.pipelines import WanLingBotI2V480PConfig +from fastvideo.models.dits.matrixgame.utils import create_action_presets + +BASE_MODEL_PATH = "weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +WEIGHTS_PATH = "wangame_lingbot_test/checkpoint-100/transformer/diffusion_pytorch_model.safetensors" + +OUTPUT_PATH = "video_samples_wangame_lingbot" +IMAGE_PATH = "https://raw.githubusercontent.com/SkyworkAI/Matrix-Game/main/Matrix-Game-2/demo_images/universal/0000.png" + + +def main(): + generator = VideoGenerator.from_pretrained( + BASE_MODEL_PATH, + pipeline_config=WanLingBotI2V480PConfig(), + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=True, + override_pipeline_cls_name="WanLingBotImageToVideoPipeline", + override_transformer_cls_name="WanLingBotTransformer3DModel", + init_weights_from_safetensors=WEIGHTS_PATH, + ) + + num_frames = 77 + actions = create_action_presets(num_frames, keyboard_dim=4) + + generator.generate_video( + prompt="", + image_path=IMAGE_PATH, + mouse_cond=actions["mouse"].unsqueeze(0), + keyboard_cond=actions["keyboard"].unsqueeze(0), + num_frames=num_frames, + height=352, + width=640, + num_inference_steps=40, + guidance_scale=1.0, + output_path=OUTPUT_PATH, + save_video=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/train/dfsft_wangame_causal_v3.yaml b/examples/train/dfsft_wangame_causal_v3.yaml new file mode 100644 index 000000000..e2e77db67 --- /dev/null +++ b/examples/train/dfsft_wangame_causal_v3.yaml @@ -0,0 +1,91 @@ +# V3 config: WanGame causal Diffusion-Forcing SFT (DFSFT). +# +# Uses _target_-based instantiation — each model role is an independent +# class instance; the method class is resolved directly from the YAML. + +models: + student: + _target_: fastvideo.train.models.wangame.WanGameCausalModel + init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps + trainable: true + +method: + _target_: fastvideo.train.methods.fine_tuning.dfsft.DiffusionForcingSFTMethod + attn_kind: dense + # use_ema: true + chunk_size: 3 + min_timestep_ratio: 0.02 + max_timestep_ratio: 0.98 + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 8 + hsdp_shard_dim: 1 + + data: + data_path: >- + /mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0, + /mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:0, + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:1 + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 352 + num_width: 640 + num_frames: 77 + + optimizer: + learning_rate: 1.0e-5 + betas: [0.9, 0.999] + weight_decay: 1.0e-4 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 20000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wangame_dfsft_causal_v3 + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 2 + + tracker: + project_name: distillation_wangame_r + run_name: wangame_dfsft_causal_v3 + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback + max_grad_norm: 1.0 + # ema: + # _target_: fastvideo.train.callbacks.ema.EMACallback + # beta: 0.9999 + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback + pipeline_target: fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline.WanGameCausalDMDPipeline + dataset_file: examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json + every_steps: 100 + sampling_steps: [40] + rollout_mode: streaming + sampler_kind: ode + scheduler_target: fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler + guidance_scale: 1.0 + num_frames: 69 + +pipeline: + flow_shift: 3 + sampler_kind: ode diff --git a/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml b/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml new file mode 100644 index 000000000..1b161f541 --- /dev/null +++ b/examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml @@ -0,0 +1,91 @@ +# DMD2 distillation: Wan 2.1 T2V 1.3B (teacher 50-step -> student 4-step). +# +# - Teacher: frozen pretrained Wan 2.1 T2V 1.3B +# - Student: trainable, initialized from the same pretrained weights +# - Critic: trainable, initialized from the same pretrained weights +# - Validation: 4-step SDE sampling + +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method + rollout_mode: simulate + generator_update_interval: 5 + real_score_guidance_scale: 4.5 + dmd_denoising_steps: [1000, 750, 500, 250] + + # Critic optimizer (required — no fallback to training.optimizer) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 1 + hsdp_shard_dim: 8 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 2.0e-6 + betas: [0.0, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wan2.1_dmd2_4steps + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wan + run_name: wan2.1_dmd2_4steps + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + max_grad_norm: 1.0 + validation: + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 50 + sampling_steps: [4] + sampler_kind: sde + sampling_timesteps: [1000, 750, 500, 250] + guidance_scale: 6.0 + +pipeline: + flow_shift: 8 diff --git a/examples/train/example.yaml b/examples/train/example.yaml new file mode 100644 index 000000000..f025d47b1 --- /dev/null +++ b/examples/train/example.yaml @@ -0,0 +1,208 @@ +# ============================================================================== +# Full configuration reference for fastvideo.train +# +# Legend: +# [TYPED] — parsed into a typed dataclass; fields are validated with +# defaults. Unknown keys are silently ignored. +# [FREE] — free-form dict passed as-is to the target class / method. +# Keys depend on the _target_ class constructor / method_config. +# [RESOLVED] — parsed by PipelineConfig.from_kwargs(); auto-populated from +# the model's config files. Only scalar overrides are useful. +# ============================================================================== + +# ------------------------------------------------------------------------------ +# models: [FREE] +# +# Each role is instantiated via _target_(*, training_config=..., **kwargs). +# Keys here are constructor kwargs of the _target_ class (e.g. WanModel). +# You can define any role name (student, teacher, critic, etc.). +# ------------------------------------------------------------------------------ +models: + student: + _target_: fastvideo.train.models.wan.WanModel # required + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers # required: HF repo or local path + trainable: true # default: true + disable_custom_init_weights: false # default: false + flow_shift: 3.0 # default: 3.0 + enable_gradient_checkpointing_type: null # default: null (falls back to training.model) + + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +# ------------------------------------------------------------------------------ +# method: [FREE] +# +# Instantiated via _target_(*, cfg=RunConfig, role_models=...). +# All keys besides _target_ are available in self.method_config (a plain dict). +# Keys depend entirely on the method class. +# ------------------------------------------------------------------------------ +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method # required + + # --- DMD2-specific keys (read from self.method_config) --- + rollout_mode: simulate # required: "simulate" or "data_latent" + generator_update_interval: 5 # default: 1 + dmd_denoising_steps: [1000, 750, 500, 250] # SDE timestep schedule + + # Critic optimizer (all required — no fallback) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + + # CFG conditioning policy (optional) + # cfg_uncond: + # on_missing: error # "error" or "ignore" + # text: keep # "keep", "zero", "drop", "negative_prompt" + # image: keep # "keep", "zero", "drop" + # action: keep # "keep", "zero", "drop" + + # --- FineTuneMethod keys (if using finetune instead) --- + # _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod + # attn_kind: vsa # "dense" or "vsa" + # use_ema: false + +# ------------------------------------------------------------------------------ +# training: [TYPED] -> TrainingConfig +# +# Every field below has a typed default. Unknown keys are ignored. +# ------------------------------------------------------------------------------ +training: + + # --- training.distributed [TYPED] -> DistributedConfig --- + distributed: + num_gpus: 8 # default: 1 + tp_size: 1 # default: 1 + sp_size: 1 # default: 1 (defaults to num_gpus in loader) + hsdp_replicate_dim: 1 # default: 1 + hsdp_shard_dim: 8 # default: -1 (defaults to num_gpus in loader) + pin_cpu_memory: false # default: false + + # --- training.data [TYPED] -> DataConfig --- + data: + data_path: data/my_dataset # default: "" + train_batch_size: 1 # default: 1 + dataloader_num_workers: 4 # default: 0 + training_cfg_rate: 0.1 # default: 0.0 + seed: 1000 # default: 0 + num_height: 448 # default: 0 + num_width: 832 # default: 0 + num_latent_t: 20 # default: 0 + num_frames: 77 # default: 0 + + # --- training.optimizer [TYPED] -> OptimizerConfig --- + # Note: only for the student optimizer. Critic optimizer is in method config. + optimizer: + learning_rate: 2.0e-6 # default: 0.0 + betas: [0.9, 0.999] # default: [0.9, 0.999] + weight_decay: 0.01 # default: 0.0 + lr_scheduler: constant # default: "constant" + lr_warmup_steps: 0 # default: 0 + lr_num_cycles: 0 # default: 0 + lr_power: 0.0 # default: 0.0 + min_lr_ratio: 0.5 # default: 0.5 + + # --- training.loop [TYPED] -> TrainingLoopConfig --- + loop: + max_train_steps: 10000 # default: 0 + gradient_accumulation_steps: 1 # default: 1 + + # --- training.checkpoint [TYPED] -> CheckpointConfig --- + checkpoint: + output_dir: outputs/my_run # default: "" + resume_from_checkpoint: "" # default: "" (or use --resume-from-checkpoint CLI) + training_state_checkpointing_steps: 1000 # default: 0 (disabled) + checkpoints_total_limit: 3 # default: 0 (keep all) + + # --- training.tracker [TYPED] -> TrackerConfig --- + tracker: + trackers: [] # default: [] (auto-adds "wandb" if project_name is set) + project_name: my_project # default: "fastvideo" + run_name: my_run # default: "" + + # --- training.vsa [TYPED] -> VSAConfig --- + vsa: + sparsity: 0.0 # default: 0.0 (0.0 = disabled) + decay_rate: 0.0 # default: 0.0 + decay_interval_steps: 0 # default: 0 + + # --- training.model [TYPED] -> ModelTrainingConfig --- + model: + weighting_scheme: uniform # default: "uniform" + logit_mean: 0.0 # default: 0.0 + logit_std: 1.0 # default: 1.0 + mode_scale: 1.0 # default: 1.0 + precondition_outputs: false # default: false + moba_config: {} # default: {} + enable_gradient_checkpointing_type: full # default: null ("full" or null) + + # --- training top-level [TYPED] --- + dit_precision: fp32 # default: "fp32" (master weight precision) + # model_path: ... # default: "" (auto-derived from models.student.init_from) + +# ------------------------------------------------------------------------------ +# callbacks: [FREE] +# +# Each callback is instantiated via _target_(*, **kwargs). +# The callback name (e.g. "grad_clip") is arbitrary — only _target_ matters. +# training_config is injected automatically (not from YAML). +# ------------------------------------------------------------------------------ +callbacks: + + # --- GradNormClipCallback --- + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback # optional if using default registry + max_grad_norm: 1.0 # default: 0.0 (0.0 = disabled) + log_grad_norms: false # default: false + + # --- EMACallback --- + # ema: + # _target_: fastvideo.train.callbacks.ema.EMACallback + # type: constant # default: "constant" ("constant", "power", "halflife") + # beta: 0.9999 # default: 0.9999 (for constant type) + # gamma: 16.97 # default: 16.97 (for power type) + # ema_halflife_kimg: 500.0 # default: 500.0 (for halflife type) + # ema_rampup_ratio: 0.05 # default: 0.05 (for halflife type) + # start_iter: 0 # default: 0 + # batch_size: 1 # default: 1 + + # --- ValidationCallback --- + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback # optional if using default registry + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline # required + dataset_file: path/to/validation.json # required + every_steps: 100 # default: 100 + sampling_steps: [4] # default: [40] + sampler_kind: sde # default: "ode" (use "sde" for few-step distilled models) + scheduler_target: null # default: null (_target_ for scheduler class, e.g. + # fastvideo.models.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler + # fastvideo.models.schedulers.scheduling_flow_unipc_multistep.FlowUniPCMultistepScheduler) + guidance_scale: 5.0 # default: null (uses model default) + num_frames: null # default: null (derived from training.data) + output_dir: null # default: null (falls back to training.checkpoint.output_dir) + sampling_timesteps: null # default: null (explicit timestep list for SDE) + rollout_mode: parallel # default: "parallel" ("parallel" or "streaming") + +# ------------------------------------------------------------------------------ +# pipeline: [RESOLVED] -> PipelineConfig +# +# Parsed by PipelineConfig.from_kwargs(). Most fields are auto-populated from +# the model's config files (vae_config, dit_config, text_encoder_configs, etc.). +# Only scalar overrides are typically needed here. +# ------------------------------------------------------------------------------ +pipeline: + flow_shift: 3 # default: null (model-specific) + # flow_shift_sr: null # default: null (super-resolution shift) + # embedded_cfg_scale: 6.0 # default: 6.0 + # is_causal: false # default: false + # vae_tiling: true # default: true + # vae_sp: true # default: true + # disable_autocast: false # default: false diff --git a/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml b/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml new file mode 100644 index 000000000..888ad2491 --- /dev/null +++ b/examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml @@ -0,0 +1,82 @@ +# V3 config: Wan 2.1 T2V 1.3B finetune with VSA (phase 3.4, 0.9 sparsity). +# +# Uses _target_-based instantiation — each model role is an independent +# class instance; the method class is resolved directly from the YAML. + +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + +method: + _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod + attn_kind: vsa + use_ema: true + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 8 + hsdp_shard_dim: 1 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.1 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 1.0e-6 + betas: [0.9, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/phase3.4_wan2.1_finetune_vsa_0.9_v3 + training_state_checkpointing_steps: 1000 + weight_only_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wangame_r + run_name: phase3.4_wan_finetune_vsa_0.9_v3 + + model: + enable_gradient_checkpointing_type: full + + vsa: + sparsity: 0.9 + decay_rate: 0.03 + decay_interval_steps: 1 + +callbacks: + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback + max_grad_norm: 1.0 + ema: + _target_: fastvideo.train.callbacks.ema.EMACallback + beta: 0.9999 + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 50 + sampling_steps: [50] + guidance_scale: 5.0 + +pipeline: + flow_shift: 3 + sampler_kind: ode diff --git a/examples/train/finetune_wangame2.1_i2v_1.3B.yaml b/examples/train/finetune_wangame2.1_i2v_1.3B.yaml new file mode 100644 index 000000000..4edc3f10b --- /dev/null +++ b/examples/train/finetune_wangame2.1_i2v_1.3B.yaml @@ -0,0 +1,86 @@ +# V3 config: WanGame 2.1 I2V 1.3B finetune (dense attention). +# +# Uses _target_-based instantiation — each model role is an independent +# class instance; the method class is resolved directly from the YAML. + +models: + student: + _target_: fastvideo.train.models.wangame.WanGameModel + init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps + trainable: true + +method: + _target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod + attn_kind: dense + # use_ema: true + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 8 + hsdp_shard_dim: 1 + + data: + data_path: >- + /mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0, + /mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:0, + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:1 + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 352 + num_width: 640 + num_frames: 77 + + optimizer: + learning_rate: 1.0e-5 + betas: [0.9, 0.999] + weight_decay: 1.0e-4 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 20000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wangame_finetune_v3 + training_state_checkpointing_steps: 1000 + weight_only_checkpointing_steps: 1000 + checkpoints_total_limit: 2 + + tracker: + project_name: distillation_wangame_r + run_name: wangame_finetune_v3 + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback + max_grad_norm: 1.0 + # ema: + # _target_: fastvideo.train.callbacks.ema.EMACallback + # beta: 0.9999 + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback + pipeline_target: fastvideo.pipelines.basic.wan.wangame_i2v_pipeline.WanGameActionImageToVideoPipeline + dataset_file: examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json + every_steps: 100 + sampling_steps: [40] + sampler_kind: ode + guidance_scale: 1.0 + +pipeline: + flow_shift: 3 + sampler_kind: ode diff --git a/examples/train/issue.md b/examples/train/issue.md new file mode 100644 index 000000000..68ede195b --- /dev/null +++ b/examples/train/issue.md @@ -0,0 +1,327 @@ +# [RFC] Unified, YAML-Driven Training Architecture for Video Diffusion Models + +## Summary + +We propose a new **pluggable training architecture** in FastVideo that cleanly separates **models**, **training methods**, and **infrastructure** into independent, composable layers. A single YAML config file is all that is needed to train any supported model with any supported algorithm — no code changes required to mix and match. + +This issue describes the design, current status, and open questions. **We welcome community feedback on the architecture, API surface, and planned extensions.** + +--- + +## Motivation + +Training video diffusion models today involves a tangle of concerns: model loading, noise scheduling, distillation algorithms, distributed strategies, checkpointing, and validation. Existing training scripts tend to hard-wire these together, making it painful to: + +1. **Try a new distillation algorithm** on an existing model (requires forking the training loop). +2. **Add a new model** to an existing algorithm (requires re-implementing boilerplate). +3. **Switch distributed strategies** (FSDP, TP, SP) without touching algorithm code. +4. **Resume, checkpoint, and validate** uniformly across all combinations. + +Our goal is a single, extensible training framework where each axis of variation is an independent plugin. + +--- + +## Architecture Overview + +``` +YAML Config + | + v ++------------------+ +---------------------+ +------------------+ +| Models Layer | | Methods Layer | | Infrastructure | +| (per-role) | | (algorithm) | | Layer | +| | | | | | +| - ModelBase |<----| - TrainingMethod |---->| - Trainer | +| - CausalModelBase| | - single_train_step| | - Callbacks | +| | | - backward | | - Checkpoint | +| Roles: | | - optimizers | | - Tracker (W&B) | +| student | | | | - Dataloader | +| teacher | | Algorithms: | | | +| critic | | DMD2, SelfForcing, | | Distributed: | +| | | SFT, DFSFT | | HSDP, TP, SP | ++------------------+ +---------------------+ +------------------+ +``` + +### Three Layers + +| Layer | Responsibility | Extension point | +|-------|---------------|-----------------| +| **Models** (`fastvideo/train/models/`) | Load transformer + scheduler, define `predict_noise`, `predict_x0`, `add_noise`, `backward`. Each training role (student/teacher/critic) is an independent instance. | Subclass `ModelBase` (or `CausalModelBase` for streaming). | +| **Methods** (`fastvideo/train/methods/`) | Implement the training algorithm: own role models, define `single_train_step` + `backward`, manage optimizers/schedulers. | Subclass `TrainingMethod`. | +| **Infrastructure** (`fastvideo/train/trainer.py`, `utils/`, `callbacks/`) | Training loop, gradient accumulation, distributed setup, checkpointing (DCP), W&B tracking, validation, EMA, grad clipping. | Add callbacks; everything else is shared. | + +### YAML-Driven Configuration + +Everything is configured declaratively. The `_target_` field selects the Python class to instantiate: + +```yaml +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method + rollout_mode: simulate + dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125] + generator_update_interval: 5 + real_score_guidance_scale: 3.5 + # ... + +training: + distributed: { num_gpus: 8, sp_size: 1, tp_size: 1 } + data: { data_path: ..., num_latent_t: 20, num_frames: 77 } + optimizer: { learning_rate: 2.0e-6, betas: [0.0, 0.999] } + loop: { max_train_steps: 4000 } + checkpoint: { output_dir: outputs/my_run } + +callbacks: + grad_clip: { max_grad_norm: 1.0 } + validation: { pipeline_target: ..., every_steps: 100 } +``` + +To switch from DMD2 to SFT, change the `method._target_` and remove the teacher/critic — no code changes. + +--- + +## Model Abstraction + +### `ModelBase` — Standard (Bidirectional) Models + +Every role gets its own `ModelBase` instance owning a `transformer` and `noise_scheduler`. The base class defines: + +- **`prepare_batch()`** — Convert raw dataloader output into forward-ready `TrainingBatch`. +- **`add_noise()`** — Apply forward-process noise at a given timestep. +- **`predict_noise()` / `predict_x0()`** — Run the transformer and return predictions. +- **`backward()`** — Backward pass that restores forward context (attention metadata, timesteps). +- **`init_preprocessors()`** — Lazy-load VAE, build dataloader (called only on the student). + +### `CausalModelBase` — Streaming / Causal Models + +Extends `ModelBase` with streaming inference primitives for causal video generation: + +```python +class CausalModelBase(ModelBase): + def clear_caches(self, *, cache_tag: str = "pos") -> None: ... + def predict_noise_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ... + def predict_x0_streaming(self, ..., cache_tag, store_kv, cur_start_frame) -> Tensor | None: ... +``` + +Key design: KV caches are **internal** to the model instance, keyed by `cache_tag`. The method controls when to store (`store_kv=True`) vs. read-only (`store_kv=False`), enabling block-by-block causal rollout during training. + +--- + +## Supported Training Methods + +### 1. DMD2 (Distribution Matching Distillation) + +**Roles:** student (trainable) + teacher (frozen) + critic (trainable) + +The student learns to generate clean video in few steps by matching the teacher's score function, with a critic network providing a learned fake-score baseline. + +- **Rollout modes:** + - `simulate` — Student starts from pure noise and iteratively denoises through the full step schedule. + - `data_latent` — Student denoises from a single randomly-noised data sample. +- **Losses:** Generator loss (DMD gradient) + critic flow-matching loss, with alternating updates (`generator_update_interval`). + +### 2. Self-Forcing (Causal DMD) + +**Roles:** student (causal, trainable) + teacher (frozen) + critic (trainable) + +Extends DMD2 for **streaming/causal video generation**. The key idea: during training, the student processes video in temporal chunks, using its own previously-denoised outputs as context for future chunks — simulating online autoregressive rollout. + +- Video is split into blocks of `chunk_size` latent frames. +- Each block is denoised through the student's step schedule; a random early-exit step is sampled per block. +- After denoising a block, its output is fed back (with optional `context_noise`) as KV cache context for subsequent blocks via `predict_noise_streaming(store_kv=True)`. +- Supports SDE and ODE sampling during rollout. +- Selective gradient control: `enable_gradient_in_rollout`, `start_gradient_frame`. + +### 3. Supervised Fine-Tuning (SFT) + +**Roles:** student only + +Standard flow-matching loss between predicted and ground-truth noise/x0. + +### 4. Diffusion-Forcing SFT (DFSFT) + +**Roles:** student only + +SFT with **inhomogeneous (per-chunk) timesteps** — each temporal chunk in a video gets a different noise level. This trains the model to handle mixed-noise inputs, which is a prerequisite for causal/streaming inference where earlier frames are cleaner than later ones. + +--- + +## Training Loop + +The `Trainer` runs a standard loop with pluggable method and callbacks: + +``` +for step in range(start_step, max_steps): + for accum_iter in range(grad_accum_steps): + batch <- dataloader + loss_map, outputs, metrics <- method.single_train_step(batch, step) + method.backward(loss_map, outputs) + + callbacks.on_before_optimizer_step() # grad clipping + method.optimizers_schedulers_step() + method.optimizers_zero_grad() + callbacks.on_training_step_end() # logging + checkpoint_manager.maybe_save(step) + callbacks.on_validation_begin() # periodic inference +``` + +### Callbacks + +- **GradNormClipCallback** — Per-module gradient norm logging + global clipping. +- **ValidationCallback** — Periodic inference sampling with configurable pipeline, sampling steps, and guidance scale. +- **EMACallback** — Exponential moving average of student weights. + +### Checkpointing + +- DCP (Distributed Checkpoint) format, compatible with FSDP/HSDP. +- Saves: model weights, optimizer states, scheduler states, RNG states (per role). +- Full resume support: auto-restores step counter and all RNG states. + +--- + +## Current Status + +| Component | Status | +|-----------|--------| +| Core framework (trainer, config, callbacks) | Implemented and tested | +| `WanModel` (Wan 2.1 T2V) | Implemented and tested | +| `WanGameModel` (WanGame 2.1 I2V) | Implemented and tested | +| `WanGameCausalModel` (streaming) | Implemented and tested | +| `WanCausalModel` (Wan T2V causal) | In progress | +| DMD2 method | Implemented and tested | +| Self-Forcing method | Implemented and tested | +| SFT method | Implemented and tested | +| DFSFT method | Implemented and tested | +| DCP checkpointing + resume | Implemented and tested | +| EMA callback | Implemented | +| Validation callback | Implemented and tested | +| Causal DMD inference pipeline | Implemented | + +--- + +## Open Questions for Discussion + +We'd love community input on the following: + +### 1. Model Plugin API + +The current `ModelBase` interface requires implementing 6 methods. Is this the right granularity? + +- Should `prepare_batch` be split into separate concerns (noise sampling, timestep sampling, attention metadata)? +- Should `backward` be lifted out of the model and into the method/trainer? + +### 2. Causal Streaming Interface + +`CausalModelBase` adds `predict_noise_streaming` / `predict_x0_streaming` with cache management. Alternatives considered: + +- **(a) Current:** Cache is internal to the model, keyed by `cache_tag`. Simple but couples cache lifecycle to model. +- **(b) External cache:** Method owns the cache dict, passes it into predict calls. More explicit but verbose. +- **(c) Context manager:** `with model.streaming_context(tag) as ctx: ...` — cleaner lifecycle but harder to compose. + +Which approach do you prefer? Are there use cases that would break the current design? + +### 3. Self-Forcing Training Strategy + +The self-forcing rollout has several configurable knobs: + +- **Early-exit sampling:** Random step per block vs. always last step (`last_step_only`). Random exit trains robustness but adds variance. +- **Context noise:** Re-noising completed blocks before feeding as context (`context_noise`). Improves robustness to prediction errors during inference. +- **Gradient scope:** Gradient only on exit-step prediction vs. through the entire rollout. Memory vs. signal quality trade-off. + +What are your experiences with these trade-offs? Are there strategies we should consider? + +### 4. Method Composition + +Currently, methods are monolithic classes. Should we support composing methods (e.g., DFSFT pre-training followed by Self-Forcing distillation) within a single config? Or is sequential training with checkpoint handoff sufficient? + +### 5. New Models and Methods + +What models and training methods should we prioritize next? + +- **Models:** HunyuanVideo, CogVideoX, other Wan variants? +- **Methods:** Consistency models, progressive distillation, reward-based fine-tuning? + +### 6. Distributed Strategy + +Currently supports HSDP (hybrid sharded data parallel) + TP + SP. Are there scenarios where the current distributed setup is insufficient? Should we add pipeline parallelism for very large models? + +--- + +## How to Try It + +```bash +# Install +uv pip install -e .[dev] + +# Run DMD2 distillation on Wan 2.1 +torchrun --nproc_per_node=8 -m fastvideo.train.entrypoint.train \ + --config examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml + +# Run SFT fine-tuning +torchrun --nproc_per_node=8 -m fastvideo.train.entrypoint.train \ + --config examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml +``` + +Example configs are in [`examples/train/`](../train/). + +--- + +## File Structure Reference + +``` +fastvideo/train/ + trainer.py # Training loop + models/ + base.py # ModelBase, CausalModelBase ABCs + wan/wan.py # Wan 2.1 T2V model plugin + wangame/wangame.py # WanGame 2.1 I2V model plugin + wangame/wangame_causal.py # WanGame causal (streaming) plugin + methods/ + base.py # TrainingMethod ABC + distribution_matching/ + dmd2.py # DMD2 distillation + self_forcing.py # Self-Forcing (causal DMD) + fine_tuning/ + finetune.py # Supervised fine-tuning + dfsft.py # Diffusion-forcing SFT + callbacks/ + grad_clip.py # Gradient clipping + norm logging + validation.py # Periodic inference validation + ema.py # EMA weight averaging + entrypoint/ + train.py # CLI entrypoint (torchrun) + utils/ + config.py # YAML parser -> RunConfig + builder.py # build_from_config: model/method instantiation + training_config.py # TrainingConfig dataclass + dataloader.py # Dataset/dataloader construction + optimizer.py # Optimizer/scheduler construction + checkpoint.py # DCP save/resume + tracking.py # W&B tracker +``` + +--- + +## Related + +- [RFC: Training Architecture](rfc.md) — Original internal design document. +- [Self-Forcing paper](https://arxiv.org/abs/2406.05477) — Chen et al., 2024. +- [DMD2 paper](https://arxiv.org/abs/2405.14867) — Yin et al., 2024. +- [Diffusion Forcing paper](https://arxiv.org/abs/2407.01392) — Chen et al., 2024. + +--- + +**Feedback is highly welcome!** Please comment on this issue with your thoughts on the design, suggestions for improvement, or use cases that the current architecture doesn't cover well. diff --git a/examples/train/review.md b/examples/train/review.md new file mode 100644 index 000000000..308d08bc1 --- /dev/null +++ b/examples/train/review.md @@ -0,0 +1,239 @@ +# PR Review: `refactor/train` vs `upstream/main` + +**144 files changed, 24042 insertions, 451 deletions.** + +This PR mixes three categories of changes. This document categorizes +every file so you can quickly decide what belongs in this PR and what +should be split out or dropped. + +--- + +## Category A — WanGame Model & Infrastructure (60 new + 17 modified files) + +Pure WanGame/WanLingBot model, pipeline, data, and example additions. +None of these touch `fastvideo/train/`. + +### New model code + +| File | Summary | +|------|---------| +| `fastvideo/models/dits/wangame/__init__.py` | Package init | +| `fastvideo/models/dits/wangame/model.py` | WanGame transformer (action-conditioned Wan) | +| `fastvideo/models/dits/wangame/causal_model.py` | Causal WanGame transformer (streaming KV cache) | +| `fastvideo/models/dits/wangame/hyworld_action_module.py` | Keyboard/mouse action embedding module | +| `fastvideo/models/dits/wangame_lingbot/__init__.py` | Package init | +| `fastvideo/models/dits/wangame_lingbot/model.py` | WanLingBot transformer variant | +| `fastvideo/models/dits/wangame_lingbot/cam_utils.py` | Camera utility for LingBot | +| `fastvideo/configs/models/dits/wangamevideo.py` | WanGame/WanLingBot model configs | + +### New pipeline code + +| File | Summary | +|------|---------| +| `fastvideo/pipelines/basic/wan/wangame_i2v_pipeline.py` | WanGame I2V inference pipeline | +| `fastvideo/pipelines/basic/wan/wangame_causal_dmd_pipeline.py` | WanGame causal DMD pipeline | +| `fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline.py` | WanGame data preprocessing | +| `fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline_ode_trajectory.py` | ODE trajectory preprocessing | +| `fastvideo/pipelines/samplers/__init__.py` | New sampler package init | +| `fastvideo/pipelines/samplers/base.py` | Base sampler class | +| `fastvideo/pipelines/samplers/wan.py` | Wan-specific sampler utilities | + +### New legacy training pipelines (old-style, not `fastvideo/train/`) + +| File | Summary | +|------|---------| +| `fastvideo/training/wangame_training_pipeline.py` | WanGame SFT training (old pipeline style) | +| `fastvideo/training/wangame_distillation_pipeline.py` | WanGame DMD distillation (old style) | +| `fastvideo/training/wangame_self_forcing_distillation_pipeline.py` | WanGame self-forcing (old style, 952 lines) | +| `fastvideo/training/wangame_ar_diffusion_pipeline.py` | WanGame AR diffusion (old style) | +| `fastvideo/training/wangame_ode_causal_pipeline.py` | WanGame ODE causal (old style) | +| `fastvideo/training/wangame_lingbot_training_pipeline.py` | WanLingBot training (old style) | + +> **Question:** Are these legacy `fastvideo/training/wangame_*.py` files +> still needed now that `fastvideo/train/` exists? If they're +> superseded, consider removing them to avoid confusion. + +### New examples & scripts (WanGame-specific) + +| File | Summary | +|------|---------| +| `examples/inference/basic/basic_wangame.py` | WanGame inference example | +| `examples/inference/basic/basic_causal_wangame.py` | Causal WanGame inference example | +| `examples/inference/basic/basic_wangame_lingbot.py` | WanLingBot inference example | +| `examples/distill/WanGame2.1/distill_dmd.slurm` | WanGame DMD slurm job | +| `examples/distill/WanGame2.1/validation.json` | Validation prompts | +| `examples/distill/SFWanGame2.1/distill_dmd.sh` | Self-forcing distill script | +| `examples/distill/SFWanGame2.1/distill_dmd.slurm` | Self-forcing slurm job | +| `examples/distill/SFWanGame2.1/validation.json` | Validation prompts | +| `examples/training/finetune/WanGame2.1_1.3b_i2v/` | WanGame finetune scripts, slurm jobs, validation JSONs, helper scripts (~15 files) | +| `examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/` | LingBot finetune scripts (~8 files) | +| `examples/training/consistency_finetune/causal_wangame_ode_init/` | ODE-init consistency finetune scripts (~8 files) | +| `docs/wangame/zero_init_fixes.md` | Zero-init debugging notes | +| `visualize_trajectory.py` | Trajectory visualization tool (224 lines) | + +### Modified files (WanGame support) + +| File | What changed | +|------|-------------| +| `fastvideo/configs/models/dits/__init__.py` | Imports/exports `WanGameVideoConfig`, `WanLingBotVideoConfig` | +| `fastvideo/configs/pipelines/__init__.py` | Imports/exports WanGame/LingBot pipeline configs | +| `fastvideo/configs/pipelines/wan.py` | Adds WanGame/LingBot/SelfForcing pipeline config dataclasses | +| `fastvideo/dataset/dataloader/record_schema.py` | Adds `wangame_ode_record_creator()` | +| `fastvideo/dataset/dataloader/schema.py` | Adds `pyarrow_schema_wangame`, `_lingbot`, `_ode_trajectory_wangame` | +| `fastvideo/dataset/validation_dataset.py` | Adds action (keyboard/mouse) loading support | +| `fastvideo/models/dits/hyworld/pose.py` | Adds `reformat_keyboard_and_mouse_tensors()`, `process_custom_actions()` | +| `fastvideo/models/loader/fsdp_load.py` | Extends `ALLOWED_NEW_PARAM_PATTERNS` for WanGame modules; adds kaiming init | +| `fastvideo/models/registry.py` | Registers WanGame/LingBot transformer classes | +| `fastvideo/registry.py` | Registers `Wan2.1-Game-Fun-1.3B-InP-Diffusers` HF path | +| `fastvideo/pipelines/preprocess/v1_preprocess.py` | Adds `wangame` and `wangame_ode_trajectory` preprocess tasks | +| `fastvideo/pipelines/stages/denoising.py` | Adds action conditioning via `process_custom_actions()` | +| `fastvideo/pipelines/stages/matrixgame_denoising.py` | Adds action/camera kwargs for causal WanGame inference (+486 lines) | + +--- + +## Category B — New `fastvideo/train/` Architecture (40 new + 15 modified files) + +The core of this PR: the new pluggable, YAML-driven training framework. + +### New files (all under `fastvideo/train/`) + +| File | Summary | +|------|---------| +| `fastvideo/train/__init__.py` | Package init | +| `fastvideo/train/.style.yapf` | YAPF config (80-col) | +| `fastvideo/train/trainer.py` | Training loop: gradient accumulation, callbacks, checkpointing | +| `fastvideo/train/models/base.py` | `ModelBase`, `CausalModelBase` ABCs | +| `fastvideo/train/models/wan/wan.py` | Wan 2.1 T2V model plugin | +| `fastvideo/train/models/wangame/wangame.py` | WanGame I2V model plugin | +| `fastvideo/train/models/wangame/wangame_causal.py` | WanGame causal (streaming) model plugin | +| `fastvideo/train/methods/base.py` | `TrainingMethod` ABC | +| `fastvideo/train/methods/distribution_matching/dmd2.py` | DMD2 distillation method | +| `fastvideo/train/methods/distribution_matching/self_forcing.py` | Self-Forcing method | +| `fastvideo/train/methods/fine_tuning/finetune.py` | SFT method | +| `fastvideo/train/methods/fine_tuning/dfsft.py` | Diffusion-forcing SFT method | +| `fastvideo/train/callbacks/callback.py` | `CallbackDict` registry | +| `fastvideo/train/callbacks/grad_clip.py` | Gradient clipping callback | +| `fastvideo/train/callbacks/validation.py` | Validation callback | +| `fastvideo/train/callbacks/ema.py` | EMA callback | +| `fastvideo/train/entrypoint/train.py` | CLI entrypoint (`torchrun -m fastvideo.train.entrypoint.train`) | +| `fastvideo/train/entrypoint/dcp_to_diffusers.py` | Checkpoint conversion | +| `fastvideo/train/utils/config.py` | YAML parser -> `RunConfig` | +| `fastvideo/train/utils/builder.py` | `build_from_config`: instantiate models + method | +| `fastvideo/train/utils/instantiate.py` | `_target_`-based class instantiation | +| `fastvideo/train/utils/training_config.py` | `TrainingConfig` dataclass | +| `fastvideo/train/utils/dataloader.py` | Dataset/dataloader construction | +| `fastvideo/train/utils/optimizer.py` | Optimizer/scheduler construction | +| `fastvideo/train/utils/checkpoint.py` | DCP save/resume | +| `fastvideo/train/utils/tracking.py` | W&B tracker | +| `fastvideo/train/utils/module_state.py` | `apply_trainable()` | +| `fastvideo/train/utils/moduleloader.py` | Dynamic module loading | +| `fastvideo/train/utils/validation.py` | Validation helpers | +| `fastvideo/train/methods/consistency_model/__init__.py` | Placeholder | +| `fastvideo/train/methods/knowledge_distillation/__init__.py` | Placeholder | +| Various `__init__.py` files | Package inits | + +### New example configs (for the new training architecture) + +| File | Summary | +|------|---------| +| `examples/train/rfc.md` | Training architecture RFC document | +| `examples/train/issue.md` | Public issue for community discussion | +| `examples/train/run.sh` | Example launch script | +| `examples/train/example.yaml` | Generic example config | +| `examples/train/distill_wan2.1_t2v_1.3B_dmd2.yaml` | DMD2 distillation on Wan 2.1 | +| `examples/train/finetune_wan2.1_t2v_1.3B_vsa_phase3.4_0.9sparsity.yaml` | VSA finetune on Wan 2.1 | +| `examples/train/finetune_wangame2.1_i2v_1.3B.yaml` | Finetune WanGame (new arch) | +| `examples/train/dfsft_wangame_causal_v3.yaml` | DFSFT on causal WanGame (new arch) | +| `examples/train/self_forcing_wangame_causal_v3.yaml` | Self-forcing on causal WanGame (new arch) | + +### Modified files outside `fastvideo/train/` (needed by the new architecture) + +| File | What changed | Necessary? | +|------|-------------|------------| +| `fastvideo/configs/pipelines/base.py` | Adds `sampler_kind` (ode/sde) and `ode_solver` config fields | Yes — pluggable sampler strategy | +| `fastvideo/dataset/parquet_dataset_map_style.py` | Multi-path with repeat counts (`/dir:2`), epoch reshuffling, hash-based caching | Yes — flexible dataset composition | +| `fastvideo/fastvideo_args.py` | Adds `reshuffle_each_epoch`, `validation_num_samples`, action training flags | Partially — some WanGame-specific | +| `fastvideo/models/loader/component_loader.py` | Removes unused imports; fixes FSDP exclusions | Partially — some cleanup, some WanGame | +| `fastvideo/pipelines/basic/wan/wan_pipeline.py` | Refactors to use `build_wan_scheduler()`, pluggable ODE/SDE sampler | Yes — sampler abstraction | +| `fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py` | Thins to compatibility wrapper (sets `sampler_kind=sde`, delegates to WanPipeline) | Yes — aligns with sampler refactor | +| `fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py` | Removes `TimestepPreparationStage` (not needed for SDE) | Yes — aligns with sampler refactor | +| `fastvideo/pipelines/pipeline_batch_info.py` | Adds `sampling_timesteps` field to `ForwardBatch` | Yes — SDE denoising needs explicit timesteps | +| `fastvideo/pipelines/stages/__init__.py` | Exports `SdeDenoisingStage`, `MatrixGameCausalOdeDenoisingStage` | Yes — new stages | +| `fastvideo/training/checkpointing_utils.py` | Fixes activation-checkpoint wrapper key renaming in `ModelWrapper.state_dict()` | Yes — bugfix for grad checkpointing + DCP | +| `fastvideo/training/distillation_pipeline.py` | Adds `num_samples` to validation, hasattr checks, video saving | Partially — some robustness, some WanGame | +| `fastvideo/training/training_pipeline.py` | Epoch reshuffling, deterministic seeds, trainable param counting | Yes — training improvements | +| `fastvideo/training/training_utils.py` | Adds `count_trainable_total()` for distributed param counting | Yes — FSDP-aware param logging | + +--- + +## Category C — Standalone Bugfixes / Improvements (6 files) + +Small fixes and improvements unrelated to either WanGame or the new +training architecture. These could be merged independently. + +| File | What changed | +|------|-------------| +| `.gitignore` | Adds `*.npy`, `slurm_outputs/` | +| `fastvideo/configs/sample/wan.py` | Updates `Wan2_1_Fun_1_3B_InP_SamplingParam` defaults (resolution 352x640, 77 frames, 25fps, guidance 1.0, 40 steps) | +| `fastvideo/configs/pipelines/wan.py` | Same sampling param updates | +| `fastvideo/training/trackers.py` | Adds `log_file()` to `BaseTracker` / `WandbTracker` / `SequentialTracker` | +| `fastvideo/utils.py` | Adds `.cpu()` before `.numpy()` on GPU tensor — **bugfix** | +| `fastvideo/models/dits/matrixgame/utils.py` | Code reformatting of already-commented drawing functions — **no functional change** | + +--- + +## Summary Table + +| Category | New files | Modified files | Lines added | +|----------|-----------|----------------|-------------| +| **A. WanGame** | ~60 | 13 | ~12,600 | +| **B. `fastvideo/train/`** | ~40 | 13 | ~8,000 | +| **C. Bugfixes** | 0 | 6 | ~100 | +| **Overlap (A+B)** | — | ~4 | — | + +--- + +## Recommended Review Order + +1. **Category C first** — 6 small, independent changes. Quick to review + and merge separately if desired. + +2. **Category B (`fastvideo/train/`)** — The core architecture. Start + with `base.py` (model + method ABCs), then `trainer.py`, then the + four method implementations, then utils. + +3. **Category A (WanGame)** — Larger but mostly additive. The key + question is whether the legacy `fastvideo/training/wangame_*.py` + pipelines (~3,500 lines) are still needed alongside the new + `fastvideo/train/` architecture. + +4. **Overlap files** — Files modified for both WanGame and the training + architecture (e.g., `fastvideo_args.py`, `component_loader.py`, + `distillation_pipeline.py`, `denoising.py`). Review these last + since they require understanding both contexts. + +--- + +## Open Questions + +1. **Legacy training pipelines**: 6 new files under + `fastvideo/training/wangame_*.py` (~3,500 lines) use the old + training pipeline pattern. Are these still needed, or are they + superseded by `fastvideo/train/` configs? + +2. **`matrixgame/utils.py`**: 351-line diff that appears to be + formatting-only on commented-out code. Drop? + +3. **`visualize_trajectory.py`**: Top-level script (224 lines). Should + this live under `examples/` or `scripts/` instead? + +4. **Sampling param changes** (`configs/sample/wan.py`, + `configs/pipelines/wan.py`): Changed resolution from 480x832 to + 352x640, frames 81->77, guidance 6.0->1.0. Is this intentional for + all users or WanGame-specific? + +5. **`fastvideo_args.py` additions**: Several new flags + (`train_action_only`, `action_train_target`, + `action_warmup_steps`, `best_checkpoint_start_step`) appear + WanGame-specific. Should these live in a WanGame-specific config + instead of the shared args? diff --git a/examples/train/rfc.md b/examples/train/rfc.md new file mode 100644 index 000000000..8ac81430d --- /dev/null +++ b/examples/train/rfc.md @@ -0,0 +1,142 @@ + + +## 1) File Structure + +fastvideo/train/ + trainer.py # Training loop; calls method.train_one_step() + models/ + base.py # BaseModel ABC: predict_x0, add_noise, backward, ... + wan/ + wan.py # Wan model loader + wangame/ + wangame.py # WanGame model loader + wangame_causal.py + methods/ + base.py # DistillMethod base; methods provide train_one_step + distribution_matching/ + dmd2.py # DMD2 distillation (student/teacher/critic) + self_forcing.py # Self-forcing distillation + fine_tuning/ + finetune.py # SFT finetuning (student only) + dfsft.py # Distribution-free SFT + knowledge_distillation/ + consistency_model/ + callbacks/ + callback.py # CallbackDict registry + grad_clip.py # Gradient clipping + optional per-module norm logging + validation.py # Periodic validation via inference pipeline + ema.py # EMA weight averaging + entrypoint/ + train.py # YAML-only CLI entrypoint (torchrun -m fastvideo.train.entrypoint.train) + dcp_to_diffusers.py # Checkpoint conversion + utils/ + config.py # YAML parser -> RunConfig + builder.py # build_from_config: instantiate models, method, dataloader + instantiate.py # _target_ based instantiation + training_config.py # TrainingConfig dataclass (all training settings with defaults) + dataloader.py # Dataset / dataloader construction + moduleloader.py # Dynamic module import + module_state.py # apply_trainable(): requires_grad + train/eval + optimizer.py # Optimizer construction + tracking.py # W&B tracker (owned by trainer) + checkpoint.py # Save/resume with DCP + validation.py # Validation helpers + +By this design, we only need a YAML config to train different models using different methods. +Models declare `_target_` to select the model class; methods declare `_target_` to select the method class. +Current code: https://github.com/FoundationResearch/FastVideo/tree/distill1/fastvideo/train + +DMD2 Distillation, Self-Forcing, SFT, and DFSFT are tested on Wan / WanGame. + +Current supported models: Wan, WanGame. +Current supported methods: DMD2, Self-Forcing, SFT, DFSFT. + +Feedbacks are highly welcome! + + +## 2) Example YAML (DMD2 8-step) + +```yaml +models: + student: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + teacher: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: false + disable_custom_init_weights: true + critic: + _target_: fastvideo.train.models.wan.WanModel + init_from: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + trainable: true + disable_custom_init_weights: true + +method: + _target_: fastvideo.train.methods.distribution_matching.dmd2.DMD2Method + rollout_mode: simulate + generator_update_interval: 5 + real_score_guidance_scale: 3.5 + dmd_denoising_steps: [1000, 850, 700, 550, 350, 275, 200, 125] + + # Critic optimizer (required) + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + +training: + distributed: + num_gpus: 8 + sp_size: 1 + tp_size: 1 + + data: + data_path: data/Wan-Syn_77x448x832_600k + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 448 + num_width: 832 + num_frames: 77 + + optimizer: + learning_rate: 2.0e-6 + betas: [0.0, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wan2.1_dmd2_8steps + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wan + run_name: wan2.1_dmd2_8steps + + model: + enable_gradient_checkpointing_type: full + +callbacks: + grad_clip: + max_grad_norm: 1.0 + validation: + pipeline_target: fastvideo.pipelines.basic.wan.wan_pipeline.WanPipeline + dataset_file: examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_4.json + every_steps: 100 + sampling_steps: [8] + sampler_kind: sde + sampling_timesteps: [1000, 850, 700, 550, 350, 275, 200, 125] + guidance_scale: 6.0 + +pipeline: + flow_shift: 8 +``` diff --git a/examples/train/run.sh b/examples/train/run.sh new file mode 100755 index 000000000..da809abb9 --- /dev/null +++ b/examples/train/run.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +# Launch distillation training from a v3 YAML config. +# +# Usage: +# bash examples/distillation/refactor/run.sh [extra flags] +# +# Examples: +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/self_forcing_wangame_causal_v3.yaml +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/dfsft_wangame_causal_v3.yaml --dry-run +# bash examples/distillation/refactor/run.sh examples/distillation/refactor/dfsft_wangame_causal_v3.yaml \ +# --override-output-dir outputs/my_run +# +# Logs are written to logs/_.log (and also printed to stdout). + +set -euo pipefail + +CONFIG="${1:?Usage: $0 [extra flags...]}" +shift + +# ── GPU / node settings ────────────────────────────────────────────── +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +NUM_GPUS="${NUM_GPUS:-1}" +NNODES="${NNODES:-1}" +NODE_RANK="${NODE_RANK:-0}" +MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +MASTER_PORT="${MASTER_PORT:-29501}" + +# ── W&B ────────────────────────────────────────────────────────────── +export WANDB_API_KEY="${WANDB_API_KEY:-}" +export WANDB_MODE="${WANDB_MODE:-online}" + +# ── Log file ───────────────────────────────────────────────────────── +CONFIG_NAME="$(basename "${CONFIG}" .yaml)" +TIMESTAMP="$(date +%Y%m%d_%H%M%S)" +LOG_DIR="${LOG_DIR:-examples/distillation/refactor}" +mkdir -p "${LOG_DIR}" +LOG_FILE="${LOG_DIR}/${CONFIG_NAME}_${TIMESTAMP}.log" + +source ~/conda/miniconda/bin/activate +conda activate alexfv + +echo "=== Distillation Training ===" +echo "Config: ${CONFIG}" +echo "Num GPUs: ${NUM_GPUS}" +echo "Num Nodes: ${NNODES}" +echo "Node Rank: ${NODE_RANK}" +echo "Master: ${MASTER_ADDR}:${MASTER_PORT}" +echo "Extra args: $*" +echo "Log file: ${LOG_FILE}" +echo "==============================" + +torchrun \ + --nnodes "${NNODES}" \ + --node_rank "${NODE_RANK}" \ + --nproc_per_node "${NUM_GPUS}" \ + --master_addr "${MASTER_ADDR}" \ + --master_port "${MASTER_PORT}" \ + fastvideo/train/entrypoint/train.py \ + --config "${CONFIG}" \ + "$@" \ + 2>&1 | tee "${LOG_FILE}" diff --git a/examples/train/self_forcing_wangame_causal_v3.yaml b/examples/train/self_forcing_wangame_causal_v3.yaml new file mode 100644 index 000000000..c72a4dc8c --- /dev/null +++ b/examples/train/self_forcing_wangame_causal_v3.yaml @@ -0,0 +1,122 @@ +# V3 config: WanGame causal Self-Forcing distillation (40-step teacher -> 4-step student). +# +# Uses _target_-based instantiation — each model role is an independent +# class instance; the method class is resolved directly from the YAML. +# +# To warmstart from a DCP checkpoint, first convert it to diffusers format +# using `dcp_to_diffusers`, then point `init_from` at the converted directory. + +models: + student: + _target_: fastvideo.train.models.wangame.WanGameCausalModel + # TODO: update to converted diffusers path + init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps + trainable: true + teacher: + _target_: fastvideo.train.models.wangame.WanGameCausalModel + # TODO: update to converted diffusers path + init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps + trainable: false + disable_custom_init_weights: true + critic: + _target_: fastvideo.train.models.wangame.WanGameCausalModel + # TODO: update to converted diffusers path + init_from: /mnt/weka/home/hao.zhang/kaiqin/wg_models/WanGame-2.1-0223-9000steps + trainable: true + disable_custom_init_weights: true + +method: + _target_: fastvideo.train.methods.distribution_matching.self_forcing.SelfForcingMethod + # use_ema: true + rollout_mode: simulate + generator_update_interval: 5 + real_score_guidance_scale: 3.5 + + # Critic / fake-score optimizer + fake_score_learning_rate: 8.0e-6 + fake_score_betas: [0.0, 0.999] + fake_score_lr_scheduler: constant + + warp_denoising_step: true + dmd_denoising_steps: [1000,750,500,250] + + chunk_size: 3 + student_sample_type: sde + same_step_across_blocks: false + last_step_only: false + context_noise: 0.0 + enable_gradient_in_rollout: true + start_gradient_frame: 0 + + cfg_uncond: + on_missing: error + action: keep + image: keep + text: keep + +training: + distributed: + num_gpus: 32 + sp_size: 1 + tp_size: 1 + hsdp_replicate_dim: 1 + hsdp_shard_dim: 32 + + data: + data_path: >- + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1, + /mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1 + dataloader_num_workers: 4 + train_batch_size: 1 + training_cfg_rate: 0.0 + seed: 1000 + num_latent_t: 20 + num_height: 352 + num_width: 640 + num_frames: 77 + + optimizer: + learning_rate: 2.0e-6 + betas: [0.0, 0.999] + weight_decay: 0.01 + lr_scheduler: constant + lr_warmup_steps: 0 + + loop: + max_train_steps: 4000 + gradient_accumulation_steps: 1 + + checkpoint: + output_dir: outputs/wangame_self_forcing_4steps_v3 + training_state_checkpointing_steps: 1000 + checkpoints_total_limit: 3 + + tracker: + project_name: distillation_wangame_r + run_name: wangame_self_forcing_4steps_v3 + + model: + enable_gradient_checkpointing_type: null + +callbacks: + grad_clip: + _target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback + max_grad_norm: 1.0 + # ema: + # _target_: fastvideo.train.callbacks.ema.EMACallback + # beta: 0.9999 + validation: + _target_: fastvideo.train.callbacks.validation.ValidationCallback + pipeline_target: fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline.WanGameCausalDMDPipeline + dataset_file: examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json + every_steps: 100 + sampling_steps: [4] + sampler_kind: sde + rollout_mode: streaming + guidance_scale: 1.0 + num_frames: 69 + dmd_denoising_steps: [1000, 750, 500, 250] + +pipeline: + flow_shift: 3 + sampler_kind: sde diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/ar_diff.slurm b/examples/training/consistency_finetune/causal_wangame_ode_init/ar_diff.slurm new file mode 100644 index 000000000..1fb83ad48 --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/ar_diff.slurm @@ -0,0 +1,126 @@ +#!/bin/bash +#SBATCH --job-name=wg-ar-diff +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=log/ar_diff_output/ar_diff_%j.out +#SBATCH --error=log/ar_diff_output/ar_diff_%j.err +#SBATCH --exclusive + +# Environment Setup +source ~/conda/miniconda/bin/activate +conda activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export MASTER_PORT=29503 +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN + +# Basic Info +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_MODE="online" +export NCCL_P2P_DISABLE=1 +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false + +echo "MASTER_ADDR: $MASTER_ADDR" +echo "NODE_RANK: $NODE_RANK" + +RUN_NAME=$(date +"%m%d_%H%M") +echo "RUN_NAME: $RUN_NAME" + +MODEL_PATH="../wg_models/WanGame-2.1-0223-9000steps" + +DATA_DIR="../traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed" +VALIDATION_DATASET_FILE="examples/training/consistency_finetune/causal_wangame_ode_init/validation_same.json" + +training_args=( + --tracker_project_name wangame_ar_diffusion + --output_dir "checkpoints/wangame_ar_diffusion_${RUN_NAME}" + --wandb_run_name "${RUN_NAME}_df_bs32" + --override_transformer_cls_name "CausalWanGameActionTransformer3DModel" + --max_train_steps 5000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 15 + --num_height 352 + --num_width 640 + --enable_gradient_checkpointing_type "full" + --num_frames 57 + --num_frame_per_block 3 +) + +parallel_args=( + --num_gpus 32 + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim 32 +) + +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 200 + --validation_sampling_steps "50" + --validation_guidance_scale "1.0" +) + +optimizer_args=( + --learning_rate 2e-5 + --mixed_precision "bf16" + --training_state_checkpointing_steps 500 + --weight_only_checkpointing_steps 500 + --weight_decay 0.01 + --betas '0.9,0.999' + --max_grad_norm 1.0 + --lr_scheduler cosine + --lr_warmup_steps 100 +) + +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.0 + --dit_precision "fp32" + --flow_shift 8 + --seed 42 +) + +mkdir -p log/ar_diff_output + +srun torchrun \ +--nnodes $SLURM_JOB_NUM_NODES \ +--nproc_per_node 8 \ +--node_rank $SLURM_PROCID \ +--rdzv_backend=c10d \ +--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ + fastvideo/training/wangame_ar_diffusion_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.sh b/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.sh new file mode 100644 index 000000000..9a351607b --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +export PYTHONPATH="/mnt/fast-disks/hao_lab/kaiqin/FastVideo_wangame:$PYTHONPATH" +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export TOKENIZERS_PARALLELISM=false + +MODEL_PATH="Wan2.1-Fun-1.3B-InP-Diffusers" +DATA_DIR="../traindata_0209_1500/ode_init_mc/preprocessed/combined_parquet_dataset/worker_0" +VALIDATION_DATASET_FILE="$(dirname "$0")/validation.json" +NUM_GPUS=1 +export CUDA_VISIBLE_DEVICES=4,5,6,7 +# IP=[MASTER NODE IP] + +# Training arguments +training_args=( + --tracker_project_name "wangame_ode_init" + --output_dir "checkpoints/wangame_ode_init" + --override_transformer_cls_name "CausalWanGameActionTransformer3DModel" + --wandb_run_name "0213_2100_test" + --max_train_steps 1 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 21 + --num_height 352 + --num_width 640 + --num_frames 81 + --warp_denoising_step + --enable_gradient_checkpointing_type "full" +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --log-visualization + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --visualization-steps 100 + --validation_sampling_steps "50" + --validation_guidance_scale "6.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 6e-6 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 200 + --training_state_checkpointing_steps 200 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t +torchrun \ + --nnodes 1 \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_ode_causal_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.slurm b/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.slurm new file mode 100644 index 000000000..cd02796fe --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/finetune_ode_init.slurm @@ -0,0 +1,131 @@ +#!/bin/bash +#SBATCH --job-name=wg-ode +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=ode_train_output/ode_%j.out +#SBATCH --error=ode_train_output/ode_%j.err +#SBATCH --exclusive + +set -e -x + +# Environment Setup +source ~/conda/miniconda/bin/activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +# Basic Info +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_MODE=online +export NCCL_P2P_DISABLE=1 +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false + +echo "MASTER_ADDR: $MASTER_ADDR" +echo "NODE_RANK: $NODE_RANK" + +RUN_NAME=$(date +"%m%d_%H%M") +echo "RUN_NAME: $RUN_NAME" + +# Configs +MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers" +# DATA_DIR="../traindata_0222_0030/ode_init_mc_Xonly_3k/preprocessed" +DATA_DIR="../traindata_0222_0030/ode_init_mc_random/preprocessed_wangame" +# VALIDATION_DATASET_FILE="examples/training/consistency_finetune/causal_wangame_ode_init/validation.json" +VALIDATION_DATASET_FILE="examples/training/consistency_finetune/causal_wangame_ode_init/validation_same.json" +CKPT_SAFETENSOR="/mnt/weka/home/hao.zhang/mhuo/FastVideo/wangame_1.3b_1action_rand_from_scratch/checkpoint-9000/transformer/diffusion_pytorch_model.safetensors" + +# Training arguments +training_args=( + --tracker_project_name "wangame_ode_init" + --output_dir "checkpoints/wangame_ode_init_${RUN_NAME}" + --override_transformer_cls_name "CausalWanGameActionTransformer3DModel" + --wandb_run_name "${RUN_NAME}_bs64_random" + --max_train_steps 5000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 2 + --num_latent_t 21 + --num_height 352 + --num_width 640 + --num_frames 81 + --warp_denoising_step + --enable_gradient_checkpointing_type "full" + --init_weights_from_safetensors $CKPT_SAFETENSOR +) + +# Parallel arguments +parallel_args=( + --num_gpus 32 + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim 32 +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +# Validation arguments +validation_args=( + --log_validation + --log-visualization + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --visualization-steps 100 + --validation_sampling_steps "50" + --validation_guidance_scale "6.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 6e-6 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 500 + --training_state_checkpointing_steps 500 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 +) + +mkdir -p ode_train_output + +srun torchrun \ +--nnodes $SLURM_JOB_NUM_NODES \ +--nproc_per_node 8 \ +--node_rank $SLURM_PROCID \ +--rdzv_backend=c10d \ +--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ + fastvideo/training/wangame_ode_causal_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/launch_preprocess_slurm.sh b/examples/training/consistency_finetune/causal_wangame_ode_init/launch_preprocess_slurm.sh new file mode 100644 index 000000000..a07f03760 --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/launch_preprocess_slurm.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Create output directory if it doesn't exist +mkdir -p preprocess_output + +# Launch 8 jobs, one for each node (Total 64 GPUs) +# Each node processes 8 consecutive files (64 total files / 8 nodes = 8 files per node) +for node_id in {0..3}; do + # Calculate the starting file number for this node + start_file=$((node_id * 8)) + + echo "Launching node $node_id with files merge_${start_file}.txt to merge_$((start_file + 7)).txt" + + sbatch --job-name=wg-pre-${node_id} \ + --output=preprocess_output/wg-node-${node_id}.out \ + --error=preprocess_output/wg-node-${node_id}.err \ + $(pwd)/FastVideo/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_worker.slurm $start_file $node_id +done + +echo "All 4 nodes (32 GPUs) launched successfully!" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/ode_finetune_worker.slurm b/examples/training/consistency_finetune/causal_wangame_ode_init/ode_finetune_worker.slurm new file mode 100644 index 000000000..6311b44fa --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/ode_finetune_worker.slurm @@ -0,0 +1,61 @@ +#!/bin/bash +#SBATCH --partition=main +#SBATCH --qos=hao +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=960G +#SBATCH --exclusive +#SBATCH --time=72:00:00 + +# conda init +source ~/conda/miniconda/bin/activate +conda activate fastvideo_kaiqin + +# Accept parameters from launch script +START_FILE=${1:-0} # Starting file number for this node +NODE_ID=${2:-0} # Node identifier (0-7) + +MODEL_PATH="../Matrix-Game-2.0-Base-Diffusers" +OUTPUT_BASE="../FastvideoWorldModel-MC/preprocessed" + +# Port range calculation +base_port=$((29700 + NODE_ID * 100)) # Using a different port range to avoid collision with other tasks +gpu_ids=(0 1 2 3 4 5 6 7) + +for i in {1..8}; do + port=$((base_port + i)) + gpu=${gpu_ids[((i-1))]} + file_num=$((START_FILE + i - 1)) + + DATA_MERGE_PATH="../FastvideoWorldModel-MC/gen/merge_${file_num}.txt" + OUTPUT_DIR="${OUTPUT_BASE}/gpu_${gpu}_file_${file_num}" + + # CPU binding + start_cpu=$(( (i-1)*2 )) + end_cpu=$(( start_cpu+1 )) + + echo "Starting GPU $gpu processing file merge_${file_num}.txt on port $port" + + CUDA_VISIBLE_DEVICES=$gpu taskset -c ${start_cpu}-${end_cpu} torchrun --nnodes=1 --nproc_per_node=1 --master_port $port \ + FastVideo/fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 1 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 81 \ + --flow_shift 5.0 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --train_fps 25 \ + --samples_per_file 8 \ + --flush_frequency 8 \ + --video_length_tolerance_range 5 \ + --preprocess_task "matrixgame_ode_trajectory" & +done + +wait +echo "Node $NODE_ID ODE preprocessing blocks completed!" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_data.sh b/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_data.sh new file mode 100644 index 000000000..172de86bf --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_data.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +export PYTHONPATH="/mnt/fast-disks/hao_lab/kaiqin/FastVideo_wangame:$PYTHONPATH" + +GPU_NUM=1 # 2,4,8 +MODEL_PATH="./Wan2.1-Fun-1.3B-InP-Diffusers" +DATA_MERGE_PATH="../traindata_0209_1500/ode_init_mc/merge.txt" +OUTPUT_DIR="../traindata_0209_1500/ode_init_mc/preprocessed" + +torchrun --nproc_per_node=$GPU_NUM \ + fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 1 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 81 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --samples_per_file 8 \ + --train_fps 25 \ + --flush_frequency 8 \ + --preprocess_task wangame_ode_trajectory & diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_worker.slurm b/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_worker.slurm new file mode 100644 index 000000000..79a57304e --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/preprocess_worker.slurm @@ -0,0 +1,62 @@ +#!/bin/bash +#SBATCH --partition=main +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=960G +#SBATCH --exclusive +#SBATCH --time=72:00:00 + +# conda init +source ~/conda/miniconda/bin/activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +# Accept parameters from launch script +START_FILE=${1:-1} # Starting file number for this node +NODE_ID=${2:-0} # Node identifier (0-7) + +MODEL_PATH="./Wan2.1-Fun-1.3B-InP-Diffusers" +# OUTPUT_BASE="traindata_0222_0030/ode_init_mc_Xonly_3k/preprocessed" +OUTPUT_BASE="traindata_0222_0030/ode_init_mc_same/preprocessed" + +# Port range calculation +base_port=$((29500 + NODE_ID * 100)) +gpu_ids=(0 1 2 3 4 5 6 7) + +for i in {1..8}; do + port=$((base_port + i)) + gpu=${gpu_ids[((i-1))]} + file_num=$((START_FILE + i - 1)) + + # DATA_MERGE_PATH="traindata_0222_0030/ode_init_mc_Xonly_3k/merge_${file_num}.txt" + DATA_MERGE_PATH="traindata_0222_0030/ode_init_mc_same/merge_${file_num}.txt" + OUTPUT_DIR="${OUTPUT_BASE}/gpu_${gpu}_file_${file_num}" + echo "DATA_MERGE_PATH: $DATA_MERGE_PATH" + echo "OUTPUT_DIR: $OUTPUT_DIR" + + # CPU binding (optional, kept from syn.slurm logic) + start_cpu=$(( (i-1)*2 )) + end_cpu=$(( start_cpu+1 )) + + echo "Starting GPU $gpu processing file merge_${file_num}.txt on port $port" + + CUDA_VISIBLE_DEVICES=$gpu taskset -c ${start_cpu}-${end_cpu} torchrun --nnodes=1 --nproc_per_node=1 --master_port $port \ + FastVideo/fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 1 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 81 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --samples_per_file 8 \ + --train_fps 25 \ + --flush_frequency 8 \ + --preprocess_task wangame_ode_trajectory & +done + +wait +echo "Node $NODE_ID processing blocks completed!" diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/validation.json b/examples/training/consistency_finetune/causal_wangame_ode_init/validation.json new file mode 100644 index 000000000..bdd9b6e45 --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/validation.json @@ -0,0 +1,324 @@ +{ + "data": [ + { + "caption": "51", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000051.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000051_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "229", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000229.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000229_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "250", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000250.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000250_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "380", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000380.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000380_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "382", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000382.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000382_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "387", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000387.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000387_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "418", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000418.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000418_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "505", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000505.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000505_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "515", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000515.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000515_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "534", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000534.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000534_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "599", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000599.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000599_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "613", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000613.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000613_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "745", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000745.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000745_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "861", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000861.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000861_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "940", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000940.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000940_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "946", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000946.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000946_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "996", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000996.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/000996_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1011", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001011.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001011_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1037", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001037.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001037_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1057", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001057.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001057_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1195", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001195.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001195_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1236", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001236.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001236_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1276", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001276.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001276_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1368", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001368.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001368_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1403", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001403.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001403_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1417", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001417.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001417_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1481", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001481.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001481_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1489", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001489.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001489_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1618", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001618.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001618_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1779", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001779.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001779_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1867", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001867.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001867_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1949", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001949.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_Xonly_3k/images/001949_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + } + ] +} \ No newline at end of file diff --git a/examples/training/consistency_finetune/causal_wangame_ode_init/validation_same.json b/examples/training/consistency_finetune/causal_wangame_ode_init/validation_same.json new file mode 100644 index 000000000..612000782 --- /dev/null +++ b/examples/training/consistency_finetune/causal_wangame_ode_init/validation_same.json @@ -0,0 +1,324 @@ +{ + "data": [ + { + "caption": "51", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000051.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000051_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "229", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000229.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000229_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "250", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000250.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000250_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "380", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000380.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000380_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "382", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000382.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000382_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "387", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000387.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000387_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "418", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000418.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000418_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "505", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000505.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000505_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "515", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000515.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000515_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "534", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000534.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000534_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "599", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000599.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000599_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "613", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000613.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000613_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "745", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000745.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000745_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "861", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000861.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000861_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "940", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000940.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000940_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "946", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000946.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000946_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "996", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000996.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/000996_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1011", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001011.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001011_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1037", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001037.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001037_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1057", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001057.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001057_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1195", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001195.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001195_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1236", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001236.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001236_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1276", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001276.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001276_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1368", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001368.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001368_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1403", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001403.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001403_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1417", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001417.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001417_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1481", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001481.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001481_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1489", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001489.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001489_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1618", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001618.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001618_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1779", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001779.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001779_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1867", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001867.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001867_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "1949", + "image_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001949.jpg", + "action_path": "/mnt/weka/home/hao.zhang/kaiqin/traindata_0222_0030/ode_init_mc_random/images/001949_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/README.md b/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/README.md new file mode 100644 index 000000000..fa58e4bfe --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/README.md @@ -0,0 +1,147 @@ +Total Files: 145 + +00: W +01: S +02: A +03: D +04: WA +05: WD +06: SA +07: SD +08: u +09: d +10: l +11: r +12: ur +13: ul +14: dr +15: dl +16: still +17: W_u +18: W_d +19: W_l +20: W_r +21: W_ur +22: W_ul +23: W_dr +24: W_dl +25: S_u +26: S_d +27: S_l +28: S_r +29: S_ur +30: S_ul +31: S_dr +32: S_dl +33: A_u +34: A_d +35: A_l +36: A_r +37: A_ur +38: A_ul +39: A_dr +40: A_dl +41: D_u +42: D_d +43: D_l +44: D_r +45: D_ur +46: D_ul +47: D_dr +48: D_dl +49: WA_u +50: WA_d +51: WA_l +52: WA_r +53: WA_ur +54: WA_ul +55: WA_dr +56: WA_dl +57: WD_u +58: WD_d +59: WD_l +60: WD_r +61: WD_ur +62: WD_ul +63: WD_dr +64: WD_dl +65: SA_u +66: SA_d +67: SA_l +68: SA_r +69: SA_ur +70: SA_ul +71: SA_dr +72: SA_dl +73: SD_u +74: SD_d +75: SD_l +76: SD_r +77: SD_ur +78: SD_ul +79: SD_dr +80: SD_dl +81: key_2_action_rand_1_f4 +82: key_2_action_rand_2_f4 +83: key_2_action_rand_3_f4 +84: key_2_action_rand_4_f4 +85: key_2_action_rand_1 +86: key_2_action_rand_2 +87: key_2_action_rand_3 +88: key_2_action_rand_4 +89: camera_2_action_rand_1_f4 +90: camera_2_action_rand_2_f4 +91: camera_2_action_rand_3_f4 +92: camera_2_action_rand_4_f4 +93: camera_2_action_rand_1 +94: camera_2_action_rand_2 +95: camera_2_action_rand_3 +96: camera_2_action_rand_4 +97: key_camera_2_action_rand_1_f4 +98: key_camera_2_action_rand_2_f4 +99: key_camera_2_action_rand_3_f4 +100: key_camera_2_action_rand_4_f4 +101: key_camera_2_action_rand_1 +102: key_camera_2_action_rand_2 +103: key_camera_2_action_rand_3 +104: key_camera_2_action_rand_4 +105: key_1_action_rand_1_f4 +106: key_1_action_rand_2_f4 +107: key_1_action_rand_3_f4 +108: key_1_action_rand_4_f4 +109: key_1_action_rand_1 +110: key_1_action_rand_2 +111: key_1_action_rand_3 +112: key_1_action_rand_4 +113: camera_1_action_rand_1_f4 +114: camera_1_action_rand_2_f4 +115: camera_1_action_rand_3_f4 +116: camera_1_action_rand_4_f4 +117: camera_1_action_rand_1 +118: camera_1_action_rand_2 +119: camera_1_action_rand_3 +120: camera_1_action_rand_4 +121: key_camera_1_action_rand_1_f4 +122: key_camera_1_action_rand_2_f4 +123: key_camera_1_action_rand_3_f4 +124: key_camera_1_action_rand_4_f4 +125: key_camera_1_action_rand_1 +126: key_camera_1_action_rand_2 +127: key_camera_1_action_rand_3 +128: key_camera_1_action_rand_4 +129: key_camera_excl_1_action_rand_1_f4 +130: key_camera_excl_1_action_rand_2_f4 +131: key_camera_excl_1_action_rand_3_f4 +132: key_camera_excl_1_action_rand_4_f4 +133: key_camera_excl_1_action_rand_1 +134: key_camera_excl_1_action_rand_2 +135: key_camera_excl_1_action_rand_3 +136: key_camera_excl_1_action_rand_4 +137: key_camera_excl_1_action_rand_1_f4 +138: key_camera_excl_1_action_rand_2_f4 +139: key_camera_excl_1_action_rand_3_f4 +140: key_camera_excl_1_action_rand_4_f4 +141: key_camera_excl_1_action_rand_1 +142: key_camera_excl_1_action_rand_2 +143: key_camera_excl_1_action_rand_3 +144: key_camera_excl_1_action_rand_4 diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/README.md b/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/README.md new file mode 100644 index 000000000..de15602b6 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/README.md @@ -0,0 +1,147 @@ +Total Files: 145 + +00: W +01: S +02: A +03: D +04: WA +05: WD +06: SA +07: SD +08: u +09: d +10: l +11: r +12: ur +13: ul +14: dr +15: dl +16: still +17: W_u +18: W_d +19: W_l +20: W_r +21: W_ur +22: W_ul +23: W_dr +24: W_dl +25: S_u +26: S_d +27: S_l +28: S_r +29: S_ur +30: S_ul +31: S_dr +32: S_dl +33: A_u +34: A_d +35: A_l +36: A_r +37: A_ur +38: A_ul +39: A_dr +40: A_dl +41: D_u +42: D_d +43: D_l +44: D_r +45: D_ur +46: D_ul +47: D_dr +48: D_dl +49: WA_u +50: WA_d +51: WA_l +52: WA_r +53: WA_ur +54: WA_ul +55: WA_dr +56: WA_dl +57: WD_u +58: WD_d +59: WD_l +60: WD_r +61: WD_ur +62: WD_ul +63: WD_dr +64: WD_dl +65: SA_u +66: SA_d +67: SA_l +68: SA_r +69: SA_ur +70: SA_ul +71: SA_dr +72: SA_dl +73: SD_u +74: SD_d +75: SD_l +76: SD_r +77: SD_ur +78: SD_ul +79: SD_dr +80: SD_dl +81: key_2_action_rand_1_f4 +82: key_2_action_rand_2_f4 +83: key_2_action_rand_3_f4 +84: key_2_action_rand_4_f4 +85: key_2_action_rand_1 +86: key_2_action_rand_2 +87: key_2_action_rand_3 +88: key_2_action_rand_4 +89: camera_2_action_rand_1_f4 +90: camera_2_action_rand_2_f4 +91: camera_2_action_rand_3_f4 +92: camera_2_action_rand_4_f4 +93: camera_2_action_rand_1 +94: camera_2_action_rand_2 +95: camera_2_action_rand_3 +96: camera_2_action_rand_4 +97: key_camera_2_action_rand_1_f4 +98: key_camera_2_action_rand_2_f4 +99: key_camera_2_action_rand_3_f4 +100: key_camera_2_action_rand_4_f4 +101: key_camera_2_action_rand_1 +102: key_camera_2_action_rand_2 +103: key_camera_2_action_rand_3 +104: key_camera_2_action_rand_4 +105: key_1_action_rand_1_f4 +106: key_1_action_rand_2_f4 +107: key_1_action_rand_3_f4 +108: key_1_action_rand_4_f4 +109: key_1_action_rand_1 +110: key_1_action_rand_2 +111: key_1_action_rand_3 +112: key_1_action_rand_4 +113: camera_1_action_rand_1_f4 +114: camera_1_action_rand_2_f4 +115: camera_1_action_rand_3_f4 +116: camera_1_action_rand_4_f4 +117: camera_1_action_rand_1 +118: camera_1_action_rand_2 +119: camera_1_action_rand_3 +120: camera_1_action_rand_4 +121: key_camera_1_action_rand_1_f4 +122: key_camera_1_action_rand_2_f4 +123: key_camera_1_action_rand_3_f4 +124: key_camera_1_action_rand_4_f4 +125: key_camera_1_action_rand_1 +126: key_camera_1_action_rand_2 +127: key_camera_1_action_rand_3 +128: key_camera_1_action_rand_4 +129: key_camera_excl_2_action_rand_1_f4 +130: key_camera_excl_2_action_rand_2_f4 +131: key_camera_excl_2_action_rand_3_f4 +132: key_camera_excl_2_action_rand_4_f4 +133: key_camera_excl_2_action_rand_1 +134: key_camera_excl_2_action_rand_2 +135: key_camera_excl_2_action_rand_3 +136: key_camera_excl_2_action_rand_4 +137: key_camera_excl_1_action_rand_1_f4 +138: key_camera_excl_1_action_rand_2_f4 +139: key_camera_excl_1_action_rand_3_f4 +140: key_camera_excl_1_action_rand_4_f4 +141: key_camera_excl_1_action_rand_1 +142: key_camera_excl_1_action_rand_2 +143: key_camera_excl_1_action_rand_3 +144: key_camera_excl_1_action_rand_4 diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.sh b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.sh new file mode 100644 index 000000000..e9869e1f5 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.sh @@ -0,0 +1,93 @@ +#!/bin/bash + +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=offline +export TOKENIZERS_PARALLELISM=false +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +DATA_DIR="mc_wasd_10/preprocessed/combined_parquet_dataset" +VALIDATION_DATASET_FILE="mc_wasd_10/validation.json" +NUM_GPUS=4 +# export CUDA_VISIBLE_DEVICES=0,1,2,3 +# IP=[MASTER NODE IP] + +# Training arguments +training_args=( + --tracker_project_name "wangame_1.3b_overfit" + --output_dir "wangame_1.3b_overfit" + --max_train_steps 1500 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 2e-5 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 1000 + --training_state_checkpointing_steps 1000 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t +torchrun \ + --nnodes 1 \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.slurm b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.slurm new file mode 100644 index 000000000..fc4eb79d2 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_i2v.slurm @@ -0,0 +1,120 @@ +#!/bin/bash +#SBATCH --job-name=wangame_1.3b_overfit +#SBATCH --partition=main +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=wangame_1.3b_overfit_output/wangame_1.3b_overfit_%j.out +#SBATCH --error=wangame_1.3b_overfit_output/wangame_1.3b_overfit_%j.err +#SBATCH --exclusive + +# Basic Info +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG_SUBSYS=INIT,NET +# different cache dir for different processes +export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false +# export WANDB_API_KEY="8d9f4b39abd68eb4e29f6fc010b7ee71a2207cde" +export WANDB_API_KEY="50632ebd88ffd970521cec9ab4a1a2d7e85bfc45" +# export WANDB_API_KEY='your_wandb_api_key_here' +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN + +source ~/conda/miniconda/bin/activate +conda activate wei-fv-distill +export HOME="/mnt/weka/home/hao.zhang/wei" + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +DATA_DIR="mc_wasd_10/preprocessed/combined_parquet_dataset" +VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation.json" +# Configs +NUM_GPUS=8 + +# Training arguments +training_args=( + --tracker_project_name "wangame_1.3b_overfit" + --output_dir "wangame_1.3b_overfit" + --max_train_steps 15000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 2e-5 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 1000000 + --training_state_checkpointing_steps 10000000 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t +torchrun \ + --nnodes 1 \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame.slurm b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame.slurm new file mode 100644 index 000000000..a6051f192 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame.slurm @@ -0,0 +1,182 @@ +#!/bin/bash +#SBATCH --job-name=wangame_1.3b +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=wangame_1.3b_output/wangame_1.3b_%j.out +#SBATCH --error=wangame_1.3b_output/wangame_1.3b_%j.err +#SBATCH --exclusive + +# Basic Info +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG_SUBSYS=INIT,NET +# different cache dir for different processes +export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29501 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false +export WANDB_API_KEY="d5b02b05e30d8cb34c7b31c6ae10416fc26dcb66" +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN +export FASTVIDEO_MAP_STYLE_CACHE_DIR="/mnt/weka/home/hao.zhang/mhuo/FastVideo/map_style_cache" + +source ~/conda/miniconda/bin/activate +conda activate mhuo-fv +export HOME="/mnt/weka/home/hao.zhang/mhuo" + +# Configs +NUM_GPUS=8 +NUM_NODES=4 +NUM_TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) +BS_PER_GPU=1 +GRADIENT_ACCUMULATION_STEPS=1 +WANDB_RUN_NAME="MC_1action_rand_from_scratch" +FREEZE_DIT=False +RUN_DIR="wangame_1.3b_1action_rand_from_scratch" +CHECKPOINTING_STEPS=1000 +ACTION_WARMUP_STEPS=0 +LEARNING_RATE=1e-5 + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +# CKPT_SAFETENSOR="wangame_1.3b_with_warmup_lr_1e-5/checkpoint-7000/transformer/diffusion_pytorch_model.safetensors" +# +# Data dirs (use one of the following): +# - DATA_DIR_ALL: all datasets below combined (comma-separated) +# - Or a single path / subset; optional ":N" = repeat, ":0" = skip +# +DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0" # Random +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:0" # Doom +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:1" # Static + w only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:1" # w/s/a/d only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:1" # wasd only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:1" # camera l-only and r-only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:1" # camera only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:1" # key_camera_excl_1_action_rand + +VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random.json" +# +# Single-dir / validation alternatives (comment out DATA_DIR above and uncomment one block): +# MC wasd only: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_wsad.json" +# MC random: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_mc.json" +# Doom: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_doom.json" +# Overfit: +# DATA_DIR="mc_wasd_10/preprocessed/combined_parquet_dataset" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_overfit.json" + + +# Training arguments +training_args=( + --tracker_project_name "wangame_1.3b" + --output_dir $RUN_DIR + --wandb_run_name "$WANDB_RUN_NAME" + --max_train_steps 20000 + --train_batch_size $BS_PER_GPU + --train_sp_batch_size $BS_PER_GPU + --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" + --train_action_only $FREEZE_DIT + --action_warmup_steps $ACTION_WARMUP_STEPS + # --init_weights_from_safetensors $CKPT_SAFETENSOR +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_TOTAL_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_TOTAL_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" + --validation_num_samples $NUM_TOTAL_GPUS +) + +# Optimizer arguments +optimizer_args=( + --learning_rate $LEARNING_RATE + --mixed_precision "bf16" + --weight_only_checkpointing_steps 100000 + --training_state_checkpointing_steps $CHECKPOINTING_STEPS + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 2 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t + +if [ $NUM_NODES -eq 1 ]; then + torchrun \ + --nnodes $NUM_NODES \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" +else + srun torchrun \ + --nnodes $NUM_NODES \ + --nproc_per_node $NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --node_rank $SLURM_PROCID \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" +fi \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame_freeze_action.slurm b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame_freeze_action.slurm new file mode 100644 index 000000000..fb0d30b24 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/finetune_wangame_freeze_action.slurm @@ -0,0 +1,191 @@ +#!/bin/bash +#SBATCH --job-name=wangame_1.3b +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=wangame_1.3b_output/wangame_1.3b_%j.out +#SBATCH --error=wangame_1.3b_output/wangame_1.3b_%j.err +#SBATCH --exclusive + +# Basic Info +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG_SUBSYS=INIT,NET +# different cache dir for different processes +export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29501 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false +export WANDB_API_KEY="d5b02b05e30d8cb34c7b31c6ae10416fc26dcb66" +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN +export FASTVIDEO_MAP_STYLE_CACHE_DIR="/mnt/weka/home/hao.zhang/mhuo/FastVideo/map_style_cache" + +source ~/conda/miniconda/bin/activate +conda activate mhuo-fv +export HOME="/mnt/weka/home/hao.zhang/mhuo" + +# Configs +NUM_GPUS=8 +NUM_NODES=4 # TODO: change this to 1 to debug +NUM_TOTAL_GPUS=$((NUM_GPUS * NUM_NODES)) +BS_PER_GPU=1 +GRADIENT_ACCUMULATION_STEPS=1 +WANDB_RUN_NAME="Doom from MC freeze action" +RUN_DIR="wangame_1.3b" +CHECKPOINTING_STEPS=100000 # This means checkpoint every 100000 steps, effectively no ckpt will be saved +ACTION_WARMUP_STEPS=100000 # This means the action modules will be frozen for the first 100000 steps. +# Effectively, during the total 100000 steps, action modules are always frozen. +LEARNING_RATE=1e-5 +# Freeze base DiT, only train action modules +Freeze_DiT=false + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +# CKPT_SAFETENSOR="wangame_1.3b_wsad_random_lr_1e-5/checkpoint-2000/transformer/diffusion_pytorch_model.safetensors" +# CKPT_SAFETENSOR="wangame_1.3b_with_warmup_lr_1e-5/checkpoint-7000/transformer/diffusion_pytorch_model.safetensors" +CKPT_SAFETENSOR="wangame_1.3b_1action_rand_from_scratch/checkpoint-7000/transformer/diffusion_pytorch_model.safetensors" + +# Data dirs (use one of the following): +# - DATA_DIR_ALL: all datasets below combined (comma-separated) +# - Or a single path / subset; optional ":N" = repeat, ":0" = skip +# +DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed:0" # Random +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed:1" # Doom +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/0_static_plus_w_only/preprocessed:0" # Static + w only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed:0" # w/s/a/d only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed:0" # wasd only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed:0" # camera l-only and r-only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed:0" # camera only +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed:0" # key_camera_excl_1_action_rand +DATA_DIR="$DATA_DIR,/mnt/weka/home/hao.zhang/alex/wm-lab/datas/cache/zelda_overfit:0" # zelda + +VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random.json" # TODO: double check this, you may remove some MC image and add more doom image +# +# Single-dir / validation alternatives (comment out DATA_DIR above and uncomment one block): +# MC wasd only: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_wsad.json" +# MC random: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_mc.json" +# Doom: +# DATA_DIR="/mnt/weka/home/hao.zhang/mhuo/traindata_0204_1600/preprocessed" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_doom.json" +# Overfit: +# DATA_DIR="mc_wasd_10/preprocessed/combined_parquet_dataset" +# VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v/validation_overfit.json" + + +# Training arguments +training_args=( + --tracker_project_name "wangame_1.3b" + --output_dir $RUN_DIR + --wandb_run_name "$WANDB_RUN_NAME" + --max_train_steps 10000 + --train_batch_size $BS_PER_GPU + --train_sp_batch_size $BS_PER_GPU + --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" + --action_warmup_steps $ACTION_WARMUP_STEPS + --init_weights_from_safetensors $CKPT_SAFETENSOR + # TODO: check terminal log or log file xxx.err, whether there is a output line saying "Starting training with 1.53 B trainable parameters (total)". If both action modules are frozen, it should be around 1.53B, otherwise it could be 1.6B. +) + +if [ "$FREEZE_DiT" = "true" ]; then + training_args+=(--train_action_only) +fi + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_TOTAL_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_TOTAL_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" + --validation_num_samples $NUM_TOTAL_GPUS +) + +# Optimizer arguments +optimizer_args=( + --learning_rate $LEARNING_RATE + --mixed_precision "bf16" + --weight_only_checkpointing_steps 100000 + --training_state_checkpointing_steps $CHECKPOINTING_STEPS + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 2 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t + +if [ $NUM_NODES -eq 1 ]; then + torchrun \ + --nnodes $NUM_NODES \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" +else + srun torchrun \ + --nnodes $NUM_NODES \ + --nproc_per_node $NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --node_rank $SLURM_PROCID \ + fastvideo/training/wangame_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" +fi \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/preprocess_wangame_data_i2v.sh b/examples/training/finetune/WanGame2.1_1.3b_i2v/preprocess_wangame_data_i2v.sh new file mode 100644 index 000000000..85a4fd0d2 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/preprocess_wangame_data_i2v.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +GPU_NUM=1 # 2,4,8 +MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers" +DATA_MERGE_PATH="mc_wasd_10/merge.txt" +OUTPUT_DIR="mc_wasd_10/preprocessed/" + +# export CUDA_VISIBLE_DEVICES=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29500 +export RANK=0 +export WORLD_SIZE=1 + +python fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 10 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 77 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --samples_per_file 10 \ + --train_fps 25 \ + --flush_frequency 10 \ + --preprocess_task wangame \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/collect_samples_to_shao.py b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/collect_samples_to_shao.py new file mode 100644 index 000000000..1b2e886ef --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/collect_samples_to_shao.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +""" +For each data dir (from finetune_wangame.slurm), randomly pick 10 samples (mp4 + action.npy), +copy to to_shao// as 01.mp4, 01_action.npy, ..., 10.mp4, 10_action.npy, +and extract first frame as 01.jpg, ..., 10.jpg. +""" +import os +import random +import shutil + +import cv2 + +# Data dirs from finetune_wangame.slurm (paths with "preprocessed"; we use "video" or "videos" for mp4/npy) +DATA_DIRS = [ + "/mnt/weka/home/hao.zhang/mhuo/traindata_0204_2130/preprocessed", + "/mnt/weka/home/hao.zhang/mhuo/traindata_0205_1330/data/1_wasd_only/preprocessed", + "/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/wasdonly_alpha1/preprocessed", + "/mnt/weka/home/hao.zhang/mhuo/traindata_0206_1200/data/camera/preprocessed", + "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/camera4hold_alpha1/preprocessed", + "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/preprocessed", +] + +OUT_ROOT = "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/to_shao" +NUM_SAMPLES = 10 + + +def get_video_dir(preprocessed_path: str) -> str | None: + """Replace 'preprocessed' with 'video' (or 'videos') to get the dir containing mp4/npy.""" + video_path = preprocessed_path.replace("preprocessed", "video") + if os.path.isdir(video_path): + return video_path + videos_path = preprocessed_path.replace("preprocessed", "videos") + if os.path.isdir(videos_path): + return videos_path + return None + + +# Override short name for specific data dirs (e.g. traindata_0204_2130 -> fully_random) +SHORT_NAME_OVERRIDES: dict[str, str] = { + "traindata_0204_2130": "fully_random", +} + + +def get_short_name(preprocessed_path: str) -> str: + """Short name = parent folder of the preprocessed dir, e.g. 1_wasd_only.""" + name = os.path.basename(os.path.normpath(os.path.dirname(preprocessed_path))) + return SHORT_NAME_OVERRIDES.get(name, name) + + +def find_samples(video_dir: str) -> list[str]: + """Return list of base names (no extension) that have both xxxxxx.mp4 and xxxxxx_action.npy.""" + samples = [] + for f in os.listdir(video_dir): + if f.endswith(".mp4"): + base = f[:-4] + action_path = os.path.join(video_dir, f"{base}_action.npy") + if os.path.isfile(action_path): + samples.append(base) + return samples + + +def extract_first_frame(mp4_path: str, jpg_path: str) -> None: + cap = cv2.VideoCapture(mp4_path) + ret, frame = cap.read() + cap.release() + if ret: + cv2.imwrite(jpg_path, frame) + + +def main() -> None: + random.seed(42) + os.makedirs(OUT_ROOT, exist_ok=True) + total_dir = os.path.join(OUT_ROOT, "total") + os.makedirs(total_dir, exist_ok=True) + total_idx = 0 + + for preprocessed_path in DATA_DIRS: + video_dir = get_video_dir(preprocessed_path) + if video_dir is None: + print(f"Skip (video dir not found): {preprocessed_path}") + continue + + short_name = get_short_name(preprocessed_path) + samples = find_samples(video_dir) + if len(samples) < NUM_SAMPLES: + print(f"Skip {short_name}: only {len(samples)} samples (need {NUM_SAMPLES})") + continue + + chosen = random.sample(samples, NUM_SAMPLES) + out_dir = os.path.join(OUT_ROOT, short_name) + os.makedirs(out_dir, exist_ok=True) + + for i, base in enumerate(chosen, start=1): + num_str = f"{i:02d}" + src_mp4 = os.path.join(video_dir, f"{base}.mp4") + src_npy = os.path.join(video_dir, f"{base}_action.npy") + dst_mp4 = os.path.join(out_dir, f"{num_str}.mp4") + dst_npy = os.path.join(out_dir, f"{num_str}_action.npy") + dst_jpg = os.path.join(out_dir, f"{num_str}.jpg") + + shutil.copy2(src_mp4, dst_mp4) + shutil.copy2(src_npy, dst_npy) + extract_first_frame(dst_mp4, dst_jpg) + + # Copy into total/ with global numbering + total_idx += 1 + t_str = f"{total_idx:02d}" + shutil.copy2(dst_mp4, os.path.join(total_dir, f"{t_str}.mp4")) + shutil.copy2(dst_npy, os.path.join(total_dir, f"{t_str}_action.npy")) + shutil.copy2(dst_jpg, os.path.join(total_dir, f"{t_str}.jpg")) + + print(f"Done: {short_name} -> {out_dir} ({NUM_SAMPLES} samples)") + + print(f"Done: total -> {total_dir} ({total_idx} samples)") + + +if __name__ == "__main__": + main() diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_actions.py b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_actions.py new file mode 100644 index 000000000..ff8b4e76e --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_actions.py @@ -0,0 +1,278 @@ +import os +import numpy as np + +# Configuration +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_OUTPUT_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "actions_81")) +VIDEO_OUTPUT_DIR = BASE_OUTPUT_DIR + +os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True) + +CAM_VALUE = 0.1 +FRAME_COUNT = 80 + +# Action Mapping +KEY_TO_INDEX = { + 'W': 0, 'S': 1, 'A': 2, 'D': 3, +} + +VIEW_ACTION_TO_MOUSE = { + "stop": [0.0, 0.0], + "up": [CAM_VALUE, 0.0], + "down": [-CAM_VALUE, 0.0], + "left": [0.0, -CAM_VALUE], + "right": [0.0, CAM_VALUE], + "up_right": [CAM_VALUE, CAM_VALUE], + "up_left": [CAM_VALUE, -CAM_VALUE], + "down_right": [-CAM_VALUE, CAM_VALUE], + "down_left": [-CAM_VALUE, -CAM_VALUE], +} + +def get_multihot_vector(keys_str): + """Convert string like 'WA' to [1, 0, 1, 0, 0, 0]""" + vector = [0.0] * 6 + if not keys_str: + return vector + for char in keys_str.upper(): + if char in KEY_TO_INDEX: + vector[KEY_TO_INDEX[char]] = 1.0 + return vector + +def get_mouse_vector(view_str): + """Convert view string to [x, y]""" + return VIEW_ACTION_TO_MOUSE.get(view_str.lower(), [0.0, 0.0]) + +def generate_sequence(key_seq, mouse_seq): + """ + Generates action arrays based on sequences. + key_seq and mouse_seq must be length FRAME_COUNT. + Duplicates the first frame at the beginning, so output length is FRAME_COUNT + 1. + """ + if len(key_seq) != FRAME_COUNT or len(mouse_seq) != FRAME_COUNT: + raise ValueError("key_seq and mouse_seq must be length FRAME_COUNT") + + keyboard_arr = np.zeros((FRAME_COUNT, 6), dtype=np.float32) + mouse_arr = np.zeros((FRAME_COUNT, 2), dtype=np.float32) + + for i in range(FRAME_COUNT): + keyboard_arr[i] = get_multihot_vector(key_seq[i]) + mouse_arr[i] = get_mouse_vector(mouse_seq[i]) + + keyboard_arr = np.vstack([keyboard_arr[0:1], keyboard_arr]) + mouse_arr = np.vstack([mouse_arr[0:1], mouse_arr]) + + return keyboard_arr, mouse_arr + +def save_action(filename, keyboard_arr, mouse_arr): + if not filename.endswith(".npy"): + filename = f"{filename}.npy" + filepath = os.path.join(VIDEO_OUTPUT_DIR, filename) + + action_dict = { + 'keyboard': keyboard_arr, + 'mouse': mouse_arr + } + np.save(filepath, action_dict) + return filename + + +def build_constant_sequence(value): + return [value] * FRAME_COUNT + + +def build_random_sequence(actions, granularity, rng): + sequence = [] + remaining = FRAME_COUNT + while remaining > 0: + block = granularity if remaining >= granularity else remaining + action = rng.choice(actions) + sequence.extend([action] * block) + remaining -= block + return sequence + + +def build_random_sequence_either_or(key_actions, mouse_actions, granularity, rng): + """Build key_seq and mouse_seq where each block has either key OR mouse, not both.""" + key_seq = [] + mouse_seq = [] + remaining = FRAME_COUNT + while remaining > 0: + block = granularity if remaining >= granularity else remaining + use_key = rng.choice([True, False]) + if use_key: + key_action = rng.choice(key_actions) + mouse_action = "" + else: + key_action = "" + mouse_action = rng.choice(mouse_actions) + key_seq.extend([key_action] * block) + mouse_seq.extend([mouse_action] * block) + remaining -= block + return key_seq, mouse_seq + + +def mouse_short_name(view_str): + mapping = { + "up": "u", + "down": "d", + "left": "l", + "right": "r", + "up_right": "ur", + "up_left": "ul", + "down_right": "dr", + "down_left": "dl", + } + return mapping.get(view_str, "NA") + + +if __name__ == "__main__": + configs = [] + readme_content = [] + rng = np.random.default_rng(42) + + # configs = list of entries + # a entry is a tuple of (key_seq, mouse_seq) + # key_seq is a list of strings, length of FRAME_COUNT, each string is a key in 'W', 'S', 'A', 'D', 'WA', 'WD', 'SA', 'SD' + # mouse_seq is a list of strings, length of FRAME_COUNT, each string is a mouse action in 'up', 'down', 'left', 'right', 'up_right', 'up_left', 'down_right', 'down_left' + + # Naming: 1=WASDudlr (key: W.npy, SA.npy; camera: u.npy; key+camera: W_u.npy, SA_dl.npy). Rand: 1_action=WASD/UDLR+still only, 2_action=full set. 2-6=rand names below. + # Group 1: Constant Keyboard, No Mouse. W.npy, S.npy, WA.npy, SA.npy, ... + keys_basic = ["W", "S", "A", "D", "WA", "WD", "SA", "SD"] + for key in keys_basic: + configs.append( + (key, build_constant_sequence(key), build_constant_sequence("")) + ) + + # Group 2: No Keyboard, Constant Mouse. u.npy, d.npy, ur.npy, ... + mouse_basic = [ + "up", + "down", + "left", + "right", + "up_right", + "up_left", + "down_right", + "down_left", + ] + for mouse in mouse_basic: + name = mouse_short_name(mouse) + configs.append( + (name, build_constant_sequence(""), build_constant_sequence(mouse)) + ) + + # Group 3: Still. still.npy + configs.append(("still", build_constant_sequence(""), build_constant_sequence(""))) + + # Group 4: Constant key + camera. W_u.npy, SA_dl.npy, ... + for key in keys_basic: + for mouse in mouse_basic: + configs.append( + ( + f"{key}_{mouse_short_name(mouse)}", + build_constant_sequence(key), + build_constant_sequence(mouse), + ) + ) + + # Random groups: allow still ("") as an option (WASD+still, UDLR+still, and full sets+still) + keys_basic_still = keys_basic + [""] + mouse_basic_still = mouse_basic + [""] + + # Group 5: key_2_action_rand (full key set). key_2_action_rand_1..4, key_2_action_rand_1_f4..4_f4 + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq = build_random_sequence(keys_basic_still, granularity, rng) + configs.append( + (f"key_2_action_rand_{i}{suffix}", key_seq, build_constant_sequence("")) + ) + + # Group 6: camera_2_action_rand (full camera set) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + mouse_seq = build_random_sequence(mouse_basic_still, granularity, rng) + configs.append( + (f"camera_2_action_rand_{i}{suffix}", build_constant_sequence(""), mouse_seq) + ) + + # Group 7: key_camera_2_action_rand (both full sets) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq = build_random_sequence(keys_basic_still, granularity, rng) + mouse_seq = build_random_sequence(mouse_basic_still, granularity, rng) + configs.append( + (f"key_camera_2_action_rand_{i}{suffix}", key_seq, mouse_seq) + ) + + # WASD-only (no combined keys) and u/d/l/r-only (no combined directions), with still as option + keys_wasd_only = ["W", "S", "A", "D"] + mouse_udlr_only = ["up", "down", "left", "right"] + keys_wasd_still = keys_wasd_only + [""] + mouse_udlr_still = mouse_udlr_only + [""] + + # Group 8: key_1_action_rand (WASD+still only) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq = build_random_sequence(keys_wasd_still, granularity, rng) + configs.append( + (f"key_1_action_rand_{i}{suffix}", key_seq, build_constant_sequence("")) + ) + + # Group 9: camera_1_action_rand (UDLR+still only) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + mouse_seq = build_random_sequence(mouse_udlr_still, granularity, rng) + configs.append( + (f"camera_1_action_rand_{i}{suffix}", build_constant_sequence(""), mouse_seq) + ) + + # Group 10: key_camera_1_action_rand (WASD+still, UDLR+still) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq = build_random_sequence(keys_wasd_still, granularity, rng) + mouse_seq = build_random_sequence(mouse_udlr_still, granularity, rng) + configs.append( + (f"key_camera_1_action_rand_{i}{suffix}", key_seq, mouse_seq) + ) + + # Group 11a: key_camera_excl_2_action_rand (either key OR camera per block, full key + full camera set) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq, mouse_seq = build_random_sequence_either_or(keys_basic_still, mouse_basic_still, granularity, rng) + configs.append( + (f"key_camera_excl_2_action_rand_{i}{suffix}", key_seq, mouse_seq) + ) + + # Group 11b: key_camera_excl_1_action_rand (either key OR camera per block, WASD/UDLR+still) + for granularity in (4, 12): + suffix = "_f4" if granularity == 4 else "" + for i in range(1, 5): + key_seq, mouse_seq = build_random_sequence_either_or(keys_wasd_still, mouse_udlr_still, granularity, rng) + configs.append( + (f"key_camera_excl_1_action_rand_{i}{suffix}", key_seq, mouse_seq) + ) + + # Execution + print(f"Preparing to generate {len(configs)} action files...") + + for name, key_seq, mouse_seq in configs: + # Generate Data + kb_arr, ms_arr = generate_sequence(key_seq, mouse_seq) + filename = save_action(name, kb_arr, ms_arr) + readme_content.append(filename.replace(".npy", "")) + + print(f"Generated {filename}") + + readme_path = os.path.join(VIDEO_OUTPUT_DIR, "README.md") + with open(readme_path, "w") as f: + f.write(f"Total Files: {len(readme_content)}\n\n") + for idx, name in enumerate(readme_content): + f.write(f"{idx:02d}: {name}\n") + + print(f"{len(configs)} .npy files generated in {VIDEO_OUTPUT_DIR}") \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation.py b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation.py new file mode 100644 index 000000000..3444ba1eb --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation.py @@ -0,0 +1,271 @@ +import json +import os +import shutil + +import cv2 + +train = "zelda" + +if train == "zelda": + height = 480 + width = 832 + num_frames = 81 + action_dir = "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81" +elif train == "mc": + height = 352 + width = 640 + num_frames = 77 + action_dir = "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions" +else: + raise ValueError(f"Invalid train type: {train}") + +# Output path +output_path = ( + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/" + f"WanGame2.1_1.3b_i2v/validation_{train}.json" +) + +# Fixed fields +fixed_fields = { + "video_path": None, + "num_inference_steps": 40, + "height": height, + "width": width, + "num_frames": num_frames, +} + +# WASDudlr: single key W.npy, single camera u.npy, key+camera w_u.npy +still = os.path.join(action_dir, "still.npy") +key_W = os.path.join(action_dir, "W.npy") +key_S = os.path.join(action_dir, "S.npy") +key_A = os.path.join(action_dir, "A.npy") +key_D = os.path.join(action_dir, "D.npy") +key_wa = os.path.join(action_dir, "WA.npy") +key_s_u = os.path.join(action_dir, "S_u.npy") +camera_u = os.path.join(action_dir, "u.npy") +camera_d = os.path.join(action_dir, "d.npy") +camera_l = os.path.join(action_dir, "l.npy") +camera_r = os.path.join(action_dir, "r.npy") +# key_1_action_rand, camera_1_action_rand (full set); _f4 suffix for granularity 4 +key_1_action_rand_1 = os.path.join(action_dir, "key_1_action_rand_1.npy") +key_1_action_rand_2 = os.path.join(action_dir, "key_1_action_rand_2.npy") +key_1_action_rand_1_f4 = os.path.join(action_dir, "key_1_action_rand_1_f4.npy") +key_1_action_rand_2_f4 = os.path.join(action_dir, "key_1_action_rand_2_f4.npy") +camera_1_action_rand_1 = os.path.join(action_dir, "camera_1_action_rand_1.npy") +camera_1_action_rand_2 = os.path.join(action_dir, "camera_1_action_rand_2.npy") +camera_1_action_rand_1_f4 = os.path.join(action_dir, "camera_1_action_rand_1_f4.npy") +camera_1_action_rand_2_f4 = os.path.join(action_dir, "camera_1_action_rand_2_f4.npy") +key_camera_1_action_rand_1 = os.path.join(action_dir, "key_camera_1_action_rand_1.npy") +key_camera_1_action_rand_2 = os.path.join(action_dir, "key_camera_1_action_rand_2.npy") +key_camera_1_action_rand_1_f4 = os.path.join(action_dir, "key_camera_1_action_rand_1_f4.npy") +key_camera_1_action_rand_2_f4 = os.path.join(action_dir, "key_camera_1_action_rand_2_f4.npy") +key_camera_excl_1_action_rand_1 = os.path.join(action_dir, "key_camera_excl_1_action_rand_1.npy") +key_camera_excl_1_action_rand_2 = os.path.join(action_dir, "key_camera_excl_1_action_rand_2.npy") +key_camera_excl_1_action_rand_1_f4 = os.path.join(action_dir, "key_camera_excl_1_action_rand_1_f4.npy") +key_camera_excl_1_action_rand_2_f4 = os.path.join(action_dir, "key_camera_excl_1_action_rand_2_f4.npy") +# key_2_action_rand, camera_2_action_rand (WASD/UDLR+still) +key_2_action_rand_1 = os.path.join(action_dir, "key_2_action_rand_1.npy") +key_2_action_rand_1_f4 = os.path.join(action_dir, "key_2_action_rand_1_f4.npy") +camera_2_action_rand_1 = os.path.join(action_dir, "camera_2_action_rand_1.npy") +camera_2_action_rand_1_f4 = os.path.join(action_dir, "camera_2_action_rand_1_f4.npy") +key_camera_2_action_rand_1 = os.path.join(action_dir, "key_camera_2_action_rand_1.npy") +key_camera_2_action_rand_1_f4 = os.path.join(action_dir, "key_camera_2_action_rand_1_f4.npy") +key_camera_excl_2_action_rand_1 = os.path.join(action_dir, "key_camera_excl_2_action_rand_1.npy") +key_camera_excl_2_action_rand_1_f4 = os.path.join(action_dir, "key_camera_excl_2_action_rand_1_f4.npy") + + +train_img_zelda_list = [ + # "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/-BxyBxfDKA0_chunk_0292/segment0001.jpg", + # "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/-BxyBxfDKA0_chunk_0292/segment0003.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0484/segment0002.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0019/segment0004.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0140/segment0003.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", +] + +val_img_zelda_list = train_img_zelda_list +train_action_zelda_list = [] +for img in train_img_zelda_list: + img_dir = os.path.dirname(img) + basename = os.path.splitext(os.path.basename(img))[0] + action_path = os.path.join( + img_dir, + "postprocess/action/majority_voting/" + "81_frame_no_button", + f"{basename}.npy", + ) + train_action_zelda_list.append(action_path) + + +val_img_mc_list = [ + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000002.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000003.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000004.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000005.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000000.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000001.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000006.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000007.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000005.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000013.jpg", +] + +# Get train data list +train_mc_data_dir = "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1" +train_mc_idx_list = ["000000", "000500", "001000", "001500", "002000", "002500", "003000", "003500"] +train_mc_img_list = [] +train_mc_action_list = [] + +for idx in train_mc_idx_list: + video_path = os.path.join(train_mc_data_dir, f"videos/{idx}.mp4") + # extract the first frame as image + image_path = os.path.join(train_mc_data_dir, f"first_frame/{idx}.jpg") + os.makedirs(os.path.dirname(image_path), exist_ok=True) + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + cap.release() + if ret: + cv2.imwrite(image_path, frame) + train_mc_img_list.append(image_path) + train_mc_action_list.append(os.path.join(train_mc_data_dir, f"videos/{idx}_action.npy")) + + +# Get doom Val data list +val_img_doom_list = [ + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000000.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000001.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000002.jpg", + "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000003.jpg", +] + +if train == "mc": + val_img_list = val_img_mc_list + train_img_list = train_mc_img_list + train_action_list = train_mc_action_list +elif train == "zelda": + val_img_list = val_img_zelda_list + train_img_list = train_img_zelda_list + train_action_list = train_action_zelda_list +elif train == "doom": + val_img_list = val_img_doom_list +else: + raise ValueError(f"Invalid train type: {train}") + + +holder = 0 # placeholder +# 32 placeholders (idx 0-31). Fill in manually. +a0 = ["00 Val-00: W", val_img_list[0], key_W] +a1 = ["01 Val-01: S", val_img_list[1], key_S] +a2 = ["02 Val-02: A", val_img_list[2], key_A] +a3 = ["03 Val-03: D", val_img_list[3], key_D] +a4 = ["04 Val-04: u", val_img_list[4], camera_u] +a5 = ["05 Val-05: d", val_img_list[5], camera_d] +a6 = ["06 Val-06: l", val_img_list[4], camera_l] +a7 = ["07 Val-07: r", val_img_list[5], camera_r] +a8 = ["08 Val-00: key rand", val_img_list[0], key_1_action_rand_1] +a9 = ["09 Val-01: key rand", val_img_list[1], key_1_action_rand_2] +a10 = ["10 Val-02: camera rand", val_img_list[2], camera_1_action_rand_1] +a11 = ["11 Val-03: camera rand", val_img_list[3], camera_1_action_rand_2] +a12 = ["12 Val-00: key+camera excl rand", val_img_list[0], key_camera_excl_1_action_rand_1] +a13 = ["13 Val-01: key+camera excl rand", val_img_list[1], key_camera_excl_1_action_rand_2] +a14 = ["14 Val-02: key+camera rand", val_img_list[2], key_camera_1_action_rand_1] +a15 = ["15 Val-03: key+camera rand", val_img_list[3], key_camera_1_action_rand_2] +a16 = ["16 Val-04: (simultaneous) key rand", val_img_list[4], key_2_action_rand_1] +a17 = ["17 Val-05: (simultaneous) camera rand", val_img_list[5], camera_2_action_rand_1] +a18 = ["18 Val-06: (simultaneous) key+camera excl rand", val_img_list[5], key_camera_excl_2_action_rand_1] +a19 = ["19 Val-07: (simultaneous) key+camera rand", val_img_list[5], key_camera_2_action_rand_1] +a20 = ["20 Val-08: W+A", val_img_list[0], key_wa] +a21 = ["21 Val-09: S+u", val_img_list[1], key_s_u] +a22 = ["22 Val-08: Still", val_img_list[2], still] +a23 = ["23 Val-09: Still", val_img_list[3], still] +a24 = ["24 Val-06: key+camera excl rand Frame 4", val_img_list[4], key_camera_excl_1_action_rand_1_f4] +a25 = ["25 Val-07: key+camera excl rand Frame 4", val_img_list[5], key_camera_excl_1_action_rand_2_f4] +a26 = ["26 Train-00", train_img_list[0], train_action_list[0]] +a27 = ["27 Train-01", train_img_list[1], train_action_list[1]] +# a28 = ["28 Train-02", train_img_list[2], train_action_list[2]] +# a29 = ["29 Train-03", train_img_list[3], train_action_list[3]] +# a30 = ["30 Train-04", train_img_list[4], train_action_list[4]] +# a31 = ["31 Train-05", train_img_list[5], train_action_list[5]] +a28 = ["28 Doom-00: W", val_img_doom_list[0], key_W] +a29 = ["29 Doom-01: key rand", val_img_doom_list[1], key_1_action_rand_1] +a30 = ["30 Doom-02: camera rand", val_img_doom_list[2], camera_1_action_rand_1] +a31 = ["31 Doom-03: key+camera excl rand", val_img_doom_list[3], key_camera_excl_1_action_rand_1] + +Val_entries = { + 0: a0, + 1: a1, + 2: a2, + 3: a3, + 4: a4, + 5: a5, + 6: a6, + 7: a7, + 8: a8, + 9: a9, + 10: a10, + 11: a11, + 12: a12, + 13: a13, + 14: a14, + 15: a15, + 16: a16, + 17: a17, + 18: a18, + 19: a19, + 20: a20, + 21: a21, + 22: a22, + 23: a23, + 24: a24, + 25: a25, + 26: a26, + 27: a27, + 28: a28, + 29: a29, + 30: a30, + 31: a31, +} + +data = [] +for idx in range(32): + if idx not in Val_entries: + raise ValueError(f"Missing entry for idx {idx}") + caption, image_path, action_path = Val_entries[idx] + data.append( + { + "caption": caption, + "image_path": image_path, + "action_path": action_path, + **fixed_fields, + } + ) + +output = {"data": data} +with open(output_path, "w") as f: + json.dump(output, f, indent=4) + +print(f"Generated {len(data)} entries to {output_path}") + +# Check file all exists + +with open(output_path) as f: + data = json.load(f) + +missing = [] +for i, item in enumerate(data['data']): + for key in ('image_path', 'action_path'): + path = item.get(key) + if path: + import os + if not os.path.isfile(path): + missing.append((i, key, path)) +if missing: + print('Missing paths:') + for idx, key, path in missing: + print(f' [{idx}] {key}: {path}') +else: + print('All paths exist.') + + diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_static_w.py b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_static_w.py new file mode 100644 index 000000000..10e285750 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_static_w.py @@ -0,0 +1,71 @@ +import json +import os + +# Paths for two image directories +image_dir_val = "/mnt/weka/home/hao.zhang/kaiqin/traindata_0205_1330/data/0_static_plus_w_only/first_frame" +image_dir_train = "/mnt/weka/home/hao.zhang/kaiqin/traindata_0205_1330/data/0_same_1st_frame_static_plus_w_only/first_frame" + +# Action paths (used for both) +action_still = "/mnt/weka/home/hao.zhang/kaiqin/traindata_0205_1330/data/0_static_plus_w_only/videos/000000_action.npy" +action_w = "/mnt/weka/home/hao.zhang/kaiqin/traindata_0205_1330/data/0_static_plus_w_only/videos/001050_action.npy" + +# Output path +output_path = "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_static_w.json" + +# Fixed fields +fixed_fields = { + "video_path": None, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 +} + +data = [] + +# 16 images from each directory, alternating: val (0,1), train (2,3), val (4,5), train (6,7), ... +for i in range(16): + # Val images: indices 0,1, 4,5, 8,9, ... (pair index 0, 2, 4, ...) + image_path_val = os.path.join(image_dir_val, f"{i:06d}.png") + + # Still action for val + data.append({ + "caption": f"val {i:02d} - Still", + "image_path": image_path_val, + "action_path": action_still, + **fixed_fields + }) + + # W action for val + data.append({ + "caption": f"val {i:02d} - W", + "image_path": image_path_val, + "action_path": action_w, + **fixed_fields + }) + + # Train images: indices 2,3, 6,7, 10,11, ... (pair index 1, 3, 5, ...) + image_path_train = os.path.join(image_dir_train, f"{i:06d}.png") + + # Still action for train + data.append({ + "caption": f"train {i:02d} - Still", + "image_path": image_path_train, + "action_path": action_still, + **fixed_fields + }) + + # W action for train + data.append({ + "caption": f"train {i:02d} - W", + "image_path": image_path_train, + "action_path": action_w, + **fixed_fields + }) + +# Write to file +output = {"data": data} +with open(output_path, "w") as f: + json.dump(output, f, indent=4) + +print(f"Generated {len(data)} entries to {output_path}") diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_to_shao.py b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_to_shao.py new file mode 100644 index 000000000..9074c5e25 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/scripts/generate_validation_to_shao.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +""" +Generate a validation JSON where all image_path and action_path entries come from +examples/training/finetune/WanGame2.1_1.3b_i2v/to_shao. + +Scans each subfolder of to_shao for pairs (NN.jpg, NN_action.npy) and emits one +validation entry per pair. Uses the same fixed_fields as generate_validation.py. +""" +import json +import os + +# Paths relative to this script / repo +FINETUNE_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +TO_SHAO_DIR = os.path.join(FINETUNE_DIR, "to_shao") +OUTPUT_PATH = os.path.join(FINETUNE_DIR, "validation_to_shao.json") + +# Same fixed fields as generate_validation.py +FIXED_FIELDS = { + "video_path": None, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77, +} + + +def collect_samples_from_to_shao(): + """Yield (caption, image_path, action_path) for each sample under to_shao.""" + if not os.path.isdir(TO_SHAO_DIR): + raise FileNotFoundError(f"to_shao directory not found: {TO_SHAO_DIR}") + + for subdir_name in sorted(os.listdir(TO_SHAO_DIR)): + subdir = os.path.join(TO_SHAO_DIR, subdir_name) + if not os.path.isdir(subdir): + continue + # Find all NN.jpg (or NN.jpeg) and matching NN_action.npy + for f in sorted(os.listdir(subdir)): + if f.endswith(".jpg") or f.endswith(".jpeg"): + base = f[: -4] if f.endswith(".jpg") else f[:-5] + action_name = f"{base}_action.npy" + action_path = os.path.join(subdir, action_name) + if not os.path.isfile(action_path): + continue + image_path = os.path.join(subdir, f) + caption = f"to_shao/{subdir_name}/{base}" + yield caption, image_path, action_path + + +def main(): + data = [] + for caption, image_path, action_path in collect_samples_from_to_shao(): + data.append( + { + "caption": caption, + "image_path": image_path, + "action_path": action_path, + **FIXED_FIELDS, + } + ) + + output = {"data": data} + with open(OUTPUT_PATH, "w") as f: + json.dump(output, f, indent=4) + + print(f"Generated {len(data)} entries to {OUTPUT_PATH}") + + # Check all paths exist + missing = [] + with open(OUTPUT_PATH) as f: + loaded = json.load(f) + for i, item in enumerate(loaded["data"]): + for key in ("image_path", "action_path"): + path = item.get(key) + if path and not os.path.isfile(path): + missing.append((i, key, path)) + if missing: + print("Missing paths:") + for idx, key, path in missing: + print(f" [{idx}] {key}: {path}") + else: + print("All paths exist.") + + +if __name__ == "__main__": + main() diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/validation.json b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation.json new file mode 100644 index 000000000..5af4cca74 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation.json @@ -0,0 +1,404 @@ +{ + "data": [ + { + "caption": "0", + "image_path": "../../../../mc_wasd_10/validate/000000.jpg", + "action_path": "../../../../mc_wasd_10/videos/000000_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "1", + "image_path": "../../../../mc_wasd_10/validate/000001.jpg", + "action_path": "../../../../mc_wasd_10/videos/000001_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "2", + "image_path": "../../../../mc_wasd_10/validate/000002.jpg", + "action_path": "../../../../mc_wasd_10/videos/000002_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "3", + "image_path": "../../../../mc_wasd_10/validate/000003.jpg", + "action_path": "../../../../mc_wasd_10/videos/000003_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "4", + "image_path": "../../../../mc_wasd_10/validate/000004.jpg", + "action_path": "../../../../mc_wasd_10/videos/000004_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "5", + "image_path": "../../../../mc_wasd_10/validate/000005.jpg", + "action_path": "../../../../mc_wasd_10/videos/000005_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "6", + "image_path": "../../../../mc_wasd_10/validate/000006.jpg", + "action_path": "../../../../mc_wasd_10/videos/000006_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "7", + "image_path": "../../../../mc_wasd_10/validate/000007.jpg", + "action_path": "../../../../mc_wasd_10/videos/000007_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "00. Hold [W] + Static", + "image_path": "../../../../mc_wasd_10/validate/000000.jpg", + "action_path": "action/000000_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "01. Hold [S] + Static", + "image_path": "../../../../mc_wasd_10/validate/000001.jpg", + "action_path": "action/000001_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "02. Hold [A] + Static", + "image_path": "../../../../mc_wasd_10/validate/000002.jpg", + "action_path": "action/000002_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "03. Hold [D] + Static", + "image_path": "../../../../mc_wasd_10/validate/000003.jpg", + "action_path": "action/000003_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "04. Hold [WA] + Static", + "image_path": "../../../../mc_wasd_10/validate/000004.jpg", + "action_path": "action/000004_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "05. Hold [WD] + Static", + "image_path": "../../../../mc_wasd_10/validate/000005.jpg", + "action_path": "action/000005_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "06. Hold [SA] + Static", + "image_path": "../../../../mc_wasd_10/validate/000006.jpg", + "action_path": "action/000006_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "07. Hold [SD] + Static", + "image_path": "../../../../mc_wasd_10/validate/000007.jpg", + "action_path": "action/000007_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "08. No Key + Hold [up]", + "image_path": "../../../../mc_wasd_10/validate/000000.jpg", + "action_path": "action/000008_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "09. No Key + Hold [down]", + "image_path": "../../../../mc_wasd_10/validate/000001.jpg", + "action_path": "action/000009_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "10. No Key + Hold [left]", + "image_path": "../../../../mc_wasd_10/validate/000002.jpg", + "action_path": "action/000010_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "11. No Key + Hold [right]", + "image_path": "../../../../mc_wasd_10/validate/000003.jpg", + "action_path": "action/000011_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "12. No Key + Hold [up_right]", + "image_path": "../../../../mc_wasd_10/validate/000004.jpg", + "action_path": "action/000012_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "13. No Key + Hold [up_left]", + "image_path": "../../../../mc_wasd_10/validate/000005.jpg", + "action_path": "action/000013_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "14. No Key + Hold [down_right]", + "image_path": "../../../../mc_wasd_10/validate/000006.jpg", + "action_path": "action/000014_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "15. No Key + Hold [down_left]", + "image_path": "../../../../mc_wasd_10/validate/000007.jpg", + "action_path": "action/000015_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "00. Hold [W] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000000.jpg", + "action_path": "action/000000_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "01. Hold [S] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000001.jpg", + "action_path": "action/000001_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "02. Hold [A] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000002.jpg", + "action_path": "action/000002_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "03. Hold [D] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000003.jpg", + "action_path": "action/000003_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "04. Hold [WA] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000004.jpg", + "action_path": "action/000004_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "05. Hold [WD] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000005.jpg", + "action_path": "action/000005_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "06. Hold [SA] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000006.jpg", + "action_path": "action/000006_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "07. Hold [SD] + Static", + "image_path": "../../../../mc_wasd_10/validate/gen_000007.jpg", + "action_path": "action/000007_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "08. No Key + Hold [up]", + "image_path": "../../../../mc_wasd_10/validate/gen_000000.jpg", + "action_path": "action/000008_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "09. No Key + Hold [down]", + "image_path": "../../../../mc_wasd_10/validate/gen_000001.jpg", + "action_path": "action/000009_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "10. No Key + Hold [left]", + "image_path": "../../../../mc_wasd_10/validate/gen_000002.jpg", + "action_path": "action/000010_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "11. No Key + Hold [right]", + "image_path": "../../../../mc_wasd_10/validate/gen_000003.jpg", + "action_path": "action/000011_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "12. No Key + Hold [up_right]", + "image_path": "../../../../mc_wasd_10/validate/gen_000004.jpg", + "action_path": "action/000012_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "13. No Key + Hold [up_left]", + "image_path": "../../../../mc_wasd_10/validate/gen_000005.jpg", + "action_path": "action/000013_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "14. No Key + Hold [down_right]", + "image_path": "../../../../mc_wasd_10/validate/gen_000006.jpg", + "action_path": "action/000014_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "15. No Key + Hold [down_left]", + "image_path": "../../../../mc_wasd_10/validate/gen_000007.jpg", + "action_path": "action/000015_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random.json b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random.json new file mode 100644 index 000000000..41e51fc79 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random.json @@ -0,0 +1,324 @@ +{ + "data": [ + { + "caption": "00 Val-00: W", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/W.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "01 Val-01: S", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/S.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "02 Val-02: A", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/A.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "03 Val-03: D", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/D.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "04 Val-04: u", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/u.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "05 Val-05: d", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/d.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "06 Val-06: l", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000006.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/l.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "07 Val-07: r", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000007.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/r.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "08 Val-00: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "09 Val-01: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "10 Val-02: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "11 Val-03: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/camera_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "12 Val-00: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "13 Val-01: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "14 Val-02: key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "15 Val-03: key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "16 Val-04: (simultaneous) key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "17 Val-05: (simultaneous) camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/camera_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "18 Val-06: (simultaneous) key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000006.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "19 Val-07: (simultaneous) key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000007.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "20 Val-08: W+A", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/WA.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "21 Val-09: S+u", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000013.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/S_u.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "22 Val-08: Still", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/still.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "23 Val-09: Still", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/humanplay/000013.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/still.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "24 Val-06: key+camera excl rand Frame 4", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000006.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_1_action_rand_1_f4.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "25 Val-07: key+camera excl rand Frame 4", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000007.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_1_action_rand_2_f4.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "26 Train-00", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/first_frame/000500.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/videos/000500_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "27 Train-01", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/first_frame/001000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/traindata_0208_2000/data/wasd4holdrandview_simple_1key1mouse1/videos/001000_action.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "28 Doom-00: W", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/W.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "29 Doom-01: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "30 Doom-02: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "31 Doom-03: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/key_camera_excl_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json new file mode 100644 index 000000000..0bde93715 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_random_8.json @@ -0,0 +1,84 @@ +{ + "data": [ + { + "caption": "00 Val-00: W", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/W.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "01 Val-01: S", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/S.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "02 Val-02: A", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/A.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "03 Val-03: D", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000005.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/D.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "04 Val-04: u", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/u.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "05 Val-05: d", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/d.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "06 Val-06: l", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000006.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/l.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + }, + { + "caption": "07 Val-07: r", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/mc_wasd_10/validate/000007.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions/r.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 352, + "width": 640, + "num_frames": 77 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_zelda.json b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_zelda.json new file mode 100644 index 000000000..8ad289afe --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v/validation_zelda.json @@ -0,0 +1,324 @@ +{ + "data": [ + { + "caption": "00 Val-00: W", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/W.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "01 Val-01: S", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/S.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "02 Val-02: A", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0484/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/A.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "03 Val-03: D", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0019/segment0004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/D.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "04 Val-04: u", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0140/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/u.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "05 Val-05: d", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/d.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "06 Val-06: l", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0140/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/l.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "07 Val-07: r", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/r.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "08 Val-00: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "09 Val-01: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "10 Val-02: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0484/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "11 Val-03: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0019/segment0004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/camera_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "12 Val-00: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "13 Val-01: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "14 Val-02: key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0484/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "15 Val-03: key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0019/segment0004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_1_action_rand_2.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "16 Val-04: (simultaneous) key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0140/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "17 Val-05: (simultaneous) camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/camera_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "18 Val-06: (simultaneous) key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "19 Val-07: (simultaneous) key+camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_2_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "20 Val-08: W+A", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/WA.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "21 Val-09: S+u", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/S_u.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "22 Val-08: Still", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0484/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/still.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "23 Val-09: Still", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0019/segment0004.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/still.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "24 Val-06: key+camera excl rand Frame 4", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0140/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_1_action_rand_1_f4.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "25 Val-07: key+camera excl rand Frame 4", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/N6ObBAt41bg_chunk_0300/segment0003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_1_action_rand_2_f4.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "26 Train-00", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0006/postprocess/action/majority_voting/81_frame_no_button/segment0002.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "27 Train-01", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/segment0002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/zelda/5TTrlqAguhQ_chunk_0067/postprocess/action/majority_voting/81_frame_no_button/segment0002.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "28 Doom-00: W", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000000.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/W.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "29 Doom-01: key rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000001.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "30 Doom-02: camera rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000002.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/camera_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + }, + { + "caption": "31 Doom-03: key+camera excl rand", + "image_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/doom/000003.jpg", + "action_path": "/mnt/weka/home/hao.zhang/mhuo/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v/actions_81/key_camera_excl_1_action_rand_1.npy", + "video_path": null, + "num_inference_steps": 40, + "height": 480, + "width": 832, + "num_frames": 81 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/action/README.md b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/action/README.md new file mode 100644 index 000000000..cdddcfab6 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/action/README.md @@ -0,0 +1,18 @@ +Total Files: 16 + +00. Hold [W] + Static +01. Hold [S] + Static +02. Hold [A] + Static +03. Hold [D] + Static +04. Hold [WA] + Static +05. Hold [WD] + Static +06. Hold [SA] + Static +07. Hold [SD] + Static +08. No Key + Hold [up] +09. No Key + Hold [down] +10. No Key + Hold [left] +11. No Key + Hold [right] +12. No Key + Hold [up_right] +13. No Key + Hold [up_left] +14. No Key + Hold [down_right] +15. No Key + Hold [down_left] diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.sh b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.sh new file mode 100644 index 000000000..dc66efdac --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +export WANDB_API_KEY="7ff8b6e8356924f7a6dd51a0342dd1a422ea9352" +export WANDB_BASE_URL="https://api.wandb.ai" +# export WANDB_MODE=online +export WANDB_MODE=offline +export TOKENIZERS_PARALLELISM=false +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN +export PYTHONPATH=$PYTHONPATH:$(pwd) + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +DATA_DIR="../traindata_0205_1330/data/0_static_plus_w_only/preprocessed" +VALIDATION_DATASET_FILE="$(dirname "$0")/validation.json" +NUM_GPUS=1 +# export CUDA_VISIBLE_DEVICES=0,1,2,3 +# IP=[MASTER NODE IP] + +source ~/conda/miniconda/bin/activate +conda activate /mnt/weka/home/hao.zhang/conda/miniconda/envs/mhuo-fv +export PYTHONPATH="/mnt/weka/home/hao.zhang/kaiqin/FastVideo:$PYTHONPATH" + +# Training arguments +training_args=( + --override-pipeline-cls-name "WanLingBotImageToVideoPipeline" + --override-transformer-cls-name "WanLingBotTransformer3DModel" + --tracker_project_name "wangame_lingbot_test" + --output_dir "wangame_lingbot_test" + --max_train_steps 100 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 2e-5 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 1000 + --training_state_checkpointing_steps 1000 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t +torchrun \ + --nnodes 1 \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_lingbot_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.slurm b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.slurm new file mode 100644 index 000000000..d40029c7e --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/finetune_i2v.slurm @@ -0,0 +1,120 @@ +#!/bin/bash +#SBATCH --job-name=wangame_1.3b_overfit +#SBATCH --partition=main +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=wangame_1.3b_overfit_output/wangame_1.3b_overfit_%j.out +#SBATCH --error=wangame_1.3b_overfit_output/wangame_1.3b_overfit_%j.err +#SBATCH --exclusive + +# Basic Info +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG_SUBSYS=INIT,NET +# different cache dir for different processes +export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export TOKENIZERS_PARALLELISM=false +# export WANDB_API_KEY="8d9f4b39abd68eb4e29f6fc010b7ee71a2207cde" +export WANDB_API_KEY="50632ebd88ffd970521cec9ab4a1a2d7e85bfc45" +# export WANDB_API_KEY='your_wandb_api_key_here' +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN + +source ~/conda/miniconda/bin/activate +conda activate wei-fv-distill +export HOME="/mnt/weka/home/hao.zhang/wei" + +MODEL_PATH="weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers" +DATA_DIR="mc_wasd_10/preprocessed/combined_parquet_dataset" +VALIDATION_DATASET_FILE="examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation.json" +# Configs +NUM_GPUS=8 + +# Training arguments +training_args=( + --tracker_project_name "wangame_1.3b_overfit" + --output_dir "wangame_1.3b_overfit" + --max_train_steps 15000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 352 + --num_width 640 + --num_frames 77 + --enable_gradient_checkpointing_type "full" +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim 1 + --hsdp_shard_dim $NUM_GPUS +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 1 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 100 + --validation_sampling_steps "40" + --validation_guidance_scale "1.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 2e-5 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 1000000 + --training_state_checkpointing_steps 10000000 + --weight_decay 1e-4 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --multi_phased_distill_schedule "4000-1" + --not_apply_cfg_solver + --dit_precision "fp32" + --num_euler_timesteps 50 + --ema_start_step 0 +) + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t +torchrun \ + --nnodes 1 \ + --nproc_per_node $NUM_GPUS \ + fastvideo/training/wangame_lingbot_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/generate_actions.py b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/generate_actions.py new file mode 100644 index 000000000..edc3ad6dd --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/generate_actions.py @@ -0,0 +1,193 @@ +import os +import numpy as np + +# Configuration +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_OUTPUT_DIR = os.path.join(SCRIPT_DIR, 'action') +VIDEO_OUTPUT_DIR = BASE_OUTPUT_DIR +os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True) + +FRAME_COUNT = 81 +CAM_VALUE = 0.1 + +# Action Mapping +KEY_TO_INDEX = { + 'W': 0, 'S': 1, 'A': 2, 'D': 3, +} + +VIEW_ACTION_TO_MOUSE = { + "stop": [0.0, 0.0], + "up": [CAM_VALUE, 0.0], + "down": [-CAM_VALUE, 0.0], + "left": [0.0, -CAM_VALUE], + "right": [0.0, CAM_VALUE], + "up_right": [CAM_VALUE, CAM_VALUE], + "up_left": [CAM_VALUE, -CAM_VALUE], + "down_right": [-CAM_VALUE, CAM_VALUE], + "down_left": [-CAM_VALUE, -CAM_VALUE], +} + +def get_multihot_vector(keys_str): + """Convert string like 'WA' to [1, 0, 1, 0, 0, 0]""" + vector = [0.0] * 6 + if not keys_str: + return vector + for char in keys_str.upper(): + if char in KEY_TO_INDEX: + vector[KEY_TO_INDEX[char]] = 1.0 + return vector + +def get_mouse_vector(view_str): + """Convert view string to [x, y]""" + return VIEW_ACTION_TO_MOUSE.get(view_str.lower(), [0.0, 0.0]) + +def generate_sequence(key_seq, mouse_seq): + """ + Generates action arrays based on sequences. + """ + keyboard_arr = np.zeros((FRAME_COUNT, 6), dtype=np.float32) + mouse_arr = np.zeros((FRAME_COUNT, 2), dtype=np.float32) + + mid_point = FRAME_COUNT // 2 + + # First Half + k_vec1 = get_multihot_vector(key_seq[0]) + m_vec1 = get_mouse_vector(mouse_seq[0]) + keyboard_arr[:mid_point] = k_vec1 + mouse_arr[:mid_point] = m_vec1 + + # Second Half + k_vec2 = get_multihot_vector(key_seq[1]) + m_vec2 = get_mouse_vector(mouse_seq[1]) + keyboard_arr[mid_point:] = k_vec2 + mouse_arr[mid_point:] = m_vec2 + + return keyboard_arr, mouse_arr + +def save_action(index, keyboard_arr, mouse_arr): + filename = f"{index:06d}_action.npy" + filepath = os.path.join(VIDEO_OUTPUT_DIR, filename) + + action_dict = { + 'keyboard': keyboard_arr, + 'mouse': mouse_arr + } + np.save(filepath, action_dict) + return filename + +def generate_description(key_seq, mouse_seq): + """Generates a human-readable string for the combination.""" + k1, k2 = key_seq + m1, m2 = mouse_seq + + # Format Keyboard Description + if not k1 and not k2: + k_desc = "No Key" + elif k1 == k2: + k_desc = f"Hold [{k1}]" + else: + k_desc = f"Switch [{k1}]->[{k2}]" + + # Format Mouse Description + if m1 == "stop" and m2 == "stop": + m_desc = "Static" + elif m1 == m2: + m_desc = f"Hold [{m1}]" + else: + m_desc = f"Switch [{m1}]->[{m2}]" + + return f"{k_desc} + {m_desc}" + +# ========================================== +# Main Generation Logic +# ========================================== + +configs = [] +readme_content = [] + +# Group 1: Constant Keyboard, No Mouse (0-7) +keys_basic = ['W', 'S', 'A', 'D', 'WA', 'WD', 'SA', 'SD'] +for k in keys_basic: + configs.append(((k, k), ("stop", "stop"))) + +# Group 2: No Keyboard, Constant Mouse (8-15) +mouse_basic = ['up', 'down', 'left', 'right', 'up_right', 'up_left', 'down_right', 'down_left'] +for m in mouse_basic: + configs.append((("", ""), (m, m))) + +# Group 3: Split Keyboard, No Mouse (16-23) +split_keys = [ + ('W', 'S'), ('S', 'W'), + ('A', 'D'), ('D', 'A'), + ('W', 'A'), ('W', 'D'), + ('S', 'A'), ('S', 'D') +] +for k1, k2 in split_keys: + configs.append(((k1, k2), ("stop", "stop"))) + +# Group 4: No Keyboard, Split Mouse (24-31) +split_mouse = [ + ('left', 'right'), ('right', 'left'), + ('up', 'down'), ('down', 'up'), + ('up_left', 'up_right'), ('up_right', 'up_left'), + ('left', 'up'), ('right', 'down') +] +for m1, m2 in split_mouse: + configs.append((("", ""), (m1, m2))) + +# Group 5: Constant Keyboard + Constant Mouse (32-47) +combo_keys = ['W', 'S', 'W', 'S', 'A', 'D', 'WA', 'WD', 'W', 'S', 'W', 'S', 'A', 'D', 'WA', 'WD'] +combo_mice = ['left', 'left', 'right', 'right', 'up', 'up', 'down', 'down', 'up_left', 'up_left', 'up_right', 'up_right', 'down_left', 'down_right', 'right', 'left'] +for i in range(16): + configs.append(((combo_keys[i], combo_keys[i]), (combo_mice[i], combo_mice[i]))) + +# Group 6: Constant Keyboard, Split Mouse (48-55) +complex_1_keys = ['W'] * 8 +complex_1_mice = [ + ('left', 'right'), ('right', 'left'), + ('up', 'down'), ('down', 'up'), + ('left', 'up'), ('right', 'up'), + ('left', 'down'), ('right', 'down') +] +for i in range(8): + configs.append(((complex_1_keys[i], complex_1_keys[i]), complex_1_mice[i])) + +# Group 7: Split Keyboard, Constant Mouse (56-63) +complex_2_keys = [ + ('W', 'S'), ('S', 'W'), + ('A', 'D'), ('D', 'A'), + ('W', 'A'), ('W', 'D'), + ('S', 'A'), ('S', 'D') +] +complex_2_mouse = 'up' +for k1, k2 in complex_2_keys: + configs.append(((k1, k2), (complex_2_mouse, complex_2_mouse))) + + +# Execution +print(f"Preparing to generate {len(configs)} action files...") + +for i, (key_seq, mouse_seq) in enumerate(configs): + if i >= 16: break + + # Generate Data + kb_arr, ms_arr = generate_sequence(key_seq, mouse_seq) + filename = save_action(i, kb_arr, ms_arr) + + # Generate Description for README + description = generate_description(key_seq, mouse_seq) + readme_entry = f"{i:02d}. {description}" + readme_content.append(readme_entry) + + print(f"Generated {filename} -> {description}") + +# Write README +readme_path = os.path.join(BASE_OUTPUT_DIR, 'README.md') +with open(readme_path, 'w', encoding='utf-8') as f: + f.write(f"Total Files: {len(readme_content)}\n\n") + for line in readme_content: + f.write(line + '\n') + +print(f"\nProcessing complete.") +print(f"64 .npy files generated in {VIDEO_OUTPUT_DIR}") +print(f"Manifest saved to {readme_path}") \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/launch_preprocess_slurm.sh b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/launch_preprocess_slurm.sh new file mode 100644 index 000000000..c1bc37b5d --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/launch_preprocess_slurm.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Create output directory if it doesn't exist +mkdir -p preprocess_output + +# Launch 8 jobs, one for each node (Total 64 GPUs) +# Each node processes 8 consecutive files (64 total files / 8 nodes = 8 files per node) +for node_id in {0..7}; do + # Calculate the starting file number for this node + start_file=$((node_id * 8)) + + echo "Launching node $node_id with files merge_${start_file}.txt to merge_$((start_file + 7)).txt" + + sbatch --job-name=mg-pre-${node_id} \ + --output=preprocess_output/mg-node-${node_id}.out \ + --error=preprocess_output/mg-node-${node_id}.err \ + $(pwd)/FastVideo/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_worker.slurm $start_file $node_id +done + +echo "All 8 nodes (64 GPUs) launched successfully!" diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_wangame_data_i2v.sh b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_wangame_data_i2v.sh new file mode 100644 index 000000000..85a4fd0d2 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_wangame_data_i2v.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +GPU_NUM=1 # 2,4,8 +MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers" +DATA_MERGE_PATH="mc_wasd_10/merge.txt" +OUTPUT_DIR="mc_wasd_10/preprocessed/" + +# export CUDA_VISIBLE_DEVICES=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29500 +export RANK=0 +export WORLD_SIZE=1 + +python fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 10 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 77 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --samples_per_file 10 \ + --train_fps 25 \ + --flush_frequency 10 \ + --preprocess_task wangame \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_worker.slurm b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_worker.slurm new file mode 100644 index 000000000..f60659a7b --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/preprocess_worker.slurm @@ -0,0 +1,61 @@ +#!/bin/bash +#SBATCH --partition=main +#SBATCH --qos=hao +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=960G +#SBATCH --exclusive +#SBATCH --time=72:00:00 + +# conda init +source ~/conda/miniconda/bin/activate +conda activate fastvideo_kaiqin + +# Accept parameters from launch script +START_FILE=${1:-1} # Starting file number for this node +NODE_ID=${2:-0} # Node identifier (0-7) + +MODEL_PATH="weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers" +OUTPUT_BASE="traindata_0204_2130/preprocessed" + +# Port range calculation +base_port=$((29500 + NODE_ID * 100)) +gpu_ids=(0 1 2 3 4 5 6 7) + +for i in {1..8}; do + port=$((base_port + i)) + gpu=${gpu_ids[((i-1))]} + file_num=$((START_FILE + i - 1)) + + DATA_MERGE_PATH="traindata_0204_2130/merge_${file_num}.txt" + OUTPUT_DIR="${OUTPUT_BASE}/gpu_${gpu}_file_${file_num}" + echo "DATA_MERGE_PATH: $DATA_MERGE_PATH" + echo "OUTPUT_DIR: $OUTPUT_DIR" + + # CPU binding (optional, kept from syn.slurm logic) + start_cpu=$(( (i-1)*2 )) + end_cpu=$(( start_cpu+1 )) + + echo "Starting GPU $gpu processing file merge_${file_num}.txt on port $port" + + CUDA_VISIBLE_DEVICES=$gpu taskset -c ${start_cpu}-${end_cpu} torchrun --nnodes=1 --nproc_per_node=1 --master_port $port \ + FastVideo/fastvideo/pipelines/preprocess/v1_preprocess.py \ + --model_path $MODEL_PATH \ + --data_merge_path $DATA_MERGE_PATH \ + --preprocess_video_batch_size 1 \ + --seed 42 \ + --max_height 352 \ + --max_width 640 \ + --num_frames 77 \ + --dataloader_num_workers 0 \ + --output_dir=$OUTPUT_DIR \ + --samples_per_file 8 \ + --train_fps 25 \ + --flush_frequency 8 \ + --preprocess_task wangame & +done + +wait +echo "Node $NODE_ID processing blocks completed!" diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation.json b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation.json new file mode 100644 index 000000000..3a4a47115 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation.json @@ -0,0 +1,14 @@ +{ + "data": [ + { + "caption": "00. Hold [W] + Static", + "image_path": "doom/000000.jpg", + "action_path": "action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 77 + } + ] +} \ No newline at end of file diff --git a/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation_vizdoom.json b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation_vizdoom.json new file mode 100644 index 000000000..1535db987 --- /dev/null +++ b/examples/training/finetune/WanGame2.1_1.3b_i2v_LingBot/validation_vizdoom.json @@ -0,0 +1,84 @@ +{ + "data": [ + { + "caption": "00. Hold [W] + Static", + "image_path": "doom/000000.jpg", + "action_path": "action/000000_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "01. Hold [S] + Static", + "image_path": "doom/000001.jpg", + "action_path": "action/000001_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "02. Hold [A] + Static", + "image_path": "doom/000002.jpg", + "action_path": "action/000002_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "10. No Key + Hold [left]", + "image_path": "doom/000003.jpg", + "action_path": "action/000010_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "12. No Key + Hold [up_right]", + "image_path": "doom/000004.jpg", + "action_path": "action/000012_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "21. Switch [W]->[D] + Static", + "image_path": "doom/000005.jpg", + "action_path": "action/000021_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "29. No Key + Switch [up_right]->[up_left]", + "image_path": "doom/000006.jpg", + "action_path": "action/000029_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + }, + { + "caption": "50. Hold [W] + Switch [up]->[down]", + "image_path": "doom/000007.jpg", + "action_path": "action/000050_action.npy", + "video_path": null, + "num_inference_steps": 4, + "height": 352, + "width": 640, + "num_frames": 81 + } + ] +} \ No newline at end of file diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 7ea6d1284..8caddca7e 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -7,9 +7,13 @@ from fastvideo.configs.models.dits.ltx2 import LTX2VideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig from fastvideo.configs.models.dits.hyworld import HYWorldConfig +from fastvideo.configs.models.dits.wangamevideo import (WanGameVideoConfig, + WanLingBotVideoConfig) __all__ = [ - "HunyuanVideoConfig", "HunyuanVideo15Config", "HunyuanGameCraftConfig", - "WanVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig", - "LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig" + "HunyuanVideoConfig", "HunyuanVideo15Config", "WanVideoConfig", + "StepVideoConfig", "CosmosVideoConfig", "Cosmos25VideoConfig", + "LongCatVideoConfig", "LTX2VideoConfig", "HYWorldConfig", + "LingBotWorldVideoConfig", "WanGameVideoConfig", "WanLingBotVideoConfig", + "LingBotWorldVideoConfig", "HunyuanGameCraftConfig", "WanVideoConfig" ] diff --git a/fastvideo/configs/models/dits/wangamevideo.py b/fastvideo/configs/models/dits/wangamevideo.py new file mode 100644 index 000000000..429d7407c --- /dev/null +++ b/fastvideo/configs/models/dits/wangamevideo.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class WanGameVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^patch_embedding\.(.*)$": + r"patch_embedding.proj.\1", + r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": + r"condition_embedder.text_embedder.fc_in.\1", + r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": + r"condition_embedder.text_embedder.fc_out.\1", + r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": + r"condition_embedder.time_embedder.mlp.fc_in.\1", + r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": + r"condition_embedder.time_embedder.mlp.fc_out.\1", + r"^condition_embedder\.time_proj\.(.*)$": + r"condition_embedder.time_modulation.linear.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$": + r"condition_embedder.image_embedder.ff.fc_in.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$": + r"condition_embedder.image_embedder.ff.fc_out.\1", + r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": + r"blocks.\1.to_q.\2", + r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": + r"blocks.\1.to_k.\2", + r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": + r"blocks.\1.to_v.\2", + r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": + r"blocks.\1.to_out.\2", + r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": + r"blocks.\1.norm_q.\2", + r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": + r"blocks.\1.norm_k.\2", + r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": + r"blocks.\1.attn2.to_out.\2", + r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": + r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": + r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm2\.(.*)$": + r"blocks.\1.self_attn_residual_norm.norm.\2", + }) + + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + # Some LoRA adapters use the original official layer names instead of hf layer names, + # so apply this before the param_names_mapping + lora_param_names_mapping: dict = field( + default_factory=lambda: { + r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.attn1.to_q.\2", + r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.attn1.to_k.\2", + r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": r"blocks.\1.attn1.to_v.\2", + r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": + r"blocks.\1.attn1.to_out.0.\2", + r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": r"blocks.\1.attn2.to_q.\2", + r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": r"blocks.\1.attn2.to_k.\2", + r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": r"blocks.\1.attn2.to_v.\2", + r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": + r"blocks.\1.attn2.to_out.0.\2", + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + }) + + patch_size: tuple[int, int, int] = (1, 2, 2) + text_len = 512 + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: str = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: int | None = None + added_kv_proj_dim: int | None = None + rope_max_seq_len: int = 1024 + pos_embed_seq_len: int | None = None + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + # Wan MoE + boundary_ratio: float | None = None + + # Causal Wan + local_attn_size: int = -1 # Window size for temporal local attention (-1 indicates global attention) + sink_size: int = 0 # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + num_frames_per_block: int = 3 + sliding_window_num_frames: int = 21 + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class WanGameVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=WanGameVideoArchConfig) + + prefix: str = "WanGame" + + +@dataclass +class WanLingBotVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=WanGameVideoArchConfig) + + prefix: str = "WanLingBot" diff --git a/fastvideo/configs/pipelines/__init__.py b/fastvideo/configs/pipelines/__init__.py index 2ea5882d9..b87c23102 100644 --- a/fastvideo/configs/pipelines/__init__.py +++ b/fastvideo/configs/pipelines/__init__.py @@ -8,14 +8,19 @@ from fastvideo.configs.pipelines.ltx2 import LTX2T2VConfig from fastvideo.registry import get_pipeline_config_cls_from_name from fastvideo.configs.pipelines.wan import (SelfForcingWanT2V480PConfig, + WanGameI2V480PConfig, WanI2V480PConfig, WanI2V720PConfig, - WanT2V480PConfig, WanT2V720PConfig) + WanLingBotI2V480PConfig, + WanT2V480PConfig, WanT2V720PConfig, + SelfForcingWanGameI2V480PConfig) __all__ = [ "HunyuanConfig", "FastHunyuanConfig", "HunyuanGameCraftPipelineConfig", "PipelineConfig", "Hunyuan15T2V480PConfig", "Hunyuan15T2V720PConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", - "WanI2V720PConfig", "SelfForcingWanT2V480PConfig", "CosmosConfig", - "Cosmos25Config", "LTX2T2VConfig", "HYWorldConfig", + "WanI2V720PConfig", "StepVideoT2VConfig", "SelfForcingWanT2V480PConfig", + "CosmosConfig", "Cosmos25Config", "LTX2T2VConfig", "HYWorldConfig", + "SD35Config", "LingBotWorldI2V480PConfig", "WanGameI2V480PConfig", + "WanLingBotI2V480PConfig", "SelfForcingWanGameI2V480PConfig", "get_pipeline_config_cls_from_name" ] diff --git a/fastvideo/configs/pipelines/base.py b/fastvideo/configs/pipelines/base.py index 83df65f8d..f60fb600a 100644 --- a/fastvideo/configs/pipelines/base.py +++ b/fastvideo/configs/pipelines/base.py @@ -69,6 +69,16 @@ class PipelineConfig: # DMD parameters dmd_denoising_steps: list[int] | None = field(default=None) + # Sampler kind (controls the denoising loop semantics). + # - "ode": deterministic solver-style loop (default) + # - "sde": stochastic loop with noise injection + sampler_kind: str = "ode" + + # ODE solver selection when `sampler_kind="ode"`. + # - "unipc": FlowUniPCMultistepScheduler (default) + # - "euler": FlowMatchEulerDiscreteScheduler + ode_solver: str = "unipc" + # Wan2.2 TI2V parameters ti2v_task: bool = False boundary_ratio: float | None = None @@ -175,6 +185,14 @@ def add_cli_args(parser: FlexibleArgumentParser, help= "Comma-separated list of denoising steps (e.g., '1000,757,522')", ) + parser.add_argument( + f"--{prefix_with_dot}sampler-kind", + type=str, + choices=["ode", "sde"], + dest=f"{prefix_with_dot.replace('-', '_')}sampler_kind", + default=PipelineConfig.sampler_kind, + help="Sampling loop kind: ode (default) or sde.", + ) # Add VAE configuration arguments from fastvideo.configs.models.vaes.base import VAEConfig diff --git a/fastvideo/configs/pipelines/wan.py b/fastvideo/configs/pipelines/wan.py index 434839dcc..996a734aa 100644 --- a/fastvideo/configs/pipelines/wan.py +++ b/fastvideo/configs/pipelines/wan.py @@ -7,6 +7,8 @@ from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig from fastvideo.configs.models.dits import WanVideoConfig from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig +from fastvideo.configs.models.dits.wangamevideo import (WanGameVideoConfig, + WanLingBotVideoConfig) from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config, WAN2_1ControlCLIPVisionConfig) @@ -112,6 +114,23 @@ class WANV2VConfig(WanI2V480PConfig): image_encoder_precision: str = 'bf16' +@dataclass +class WanLingBotI2V480PConfig(WanI2V480PConfig): + """Configuration for Wan LingBot image-to-video pipeline.""" + + dit_config: DiTConfig = field(default_factory=WanLingBotVideoConfig) + + +@dataclass +class WanGameI2V480PConfig(WanI2V480PConfig): + """Configuration for WanGame image-to-video pipeline.""" + + dit_config: DiTConfig = field(default_factory=WanGameVideoConfig) + flow_shift: float | None = 3.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 750, 500, 250, 0]) + + @dataclass class FastWan2_1_T2V_480P_Config(WanT2V480PConfig): """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD""" @@ -193,6 +212,15 @@ def __post_init__(self) -> None: self.vae_config.load_decoder = True +@dataclass +class SelfForcingWanGameI2V480PConfig(WanGameI2V480PConfig): + is_causal: bool = True + flow_shift: float | None = 3.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 750, 500, 250, 0]) + warp_denoising_step: bool = True + + # ============================================= # ============= Matrix Game =================== # ============================================= diff --git a/fastvideo/configs/sample/wan.py b/fastvideo/configs/sample/wan.py index a96cf0d29..b6eaf6989 100644 --- a/fastvideo/configs/sample/wan.py +++ b/fastvideo/configs/sample/wan.py @@ -63,13 +63,13 @@ class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam): @dataclass class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam): """Sampling parameters for Wan2.1 Fun 1.3B InP model.""" - height: int = 480 - width: int = 832 - num_frames: int = 81 - fps: int = 16 + height: int = 352 + width: int = 640 + num_frames: int = 77 + fps: int = 25 negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" - guidance_scale: float = 6.0 - num_inference_steps: int = 50 + guidance_scale: float = 1.0 + num_inference_steps: int = 40 @dataclass diff --git a/fastvideo/dataset/dataloader/record_schema.py b/fastvideo/dataset/dataloader/record_schema.py index 1bc86dd7d..0eea39adf 100644 --- a/fastvideo/dataset/dataloader/record_schema.py +++ b/fastvideo/dataset/dataloader/record_schema.py @@ -188,3 +188,96 @@ def text_only_record_creator(text_name: str, text_embedding: np.ndarray, "caption": caption, } return record + + +def wangame_ode_record_creator( + video_name: str, + clip_feature: np.ndarray, + first_frame_latent: np.ndarray, + trajectory_latents: np.ndarray, + trajectory_timesteps: np.ndarray, + pil_image: np.ndarray | None = None, + keyboard_cond: np.ndarray | None = None, + mouse_cond: np.ndarray | None = None, + caption: str = "") -> dict[str, Any]: + """Create a ODE trajectory record matching pyarrow_schema_wangame + """ + assert trajectory_latents is not None, "trajectory_latents is required" + assert trajectory_timesteps is not None, "trajectory_timesteps is required" + assert clip_feature is not None, "clip_feature is required" + assert first_frame_latent is not None, "first_frame_latent is required" + + record = { + "id": video_name, + "file_name": video_name, + "caption": caption, + "media_type": "video", + } + + # I2V features + record.update({ + "clip_feature_bytes": clip_feature.tobytes(), + "clip_feature_shape": list(clip_feature.shape), + "clip_feature_dtype": str(clip_feature.dtype), + }) + + record.update({ + "first_frame_latent_bytes": first_frame_latent.tobytes(), + "first_frame_latent_shape": list(first_frame_latent.shape), + "first_frame_latent_dtype": str(first_frame_latent.dtype), + }) + + # Optional PIL Image + if pil_image is not None: + record.update({ + "pil_image_bytes": pil_image.tobytes(), + "pil_image_shape": list(pil_image.shape), + "pil_image_dtype": str(pil_image.dtype), + }) + else: + record.update({ + "pil_image_bytes": b"", + "pil_image_shape": [], + "pil_image_dtype": "", + }) + + # Actions + if keyboard_cond is not None: + record.update({ + "keyboard_cond_bytes": keyboard_cond.tobytes(), + "keyboard_cond_shape": list(keyboard_cond.shape), + "keyboard_cond_dtype": str(keyboard_cond.dtype), + }) + else: + record.update({ + "keyboard_cond_bytes": b"", + "keyboard_cond_shape": [], + "keyboard_cond_dtype": "", + }) + + if mouse_cond is not None: + record.update({ + "mouse_cond_bytes": mouse_cond.tobytes(), + "mouse_cond_shape": list(mouse_cond.shape), + "mouse_cond_dtype": str(mouse_cond.dtype), + }) + else: + record.update({ + "mouse_cond_bytes": b"", + "mouse_cond_shape": [], + "mouse_cond_dtype": "", + }) + + record.update({ + "trajectory_latents_bytes": trajectory_latents.tobytes(), + "trajectory_latents_shape": list(trajectory_latents.shape), + "trajectory_latents_dtype": str(trajectory_latents.dtype), + }) + + record.update({ + "trajectory_timesteps_bytes": trajectory_timesteps.tobytes(), + "trajectory_timesteps_shape": list(trajectory_timesteps.shape), + "trajectory_timesteps_dtype": str(trajectory_timesteps.dtype), + }) + + return record diff --git a/fastvideo/dataset/dataloader/schema.py b/fastvideo/dataset/dataloader/schema.py index 048c7686c..cdf9b42fb 100644 --- a/fastvideo/dataset/dataloader/schema.py +++ b/fastvideo/dataset/dataloader/schema.py @@ -157,3 +157,85 @@ pa.field("duration_sec", pa.float64()), pa.field("fps", pa.float64()), ]) + +pyarrow_schema_wangame = pa.schema([ + pa.field("id", pa.string()), + # --- Image/Video VAE latents --- + # Tensors are stored as raw bytes with shape and dtype info for loading + pa.field("vae_latent_bytes", pa.binary()), + # e.g., [C, T, H, W] or [C, H, W] + pa.field("vae_latent_shape", pa.list_(pa.int64())), + # e.g., 'float32' + pa.field("vae_latent_dtype", pa.string()), + #I2V + pa.field("clip_feature_bytes", pa.binary()), + pa.field("clip_feature_shape", pa.list_(pa.int64())), + pa.field("clip_feature_dtype", pa.string()), + pa.field("first_frame_latent_bytes", pa.binary()), + pa.field("first_frame_latent_shape", pa.list_(pa.int64())), + pa.field("first_frame_latent_dtype", pa.string()), + # --- Action --- + pa.field("mouse_cond_bytes", pa.binary()), + pa.field("mouse_cond_shape", pa.list_(pa.int64())), # [T, 2] + pa.field("mouse_cond_dtype", pa.string()), + pa.field("keyboard_cond_bytes", pa.binary()), + pa.field("keyboard_cond_shape", pa.list_(pa.int64())), # [T, 4] + pa.field("keyboard_cond_dtype", pa.string()), + # I2V Validation + pa.field("pil_image_bytes", pa.binary()), + pa.field("pil_image_shape", pa.list_(pa.int64())), + pa.field("pil_image_dtype", pa.string()), + # --- Metadata --- + pa.field("file_name", pa.string()), + pa.field("caption", pa.string()), + pa.field("media_type", pa.string()), # 'image' or 'video' + pa.field("width", pa.int64()), + pa.field("height", pa.int64()), + # -- Video-specific (can be null/default for images) --- + # Number of frames processed (e.g., 1 for image, N for video) + pa.field("num_frames", pa.int64()), + pa.field("duration_sec", pa.float64()), + pa.field("fps", pa.float64()), +]) + +pyarrow_schema_wangame_lingbot = pyarrow_schema_wangame + +pyarrow_schema_ode_trajectory_wangame = pa.schema([ + pa.field("id", pa.string()), + #I2V + pa.field("clip_feature_bytes", pa.binary()), + pa.field("clip_feature_shape", pa.list_(pa.int64())), + pa.field("clip_feature_dtype", pa.string()), + pa.field("first_frame_latent_bytes", pa.binary()), + pa.field("first_frame_latent_shape", pa.list_(pa.int64())), + pa.field("first_frame_latent_dtype", pa.string()), + # --- Action --- + pa.field("mouse_cond_bytes", pa.binary()), + pa.field("mouse_cond_shape", pa.list_(pa.int64())), # [T, 2] + pa.field("mouse_cond_dtype", pa.string()), + pa.field("keyboard_cond_bytes", pa.binary()), + pa.field("keyboard_cond_shape", pa.list_(pa.int64())), # [T, 4] + pa.field("keyboard_cond_dtype", pa.string()), + # I2V Validation + pa.field("pil_image_bytes", pa.binary()), + pa.field("pil_image_shape", pa.list_(pa.int64())), + pa.field("pil_image_dtype", pa.string()), + # --- ODE Trajectory --- + pa.field("trajectory_latents_bytes", pa.binary()), + pa.field("trajectory_latents_shape", pa.list_(pa.int64())), + pa.field("trajectory_latents_dtype", pa.string()), + pa.field("trajectory_timesteps_bytes", pa.binary()), + pa.field("trajectory_timesteps_shape", pa.list_(pa.int64())), + pa.field("trajectory_timesteps_dtype", pa.string()), + # --- Metadata --- + pa.field("file_name", pa.string()), + pa.field("caption", pa.string()), + pa.field("media_type", pa.string()), # 'image' or 'video' + pa.field("width", pa.int64()), + pa.field("height", pa.int64()), + # -- Video-specific (can be null/default for images) --- + # Number of frames processed (e.g., 1 for image, N for video) + pa.field("num_frames", pa.int64()), + pa.field("duration_sec", pa.float64()), + pa.field("fps", pa.float64()), +]) diff --git a/fastvideo/dataset/parquet_dataset_map_style.py b/fastvideo/dataset/parquet_dataset_map_style.py index dac622497..d1ce92a9f 100644 --- a/fastvideo/dataset/parquet_dataset_map_style.py +++ b/fastvideo/dataset/parquet_dataset_map_style.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import hashlib import os import pickle import random @@ -36,6 +37,7 @@ def __init__( global_rank: int, drop_last: bool = True, drop_first_row: bool = False, + reshuffle_each_epoch: bool = True, seed: int = 0, ): self.batch_size = batch_size @@ -45,34 +47,40 @@ def __init__( self.num_sp_groups = num_sp_groups self.global_rank = global_rank self.sp_world_size = sp_world_size + self.drop_first_row = drop_first_row + self.reshuffle_each_epoch = reshuffle_each_epoch + self._build_indices(0) # build indices for epoch 0 to initialize the sampler + + def _build_indices(self, epoch: int) -> None: # ── epoch-level RNG ──────────────────────────────────────────────── - rng = torch.Generator().manual_seed(self.seed) + rng = torch.Generator().manual_seed(self.seed + epoch) # Create a random permutation of all indices global_indices = torch.randperm(self.dataset_size, generator=rng) - if drop_first_row: + dataset_size = self.dataset_size + if self.drop_first_row: # drop 0 in global_indices global_indices = global_indices[global_indices != 0] - self.dataset_size = self.dataset_size - 1 + dataset_size = dataset_size - 1 if self.drop_last: # For drop_last=True, we: # 1. Ensure total samples is divisible by (batch_size * num_sp_groups) # 2. This guarantees each SP group gets same number of complete batches # 3. Prevents uneven batch sizes across SP groups at end of epoch - num_batches = self.dataset_size // self.batch_size + num_batches = dataset_size // self.batch_size num_global_batches = num_batches // self.num_sp_groups global_indices = global_indices[:num_global_batches * self.num_sp_groups * self.batch_size] else: - if self.dataset_size % (self.num_sp_groups * self.batch_size) != 0: + if dataset_size % (self.num_sp_groups * self.batch_size) != 0: # add more indices to make it divisible by (batch_size * num_sp_groups) padding_size = self.num_sp_groups * self.batch_size - ( - self.dataset_size % (self.num_sp_groups * self.batch_size)) + dataset_size % (self.num_sp_groups * self.batch_size)) logger.info("Padding the dataset from %d to %d", - self.dataset_size, self.dataset_size + padding_size) + dataset_size, dataset_size + padding_size) global_indices = torch.cat( [global_indices, global_indices[:padding_size]]) @@ -84,6 +92,11 @@ def __init__( logger.info("Dataset size for each sp group: %d", len(sp_group_local_indices)) + def set_epoch(self, epoch: int) -> None: + if not self.reshuffle_each_epoch: + return + self._build_indices(epoch) + def __iter__(self): indices = self.sp_group_local_indices for i in range(0, len(indices), self.batch_size): @@ -94,11 +107,90 @@ def __len__(self): return len(self.sp_group_local_indices) // self.batch_size +def _parse_data_path_specs(path: str) -> list[tuple[str, int]]: + """ + Parse data_path into a list of (directory, repeat_count). + Syntax: comma-separated entries; each entry is "path" (default 1) or "path:N" (N = repeat count). + N=0 means skip that path (convenience to disable without removing). If no ":" present, default is 1. + Example: "/dir1:2,/dir2,/dir3:0" -> dir1 2x, dir2 1x, dir3 skipped. + """ + specs: list[tuple[str, int]] = [] + for part in path.split(","): + part = part.strip() + if not part: + continue + if ":" in part: + p, _, count_str = part.rpartition(":") + p = p.strip() + try: + count = int(count_str.strip()) + except ValueError: + raise ValueError( + f"data_path repeat count must be an integer, got {count_str!r}" + ) from None + if count < 0: + raise ValueError( + f"data_path repeat count must be >= 0, got {count}" + ) + specs.append((p, count)) + else: + specs.append((part, 1)) + return specs + + +def _scan_parquet_files_for_path(p: str) -> tuple[list[str], list[int]]: + """Return (file_paths, row_lengths) for a single directory.""" + file_names: list[str] = [] + for root, _, files in os.walk(p): + for file in sorted(files): + if file.endswith(".parquet"): + file_names.append(os.path.join(root, file)) + lengths = [] + for file_path in tqdm.tqdm( + file_names, desc="Reading parquet files to get lengths"): + lengths.append(pq.ParquetFile(file_path).metadata.num_rows) + logger.info("Found %d parquet files with %d total rows", len(file_names), sum(lengths)) + return file_names, lengths + + def get_parquet_files_and_length(path: str): - dataset_root = os.path.realpath(os.path.expanduser(path)) - # Check if cached info exists - cache_dir = os.path.join(dataset_root, "map_style_cache") - cache_file = os.path.join(cache_dir, "file_info.pkl") + """ + Collect parquet file paths and row lengths from one or more directories. + path: single directory, or comma-separated "path" or "path:N" (N = repeat count). + E.g. "/dir1:2,/dir2:1" -> dir1's files appear 2x (oversampled), dir2 once. + """ + path_specs = _parse_data_path_specs(path) + if not path_specs: + raise ValueError( + "data_path must be a non-empty path or comma-separated path specs" + ) + # Use first path with count > 0 for cache_dir (single-path case only) + first_path = next( + (p for p, c in path_specs if c > 0), + path_specs[0][0], + ) + is_single_no_repeat = ( + len(path_specs) == 1 and path_specs[0][1] == 1 + ) + effective_path = path.strip() + # Single path, no repeat: cache under that path (backward compatible). + # Multi-path or repeat: cache in a neutral dir keyed by hash of full path spec, + # so we never reuse "first path's" cache and the cached list is the merged list. + if is_single_no_repeat: + cache_dir = os.path.join(first_path, "map_style_cache") + cache_suffix = "file_info.pkl" + else: + neutral_root = os.environ.get( + "FASTVIDEO_MAP_STYLE_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache", "fastvideo", "map_style_cache"), + ) + cache_dir = neutral_root + cache_suffix = ( + "file_info_" + + hashlib.md5(effective_path.encode()).hexdigest()[:16] + + ".pkl" + ) + cache_file = os.path.join(cache_dir, cache_suffix) # Only rank 0 checks for cache and scans files if needed if get_world_rank() == 0: @@ -152,30 +244,31 @@ def get_parquet_files_and_length(path: str): # If cache not loaded (either doesn't exist or failed to load), scan files if not cache_loaded: - logger.info("Scanning parquet files to get lengths") - lengths = [] - file_names = [] - for root, _, files in os.walk(dataset_root): - for file in sorted(files): - if file.endswith('.parquet'): - file_path = os.path.realpath(os.path.join(root, file)) - file_names.append(file_path) - if len(file_names) == 0: - raise FileNotFoundError( - "No parquet files found under dataset path: " - f"{path}. " - "Please verify this path points to preprocessed parquet " - "data.") - for file_path in tqdm.tqdm( - file_names, desc="Reading parquet files to get lengths"): - num_rows = pq.ParquetFile(file_path).metadata.num_rows - lengths.append(num_rows) - # sort according to file name to ensure all rank has the same order - file_names_sorted, lengths_sorted = zip(*sorted(zip(file_names, - lengths, - strict=True), - key=lambda x: x[0]), - strict=True) + logger.info( + "Scanning parquet files (path specs: %s)", + [(p, c) for p, c in path_specs], + ) + # Build list with repeats; use (path, length, sort_index) for stable order + # Skip paths with count 0 (no I/O for disabled paths) + combined: list[tuple[str, int, int]] = [] + sort_index = 0 + for p, count in path_specs: + if count == 0: + continue + fnames, lens = _scan_parquet_files_for_path(p) + for _ in range(count): + for f, ln in zip(fnames, lens, strict=True): + combined.append((f, ln, sort_index)) + sort_index += 1 + if not combined: + raise ValueError( + "No parquet files found in the dataset (paths: %s)" + % [p for p, _ in path_specs] + ) + combined.sort(key=lambda x: (x[0], x[2])) + file_names_sorted = tuple(x[0] for x in combined) + lengths_sorted = tuple(x[1] for x in combined) + # Save the cache os.makedirs(cache_dir, exist_ok=True) with open(cache_file, "wb") as f: @@ -275,6 +368,7 @@ def __init__( seed: int = 42, drop_last: bool = True, drop_first_row: bool = False, + reshuffle_each_epoch: bool = False, text_padding_length: int = 512, ): super().__init__() @@ -297,6 +391,7 @@ def __init__( global_rank=get_world_rank(), drop_last=drop_last, drop_first_row=drop_first_row, + reshuffle_each_epoch=reshuffle_each_epoch, seed=seed, ) logger.info("Dataset initialized with %d parquet files and %d rows", @@ -369,6 +464,7 @@ def build_parquet_map_style_dataloader( cfg_rate=0.0, drop_last=True, drop_first_row=False, + reshuffle_each_epoch=False, text_padding_length=512, seed=42) -> tuple[LatentsParquetMapStyleDataset, StatefulDataLoader]: dataset = LatentsParquetMapStyleDataset( @@ -377,6 +473,7 @@ def build_parquet_map_style_dataloader( cfg_rate=cfg_rate, drop_last=drop_last, drop_first_row=drop_first_row, + reshuffle_each_epoch=reshuffle_each_epoch, text_padding_length=text_padding_length, parquet_schema=parquet_schema, seed=seed) diff --git a/fastvideo/dataset/validation_dataset.py b/fastvideo/dataset/validation_dataset.py index 5ab467d75..cf97e8bc0 100644 --- a/fastvideo/dataset/validation_dataset.py +++ b/fastvideo/dataset/validation_dataset.py @@ -4,6 +4,7 @@ import pathlib import datasets +import numpy as np from torch.utils.data import IterableDataset from fastvideo.distributed import (get_sp_world_size, get_world_rank, @@ -16,8 +17,9 @@ class ValidationDataset(IterableDataset): - def __init__(self, filename: str): + def __init__(self, filename: str, num_samples: int | None = None): super().__init__() + self.num_samples = num_samples self.filename = pathlib.Path(filename) # get directory of filename @@ -58,6 +60,12 @@ def __init__(self, filename: str): # Convert to list to get total samples self.all_samples = list(data) + + # Limit number of samples if specified + if self.num_samples is not None and self.num_samples < len(self.all_samples): + self.all_samples = self.all_samples[:self.num_samples] + logger.info("Limiting validation samples to %s", self.num_samples) + self.original_total_samples = len(self.all_samples) # Extend samples to be a multiple of DP degree (num_sp_groups) @@ -160,5 +168,25 @@ def __iter__(self): else: sample["control_video"] = load_video(control_video_path) + if sample.get("action_path", None) is not None: + action_path = sample["action_path"] + action_path = os.path.join(self.dir, action_path) + sample["action_path"] = action_path + if not pathlib.Path(action_path).is_file(): + logger.warning("Action file %s does not exist.", action_path) + else: + try: + action_data = np.load(action_path, allow_pickle=True) + num_frames = sample["num_frames"] + if action_data.dtype == object: action_data = action_data.item() + if isinstance(action_data, dict): + sample["keyboard_cond"] = action_data["keyboard"][:num_frames] + sample["mouse_cond"] = action_data["mouse"][:num_frames] + else: + sample["keyboard_cond"] = action_data[:num_frames] + except Exception as e: + logger.error("Error loading action file %s: %s", + action_path, e) + sample = {k: v for k, v in sample.items() if v is not None} yield sample diff --git a/fastvideo/fastvideo_args.py b/fastvideo/fastvideo_args.py index 4c0a14d0d..f5d53d13c 100644 --- a/fastvideo/fastvideo_args.py +++ b/fastvideo/fastvideo_args.py @@ -798,6 +798,7 @@ class TrainingArgs(FastVideoArgs): """ data_path: str = "" dataloader_num_workers: int = 0 + reshuffle_each_epoch: bool = True num_height: int = 0 num_width: int = 0 num_frames: int = 0 @@ -826,6 +827,7 @@ class TrainingArgs(FastVideoArgs): validation_sampling_steps: str = "" validation_guidance_scale: str = "" validation_steps: float = 0.0 + validation_num_samples: int | None = None # Limit number of validation samples (None = use all) log_validation: bool = False trackers: list[str] = dataclasses.field(default_factory=list) tracker_project_name: str = "" @@ -887,6 +889,19 @@ class TrainingArgs(FastVideoArgs): lora_training: bool = False ltx2_first_frame_conditioning_p: float = 0.1 + # Action-only training: freeze base DiT, only train action modules + train_action_only: bool = False + + # Which action modules to train (only effective when train_action_only=True): + # "both" – action_embedder + prope_proj (default) + # "action_mlp" – action_embedder only + # "prope" – prope_proj only + action_train_target: str = "both" + + # Action warmup: keep action modules (action_embedder, to_out_prope) at zero + # for this many steps to let the base model stabilize first, then enable them. + action_warmup_steps: int = 0 + # distillation args generator_update_interval: int = 5 dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic @@ -898,6 +913,7 @@ class TrainingArgs(FastVideoArgs): fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2" training_state_checkpointing_steps: int = 0 # for resuming training weight_only_checkpointing_steps: int = 0 # for inference + best_checkpoint_start_step: int = 0 # save best checkpoint (by mf_angle_err_mean) after this step; 0 = disabled log_visualization: bool = False visualization_steps: int = 0 # simulate generator forward to match inference @@ -963,14 +979,22 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument("--data-path", - type=str, - required=True, - help="Path to parquet files") + parser.add_argument( + "--data-path", + type=str, + required=True, + help= + "Path to parquet files (comma-separated for multiple; path:N for repeat count)" + ) parser.add_argument("--dataloader-num-workers", type=int, required=True, help="Number of workers for dataloader") + parser.add_argument( + "--reshuffle-each-epoch", + action=StoreBoolean, + default=TrainingArgs.reshuffle_each_epoch, + help="Whether to reshuffle dataset order each epoch") parser.add_argument("--num-height", type=int, required=True, @@ -1060,6 +1084,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--validation-steps", type=float, help="Number of validation steps") + parser.add_argument( + "--validation-num-samples", + type=int, + help="Limit number of validation samples (default: use all)") parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results") @@ -1094,6 +1122,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--weight-only-checkpointing-steps", type=int, help="Steps between weight-only checkpoints (for inference)") + parser.add_argument( + "--best-checkpoint-start-step", + type=int, + help="Save best checkpoint (by mf_angle_err_mean) after this " + "step; 0 = disabled") parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from") @@ -1248,6 +1281,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Probability of conditioning on the first frame during LTX-2 training", ) + # Action-only training (freeze base model, only train action params) + parser.add_argument( + "--train-action-only", + action=StoreBoolean, + help="Whether to only train action-related parameters " + "(action_embedder and to_out_prope) while freezing base model") + + # Action warmup: keep action modules frozen for N steps + parser.add_argument("--action-warmup-steps", + type=int, + default=0, + help="Number of steps to keep action modules " + "(action_embedder, to_out_prope) frozen to let " + "the base model stabilize first") + # V-MoBA parameters parser.add_argument( "--moba-config-path", @@ -1344,6 +1392,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=TrainingArgs.context_noise, help="Context noise level for cache updates") + parser.add_argument( + "--action-train-target", + type=str, + default=TrainingArgs.action_train_target, + choices=["both", "action_mlp", "prope"], + help="Which action modules to train while freezing the base model", + ) return parser diff --git a/fastvideo/models/dits/hyworld/pose.py b/fastvideo/models/dits/hyworld/pose.py index b1b3f5df3..99d535308 100644 --- a/fastvideo/models/dits/hyworld/pose.py +++ b/fastvideo/models/dits/hyworld/pose.py @@ -10,13 +10,15 @@ """ import json +import logging import numpy as np import torch from scipy.spatial.transform import Rotation as R from typing import Union, Optional -from .trajectory import generate_camera_trajectory_local +from fastvideo.models.dits.hyworld.trajectory import generate_camera_trajectory_local +logger = logging.getLogger(__name__) # Mapping from one-hot action encoding to single label mapping = { @@ -411,3 +413,271 @@ def compute_num_frames(latent_num: int) -> int: Number of video frames """ return (latent_num - 1) * 4 + 1 + +def reformat_keyboard_and_mouse_tensors(keyboard_tensor, mouse_tensor): + """ + Reformat the keyboard and mouse tensors to the format compatible with HyWorld. + """ + num_frames = keyboard_tensor.shape[0] + assert (num_frames - 1) % 4 == 0, "num_frames must be a multiple of 4" + assert mouse_tensor.shape[0] == num_frames, "mouse_tensor must have the same number of frames as keyboard_tensor" + keyboard_tensor = keyboard_tensor[1:, :] + mouse_tensor = mouse_tensor[1:, :] + groups = keyboard_tensor.view(-1, 4, keyboard_tensor.shape[1]) + if not (groups == groups[:, 0:1]).all(dim=1).all(): + logger.warning(f"keyboard_tensor has different values for each group: {groups}") + groups = mouse_tensor.view(-1, 4, mouse_tensor.shape[1]) + if not (groups == groups[:, 0:1]).all(dim=1).all(): + logger.warning(f"mouse_tensor has different values for each group: {groups}") + + return keyboard_tensor[::4], mouse_tensor[::4] + +def process_custom_actions(keyboard_tensor, mouse_tensor, forward_speed=DEFAULT_FORWARD_SPEED): + """ + Process custom keyboard and mouse tensors into model inputs (viewmats, intrinsics, action_labels). + Assumes inputs correspond to each LATENT frame. + """ + if keyboard_tensor.ndim == 3: + keyboard_tensor = keyboard_tensor.squeeze(0) + if mouse_tensor.ndim == 3: + mouse_tensor = mouse_tensor.squeeze(0) + keyboard_tensor, mouse_tensor = reformat_keyboard_and_mouse_tensors(keyboard_tensor, mouse_tensor) + + motions = [] + + # 1. Translate tensors to motions for trajectory generation + for t in range(keyboard_tensor.shape[0]): + frame_motion = {} + + # --- Translation --- + # MatrixGame convention: 0:W, 1:S, 2:A, 3:D + fwd = 0.0 + if keyboard_tensor[t, 0] > 0.5: fwd += forward_speed # W + if keyboard_tensor[t, 1] > 0.5: fwd -= forward_speed # S + if fwd != 0: frame_motion["forward"] = fwd + + rgt = 0.0 + if keyboard_tensor[t, 2] > 0.5: rgt -= forward_speed # A (Left is negative Right) + if keyboard_tensor[t, 3] > 0.5: rgt += forward_speed # D (Right) + if rgt != 0: frame_motion["right"] = rgt + + # --- Rotation --- + # MatrixGame convention: mouse is [Pitch, Yaw] (or Y, X) + # Apply scaling (e.g. to match HyWorld distribution) + pitch = mouse_tensor[t, 0].item() + yaw = mouse_tensor[t, 1].item() + + if abs(pitch) > 1e-4: frame_motion["pitch"] = pitch + if abs(yaw) > 1e-4: frame_motion["yaw"] = yaw + + motions.append(frame_motion) + + # 2. Generate Camera Trajectory + # generate_camera_trajectory_local returns T+1 poses (starting at Identity) + # We take the first T poses to match the latent count. + # Pose 0 is Identity. Pose 1 is Identity + Motion[0]. + poses = generate_camera_trajectory_local(motions) + # poses = np.array(poses[:T]) + + # 3. Compute Viewmats (w2c) and Intrinsics + w2c_list = [] + intrinsic_list = [] + + # Setup default intrinsic (normalized) + K = np.array(DEFAULT_INTRINSIC) + K[0, 0] /= K[0, 2] * 2 + K[1, 1] /= K[1, 2] * 2 + K[0, 2] = 0.5 + K[1, 2] = 0.5 + + for i in range(len(poses)): + c2w = np.array(poses[i]) + w2c = np.linalg.inv(c2w) + w2c_list.append(w2c) + intrinsic_list.append(K) + + viewmats = torch.as_tensor(np.array(w2c_list)) + intrinsics = torch.as_tensor(np.array(intrinsic_list)) + + # 4. Generate Action Labels by analyzing the generated trajectory + # This ensures consistency with complex simultaneous movements, exactly as pose_to_input does. + + # Calculate relative camera-to-world transforms + # c2ws = inverse(viewmats) + c2ws = np.linalg.inv(np.array(w2c_list)) + + # Calculate relative movement between frames + # relative_c2w[i] = inv(c2ws[i-1]) @ c2ws[i] + C_inv = np.linalg.inv(c2ws[:-1]) + relative_c2w = np.zeros_like(c2ws) + relative_c2w[0, ...] = c2ws[0, ...] # First is anchor + relative_c2w[1:, ...] = C_inv @ c2ws[1:, ...] + + # Initialize one-hot action encodings + trans_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) + rotate_one_hot = np.zeros((relative_c2w.shape[0], 4), dtype=np.int32) + + move_norm_valid = 0.0001 + + # Skip index 0 (anchor/identity) + for i in range(1, relative_c2w.shape[0]): + move_dirs = relative_c2w[i, :3, 3] # direction vector + move_norms = np.linalg.norm(move_dirs) + + if move_norms > move_norm_valid: # threshold for movement + move_norm_dirs = move_dirs / move_norms + angles_rad = np.arccos(move_norm_dirs.clip(-1.0, 1.0)) + trans_angles_deg = angles_rad * (180.0 / np.pi) # convert to degrees + else: + trans_angles_deg = np.zeros(3) + + R_rel = relative_c2w[i, :3, :3] + r = R.from_matrix(R_rel) + rot_angles_deg = r.as_euler("xyz", degrees=True) + + # Determine movement actions based on trajectory + # Note: HyWorld logic checks if rotation is small before assigning translation labels + # to avoid ambiguity in TPS mode, but here we generally want to capture the dominant movement. + tps = False # Default assumption, can be made an arg if needed + + if move_norms > move_norm_valid: + if (not tps) or ( + tps and abs(rot_angles_deg[1]) < 5e-2 and abs(rot_angles_deg[0]) < 5e-2 + ): + # Z-axis (Forward/Back) + if trans_angles_deg[2] < 60: + trans_one_hot[i, 0] = 1 # forward + elif trans_angles_deg[2] > 120: + trans_one_hot[i, 1] = 1 # backward + + # X-axis (Right/Left) + if trans_angles_deg[0] < 60: + trans_one_hot[i, 2] = 1 # right + elif trans_angles_deg[0] > 120: + trans_one_hot[i, 3] = 1 # left + + # Determine rotation actions + # Y-axis (Yaw) + if rot_angles_deg[1] > 5e-2: + rotate_one_hot[i, 0] = 1 # right + elif rot_angles_deg[1] < -5e-2: + rotate_one_hot[i, 1] = 1 # left + + # X-axis (Pitch) + if rot_angles_deg[0] > 5e-2: + rotate_one_hot[i, 2] = 1 # up + elif rot_angles_deg[0] < -5e-2: + rotate_one_hot[i, 3] = 1 # down + + trans_one_hot = torch.tensor(trans_one_hot) + rotate_one_hot = torch.tensor(rotate_one_hot) + + # Convert to single labels + trans_label = one_hot_to_one_dimension(trans_one_hot) + rotate_label = one_hot_to_one_dimension(rotate_one_hot) + action_labels = trans_label * 9 + rotate_label + + return viewmats, intrinsics, action_labels + +if __name__ == "__main__": + print("Running comparison test between process_custom_actions and pose_to_input...") + + def test_process_custom_actions(pose_string: str, keyboard: torch.Tensor, mouse: torch.Tensor, latent_num: int): + # Run process_custom_actions + # Note: We need to pass float tensors + print("Running process_custom_actions...") + viewmats_1, intrinsics_1, labels_1 = process_custom_actions( + keyboard, mouse + ) + + print(f"Running pose_to_input with string: '{pose_string}'...") + viewmats_2, intrinsics_2, labels_2 = pose_to_input( + pose_string, latent_num=latent_num + ) + + # print(f"Viewmats: {viewmats_1} vs \n {viewmats_2}") + # print(f"Intrinsics: {intrinsics_1} vs \n {intrinsics_2}") + # print(f"Labels: {labels_1} vs \n {labels_2}") + # 3. Compare Results + print("\nComparison Results:") + + # Check Shapes + print(f"Shapes (Viewmats): {viewmats_1.shape} vs {viewmats_2.shape}") + assert viewmats_1.shape == viewmats_2.shape, "Shape mismatch for viewmats" + + # Check Values + # Viewmats + diff_viewmats = (viewmats_1 - viewmats_2).abs().max().item() + print(f"Max difference in Viewmats: {diff_viewmats}") + if diff_viewmats < 1e-5: + print("✅ Viewmats match.") + else: + print("❌ Viewmats mismatch.") + + # Check intrinsics + diff_intrinsics = (intrinsics_1 - intrinsics_2).abs().max().item() + print(f"Max difference in Intrinsics: {diff_intrinsics}") + if diff_intrinsics < 1e-5: + print("✅ Intrinsics match.") + else: + print("❌ Intrinsics mismatch.") + + # Check labels + diff_labels = (labels_1 - labels_2).abs().max().item() + print(f"Max difference in Labels: {diff_labels}") + if diff_labels < 1e-5: + print("✅ Labels match.") + else: + print("❌ Labels mismatch.") + + print("All checks passed.") + + # Define shared parameters + + latent_num = 13 + pose_string = "w-2, a-3, s-1, d-6" + + num_frames = 4 * (latent_num - 1) + 1 + keyboard = torch.zeros((num_frames, 6)) + mouse = torch.zeros((num_frames, 2)) + + # Frame 0 is ignored/start + # Frames 1-8: Press W (index 0) + keyboard[1:9, 0] = 1.0 + # Frames 9-20: Press A (index 2) + keyboard[9:21, 2] = 1.0 + # Frames 21-24: Press S (index 1) + keyboard[21:25, 1] = 1.0 + # Frames 25-48: Press D (index 3) + keyboard[25:49, 3] = 1.0 + + test_process_custom_actions(pose_string, keyboard, mouse, latent_num) + + # Test keyboard AND mouse + latent_num = 25 + pose_string = "w-2, up-2, a-3, down-4, s-1, left-2, d-6, right-4" + + num_frames = 4 * (latent_num - 1) + 1 + keyboard = torch.zeros((num_frames, 6)) + mouse = torch.zeros((num_frames, 2)) + + # Frame 0 is ignored/start + # Frames 1-8: Press W (index 0) + keyboard[1:9, 0] = 1.0 + # Frames 17-28: Press A (index 2) + keyboard[17:29, 2] = 1.0 + # Frames 45-48: Press S (index 1) + keyboard[45:49, 1] = 1.0 + # Frames 57-80: Press D (index 3) + keyboard[57:81, 3] = 1.0 + + # Frames 9-16: Press Up (index 4) + mouse[9:17, 0] = DEFAULT_PITCH_SPEED + # Frames 25-32: Press Down (index 5) + mouse[29:45, 0] = -DEFAULT_PITCH_SPEED + # Frames 41-48: Press Left (index 6) + mouse[49:57, 1] = -DEFAULT_YAW_SPEED + # Frames 57-64: Press Right (index 7) + mouse[81:, 1] = DEFAULT_YAW_SPEED + + test_process_custom_actions(pose_string, keyboard, mouse, latent_num) diff --git a/fastvideo/models/dits/matrixgame/utils.py b/fastvideo/models/dits/matrixgame/utils.py index 4dd937699..c7dfd6743 100644 --- a/fastvideo/models/dits/matrixgame/utils.py +++ b/fastvideo/models/dits/matrixgame/utils.py @@ -301,119 +301,238 @@ def parse_config(config, mode="universal"): # NOTE: drawing functions are commented out to avoid cv2/libGL dependency. # -# def draw_rounded_rectangle(image, top_left, bottom_right, color, radius=10, alpha=0.5): -# overlay = image.copy() -# x1, y1 = top_left -# x2, y2 = bottom_right -# -# cv2.rectangle(overlay, (x1 + radius, y1), (x2 - radius, y2), color, -1) -# cv2.rectangle(overlay, (x1, y1 + radius), (x2, y2 - radius), color, -1) -# cv2.ellipse(overlay, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1) -# cv2.ellipse(overlay, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1) -# cv2.ellipse(overlay, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1) -# cv2.ellipse(overlay, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1) -# cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) -# -# def draw_keys_on_frame(frame, keys, key_size=(80, 50), spacing=20, bottom_margin=30, mode='universal'): -# h, w, _ = frame.shape -# horison_shift = 90 -# vertical_shift = -20 -# horizon_shift_all = 50 -# key_positions = { -# "W": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] * 2 + vertical_shift - 20), -# "A": (w // 2 - key_size[0] * 2 + 5 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "S": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "D": (w // 2 + key_size[0] - 5 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# } -# key_icon = {"W": "W", "A": "A", "S": "S", "D": "D", "left": "left", "right": "right"} -# if mode == 'templerun': -# key_positions.update({ -# "left": (w // 2 + key_size[0] * 2 + spacing * 2 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift), -# "right": (w // 2 + key_size[0] * 3 + spacing * 7 - horison_shift - horizon_shift_all, -# h - bottom_margin - key_size[1] + vertical_shift) -# }) -# -# for key, (x, y) in key_positions.items(): -# is_pressed = keys.get(key, False) -# top_left = (x, y) -# if key in ["left", "right"]: -# bottom_right = (x + key_size[0] + 40, y + key_size[1]) -# else: -# bottom_right = (x + key_size[0], y + key_size[1]) -# -# color = (0, 255, 0) if is_pressed else (200, 200, 200) -# alpha = 0.8 if is_pressed else 0.5 -# draw_rounded_rectangle(frame, top_left, bottom_right, color, radius=10, alpha=alpha) -# -# text_size = cv2.getTextSize(key, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0] -# if key in ["left", "right"]: -# text_x = x + (key_size[0] + 40 - text_size[0]) // 2 -# else: -# text_x = x + (key_size[0] - text_size[0]) // 2 -# text_y = y + (key_size[1] + text_size[1]) // 2 -# cv2.putText(frame, key_icon[key], (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) -# -# def overlay_icon(frame, icon, position, scale=1.0, rotation=0): -# x, y = position -# h, w, _ = icon.shape -# -# scaled_width = int(w * scale) -# scaled_height = int(h * scale) -# icon_resized = cv2.resize(icon, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA) -# -# center = (scaled_width // 2, scaled_height // 2) -# rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0) -# icon_rotated = cv2.warpAffine( -# icon_resized, rotation_matrix, (scaled_width, scaled_height), -# flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0) -# ) -# -# h, w, _ = icon_rotated.shape -# frame_h, frame_w, _ = frame.shape -# -# top_left_x = max(0, int(x - w // 2)) -# top_left_y = max(0, int(y - h // 2)) -# bottom_right_x = min(frame_w, int(x + w // 2)) -# bottom_right_y = min(frame_h, int(y + h // 2)) -# -# icon_x_start = max(0, int(-x + w // 2)) -# icon_y_start = max(0, int(-y + h // 2)) -# icon_x_end = icon_x_start + (bottom_right_x - top_left_x) -# icon_y_end = icon_y_start + (bottom_right_y - top_left_y) -# -# icon_region = icon_rotated[icon_y_start:icon_y_end, icon_x_start:icon_x_end] -# alpha = icon_region[:, :, 3] / 255.0 -# icon_rgb = icon_region[:, :, :3] -# -# frame_region = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] -# for c in range(3): -# frame_region[:, :, c] = (1 - alpha) * frame_region[:, :, c] + alpha * icon_rgb[:, :, c] -# frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = frame_region -# -# def process_video(input_video, output_video, config, mouse_icon_path, -# mouse_scale=1.0, mouse_rotation=0, process_icon=True, mode='universal'): -# key_data, mouse_data = parse_config(config, mode=mode) -# fps = 12 -# -# mouse_icon = cv2.imread(mouse_icon_path, cv2.IMREAD_UNCHANGED) -# -# out_video = [] -# for frame_idx, frame in enumerate(input_video): -# frame = np.ascontiguousarray(frame) -# if process_icon: -# keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) -# draw_keys_on_frame(frame, keys, key_size=(50, 50), spacing=10, bottom_margin=20, mode=mode) -# if mode == 'universal': -# frame_width = frame.shape[1] -# frame_height = frame.shape[0] -# mouse_position = mouse_data.get(frame_idx, (frame_width // 2, frame_height // 2)) -# overlay_icon(frame, mouse_icon, mouse_position, scale=mouse_scale, rotation=mouse_rotation) -# out_video.append(frame / 255) -# -# export_to_video(out_video, output_video, fps=fps) -# logger.info(f"Video saved to {output_video}") +import cv2 +import numpy as np +from diffusers.utils import export_to_video + +def draw_rounded_rectangle(image, top_left, bottom_right, color, radius=10, alpha=0.5): + overlay = image.copy() + x1, y1 = top_left + x2, y2 = bottom_right + + cv2.rectangle(overlay, (x1 + radius, y1), (x2 - radius, y2), color, -1) + cv2.rectangle(overlay, (x1, y1 + radius), (x2, y2 - radius), color, -1) + cv2.ellipse(overlay, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1) + cv2.ellipse(overlay, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1) + cv2.ellipse(overlay, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1) + cv2.ellipse(overlay, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1) + cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image) + +def draw_keys_on_frame(frame, keys, key_size=(30, 30), spacing=5, top_margin=15, mode='universal'): + """Draw WASD keys on the left top of the frame.""" + h, w, _ = frame.shape + + # Left top positioning + left_margin = 15 + gap = 3 # Gap between keys + + key_positions = { + "W": (left_margin + key_size[0] + gap, + top_margin), + "A": (left_margin, + top_margin + key_size[1] + gap), + "S": (left_margin + key_size[0] + gap, + top_margin + key_size[1] + gap), + "D": (left_margin + (key_size[0] + gap) * 2, + top_margin + key_size[1] + gap), + } + key_icon = {"W": "W", "A": "A", "S": "S", "D": "D", "left": "L", "right": "R"} + if mode == 'templerun': + key_positions.update({ + "left": (left_margin + (key_size[0] + gap) * 3 + 10, + top_margin + key_size[1] + gap), + "right": (left_margin + (key_size[0] + gap) * 4 + 15, + top_margin + key_size[1] + gap) + }) + + for key, (x, y) in key_positions.items(): + is_pressed = keys.get(key, False) + top_left = (x, y) + bottom_right = (x + key_size[0], y + key_size[1]) + + color = (0, 255, 0) if is_pressed else (200, 200, 200) + alpha = 0.8 if is_pressed else 0.5 + draw_rounded_rectangle(frame, top_left, bottom_right, color, radius=5, alpha=alpha) + + text_size = cv2.getTextSize(key_icon[key], cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] + text_x = x + (key_size[0] - text_size[0]) // 2 + text_y = y + (key_size[1] + text_size[1]) // 2 + cv2.putText(frame, key_icon[key], (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) + +def overlay_icon(frame, icon, position, scale=1.0, rotation=0): + x, y = position + h, w, _ = icon.shape + + scaled_width = int(w * scale) + scaled_height = int(h * scale) + icon_resized = cv2.resize(icon, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA) + + center = (scaled_width // 2, scaled_height // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0) + icon_rotated = cv2.warpAffine( + icon_resized, rotation_matrix, (scaled_width, scaled_height), + flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0) + ) + + h, w, _ = icon_rotated.shape + frame_h, frame_w, _ = frame.shape + + top_left_x = max(0, int(x - w // 2)) + top_left_y = max(0, int(y - h // 2)) + bottom_right_x = min(frame_w, int(x + w // 2)) + bottom_right_y = min(frame_h, int(y + h // 2)) + + icon_x_start = max(0, int(-x + w // 2)) + icon_y_start = max(0, int(-y + h // 2)) + icon_x_end = icon_x_start + (bottom_right_x - top_left_x) + icon_y_end = icon_y_start + (bottom_right_y - top_left_y) + + icon_region = icon_rotated[icon_y_start:icon_y_end, icon_x_start:icon_x_end] + alpha = icon_region[:, :, 3] / 255.0 + icon_rgb = icon_region[:, :, :3] + + frame_region = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] + for c in range(3): + frame_region[:, :, c] = (1 - alpha) * frame_region[:, :, c] + alpha * icon_rgb[:, :, c] + frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = frame_region + +def process_video(input_video, output_video, config, mouse_icon_path, + mouse_scale=1.0, mouse_rotation=0, process_icon=True, mode='universal'): + key_data, mouse_data = parse_config(config, mode=mode) + fps = 12 + + mouse_icon = cv2.imread(mouse_icon_path, cv2.IMREAD_UNCHANGED) + + out_video = [] + for frame_idx, frame in enumerate(input_video): + frame = np.ascontiguousarray(frame) + if process_icon: + keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) + draw_keys_on_frame(frame, keys, key_size=(50, 50), spacing=10, bottom_margin=20, mode=mode) + if mode == 'universal': + frame_width = frame.shape[1] + frame_height = frame.shape[0] + mouse_position = mouse_data.get(frame_idx, (frame_width // 2, frame_height // 2)) + overlay_icon(frame, mouse_icon, mouse_position, scale=mouse_scale, rotation=mouse_rotation) + out_video.append(frame / 255) + + export_to_video(out_video, output_video, fps=fps) + logger.info(f"Video saved to {output_video}") + + +def parse_npy_action(action_path): + """Convert npy action file to key_data and mouse_data dict format.""" + action_data = np.load(action_path, allow_pickle=True).item() + keyboard_data = action_data['keyboard'] # shape: (num_frames, 6) -> [W, S, A, D, left, right] + mouse_data = action_data.get('mouse', None) # shape: (num_frames, 2) -> [Pitch, Yaw] + + # MatrixGame convention: 0:W, 1:S, 2:A, 3:D, 4:left, 5:right + key_names = ["W", "S", "A", "D", "left", "right"] + key_data = {} + for frame_idx, keys in enumerate(keyboard_data): + key_data[frame_idx] = {key_names[i]: bool(keys[i]) for i in range(len(key_names))} + + # MatrixGame convention: mouse is [Pitch, Yaw] + mouse_dict = {} + if mouse_data is not None: + for frame_idx, (pitch, yaw) in enumerate(mouse_data): + mouse_dict[frame_idx] = {"pitch": float(pitch), "yaw": float(yaw)} + + return key_data, mouse_dict + + +def draw_mouse_on_frame(frame, pitch, yaw, top_margin=15): + """Draw crosshair with direction arrow on the right top of the frame.""" + h, w, _ = frame.shape + + # Right top positioning + right_margin = 15 + crosshair_radius = 25 + + # Position crosshair on the right top + crosshair_x = w - right_margin - crosshair_radius + crosshair_y = top_margin + crosshair_radius + + # Yaw affects horizontal direction, pitch affects vertical + dx = int(yaw * crosshair_radius * 8) # Scale for visibility + dy = int(-pitch * crosshair_radius * 8) # Negative because y increases downward + + # Clamp arrow length + max_arrow = crosshair_radius - 5 + dx = max(-max_arrow, min(max_arrow, dx)) + dy = max(-max_arrow, min(max_arrow, dy)) + + # Draw crosshair background + cv2.circle(frame, (crosshair_x, crosshair_y), crosshair_radius, (50, 50, 50), -1) + cv2.circle(frame, (crosshair_x, crosshair_y), crosshair_radius, (200, 200, 200), 1) + cv2.line(frame, (crosshair_x - crosshair_radius + 5, crosshair_y), + (crosshair_x + crosshair_radius - 5, crosshair_y), (100, 100, 100), 1) + cv2.line(frame, (crosshair_x, crosshair_y - crosshair_radius + 5), + (crosshair_x, crosshair_y + crosshair_radius - 5), (100, 100, 100), 1) + + # Draw direction arrow + if abs(dx) > 1 or abs(dy) > 1: + cv2.arrowedLine(frame, (crosshair_x, crosshair_y), (crosshair_x + dx, crosshair_y + dy), + (0, 255, 0), 2, tipLength=0.3) + + +def process_video_with_npy(input_video, output_video, action_path, fps=12, mode='universal'): + """Process video with overlay using npy action file. + + Uses existing draw_keys_on_frame function. + """ + key_data, mouse_data = parse_npy_action(action_path) + + out_video = [] + for frame_idx, frame in enumerate(input_video): + frame = np.ascontiguousarray(frame) + keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "left": False, "right": False}) + draw_keys_on_frame(frame, keys, mode=mode) + + # Draw pitch and yaw + mouse = mouse_data.get(frame_idx, {"pitch": 0.0, "yaw": 0.0}) + draw_mouse_on_frame(frame, mouse["pitch"], mouse["yaw"]) + + out_video.append(frame / 255.0) + + export_to_video(out_video, output_video, fps=fps) + logger.info(f"Video saved to {output_video}") + + +if __name__ == "__main__": + import argparse + import cv2 + + parser = argparse.ArgumentParser(description="Overlay keyboard actions on video") + parser.add_argument("--video", type=str, required=True, help="Path to input video (.mp4)") + parser.add_argument("--action", type=str, required=True, help="Path to action file (.npy)") + parser.add_argument("--output", type=str, default=None, help="Path to output video (default: input_with_overlay.mp4)") + parser.add_argument("--fps", type=int, default=12, help="Output video FPS") + args = parser.parse_args() + + # Load video frames using cv2 + cap = cv2.VideoCapture(args.video) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {args.video}") + + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + + print(f"Loaded {len(frames)} frames from video") + + # Set output path + if args.output is None: + base_name = args.video.rsplit('.', 1)[0] + output_path = f"{base_name}_with_overlay.mp4" + else: + output_path = args.output + + # Process video with overlay using existing functions + process_video_with_npy(frames, output_path, args.action, fps=args.fps) + print(f"Video with overlay saved to: {output_path}") diff --git a/fastvideo/models/dits/wangame/__init__.py b/fastvideo/models/dits/wangame/__init__.py new file mode 100644 index 000000000..1d799e506 --- /dev/null +++ b/fastvideo/models/dits/wangame/__init__.py @@ -0,0 +1,10 @@ +from .model import WanGameActionTransformer3DModel +from .causal_model import CausalWanGameActionTransformer3DModel +from .hyworld_action_module import WanGameActionTimeImageEmbedding, WanGameActionSelfAttention + +__all__ = [ + "WanGameActionTransformer3DModel", + "CausalWanGameActionTransformer3DModel", + "WanGameActionTimeImageEmbedding", + "WanGameActionSelfAttention", +] diff --git a/fastvideo/models/dits/wangame/causal_model.py b/fastvideo/models/dits/wangame/causal_model.py new file mode 100644 index 000000000..64bfefafc --- /dev/null +++ b/fastvideo/models/dits/wangame/causal_model.py @@ -0,0 +1,856 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import torch +import torch.nn as nn + +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.nn.attention.flex_attention import BlockMask +# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention +# see https://github.com/pytorch/pytorch/issues/133254 +# change to default for other models +flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") +import torch.distributed as dist + +from fastvideo.attention import LocalAttention +from fastvideo.configs.models.dits.wangamevideo import WanGameVideoConfig +from fastvideo.distributed.parallel_state import get_sp_world_size +from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, + RMSNorm, ScaleResidual, + ScaleResidualLayerNormScaleShift) +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.mlp import MLP +from fastvideo.layers.rotary_embedding import (_apply_rotary_emb, + get_rotary_pos_embed) +from fastvideo.layers.visual_embedding import PatchEmbed +from fastvideo.logger import init_logger +from fastvideo.models.dits.base import BaseDiT +from fastvideo.models.dits.wanvideo import WanI2VCrossAttention +from fastvideo.platforms import AttentionBackendEnum, current_platform + +# Import ActionModule +from fastvideo.models.dits.wangame.hyworld_action_module import ( + WanGameActionTimeImageEmbedding, + WanGameActionSelfAttention +) +from fastvideo.models.dits.hyworld.camera_rope import prope_qkv + +logger = init_logger(__name__) + + +class CausalWanGameCrossAttention(WanI2VCrossAttention): + """Cross-attention for WanGame causal model""" + + def forward(self, x, context, context_lens=None, crossattn_cache=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + crossattn_cache: Optional cache dict for inference + """ + context_img = context + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d) + + if crossattn_cache is not None: + if not crossattn_cache["is_init"]: + crossattn_cache["is_init"] = True + k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view( + b, -1, n, d) + v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d) + crossattn_cache["k"] = k_img + crossattn_cache["v"] = v_img + else: + k_img = crossattn_cache["k"] + v_img = crossattn_cache["v"] + else: + k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view( + b, -1, n, d) + v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d) + + img_x = self.attn(q, k_img, v_img) + + # output + x = img_x.flatten(2) + x, _ = self.to_out(x) + return x + + +class CausalWanGameActionSelfAttention(WanGameActionSelfAttention): + + def __init__(self, + dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm=True, + eps=1e-6) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + local_attn_size=local_attn_size, + sink_size=sink_size, + qk_norm=qk_norm, + eps=eps, + ) + self.max_attention_size = 32760 if local_attn_size == -1 else local_attn_size * 1560 + + # Local attention for KV-cache inference + self.local_attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=(AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA)) + + @staticmethod + def _masked_flex_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_mask: BlockMask, + ) -> torch.Tensor: + padded_length = math.ceil(query.shape[1] / 128) * 128 - query.shape[1] + if padded_length > 0: + query = torch.cat( + [ + query, + torch.zeros( + [query.shape[0], padded_length, query.shape[2], query.shape[3]], + device=query.device, + dtype=value.dtype, + ), + ], + dim=1, + ) + key = torch.cat( + [ + key, + torch.zeros( + [key.shape[0], padded_length, key.shape[2], key.shape[3]], + device=key.device, + dtype=value.dtype, + ), + ], + dim=1, + ) + value = torch.cat( + [ + value, + torch.zeros( + [value.shape[0], padded_length, value.shape[2], value.shape[3]], + device=value.device, + dtype=value.dtype, + ), + ], + dim=1, + ) + + out = flex_attention( + query=query.transpose(2, 1), + key=key.transpose(2, 1), + value=value.transpose(2, 1), + block_mask=block_mask, + ).transpose(2, 1) + + if padded_length > 0: + out = out[:, :-padded_length] + return out + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask | None = None, + kv_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + is_cache: bool = False): + """ + Forward pass with causal attention. + """ + if cache_start is None: + cache_start = current_start + + if kv_cache is None: + if block_mask is None: + raise ValueError( + "block_mask must be provided for causal training attention") + if viewmats is None or Ks is None: + raise ValueError( + "viewmats and Ks must be provided for WanGame causal attention") + + cos, sin = freqs_cis + query_rope = _apply_rotary_emb( + q, cos, sin, is_neox_style=False).type_as(v) + key_rope = _apply_rotary_emb( + k, cos, sin, is_neox_style=False).type_as(v) + rope_output = self._masked_flex_attn( + query_rope, key_rope, v, block_mask) + + # PRoPE path with the same causal mask. + query_prope, key_prope, value_prope, apply_fn_o = prope_qkv( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + viewmats=viewmats, + Ks=Ks, + patches_x=40, + patches_y=22, + ) + query_prope = query_prope.transpose(1, 2) + key_prope = key_prope.transpose(1, 2) + value_prope = value_prope.transpose(1, 2) + prope_output = self._masked_flex_attn( + query_prope, key_prope, value_prope, block_mask) + prope_output = apply_fn_o( + prope_output.transpose(1, 2)).transpose(1, 2) + + return rope_output, prope_output + else: + # Inference mode with KV cache + if viewmats is None or Ks is None: + raise ValueError( + "viewmats and Ks must be provided for WanGame causal attention") + + cos, sin = freqs_cis + roped_query = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v) + roped_key = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v) + query_prope, key_prope, value_prope, apply_fn_o = prope_qkv( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + viewmats=viewmats, + Ks=Ks, + patches_x=40, + patches_y=22, + ) + query_prope = query_prope.transpose(1, 2).type_as(v) + key_prope = key_prope.transpose(1, 2).type_as(v) + value_prope = value_prope.transpose(1, 2).type_as(v) + + frame_seqlen = q.shape[1] + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache + kv_cache_size = kv_cache["k"].shape[1] + num_new_tokens = roped_query.shape[1] + + # rope+prope + cache_head_dim = kv_cache["k"].shape[-1] + local_end_index = kv_cache["local_end_index"].item() + + # read cache but never mutate it. + if not is_cache: + if cache_head_dim not in (self.head_dim, self.head_dim * 2): + raise ValueError( + f"Unexpected kv_cache head dim: {cache_head_dim}, " + f"expected {self.head_dim} or {self.head_dim * 2}") + + cache_k_rope = kv_cache["k"][..., :self.head_dim] + cache_v_rope = kv_cache["v"][..., :self.head_dim] + rope_k = torch.cat( + [cache_k_rope[:, :local_end_index], roped_key], dim=1) + rope_v = torch.cat( + [cache_v_rope[:, :local_end_index], v], dim=1) + rope_k = rope_k[:, -self.max_attention_size:] + rope_v = rope_v[:, -self.max_attention_size:] + rope_x = self.local_attn(roped_query, rope_k, rope_v) + + if cache_head_dim == self.head_dim * 2: + cache_k_prope = kv_cache["k"][..., self.head_dim:] + cache_v_prope = kv_cache["v"][..., self.head_dim:] + prope_k = torch.cat( + [cache_k_prope[:, :local_end_index], key_prope], dim=1) + prope_v = torch.cat( + [cache_v_prope[:, :local_end_index], value_prope], dim=1) + prope_k = prope_k[:, -self.max_attention_size:] + prope_v = prope_v[:, -self.max_attention_size:] + prope_x = self.local_attn(query_prope, prope_k, prope_v) + else: + prope_x = self.local_attn( + query_prope, key_prope, value_prope) + + prope_x = apply_fn_o(prope_x.transpose(1, 2)).transpose(1, 2) + return rope_x, prope_x + + # update cache. + if cache_head_dim == self.head_dim: + kv_cache["k"] = torch.cat( + [kv_cache["k"], torch.zeros_like(kv_cache["k"])], dim=-1) + kv_cache["v"] = torch.cat( + [kv_cache["v"], torch.zeros_like(kv_cache["v"])], dim=-1) + elif cache_head_dim != self.head_dim * 2: + raise ValueError( + f"Unexpected kv_cache head dim: {cache_head_dim}, " + f"expected {self.head_dim} or {self.head_dim * 2}") + + cache_k_rope = kv_cache["k"][..., :self.head_dim] + cache_k_prope = kv_cache["k"][..., self.head_dim:] + cache_v_rope = kv_cache["v"][..., :self.head_dim] + cache_v_prope = kv_cache["v"][..., self.head_dim:] + + if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and ( + num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size + num_rolled_tokens = kv_cache["local_end_index"].item() - num_evicted_tokens - sink_tokens + cache_k_rope[:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + cache_k_rope[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + cache_v_rope[:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + cache_v_rope[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + cache_k_prope[:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + cache_k_prope[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + cache_v_prope[:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + cache_v_prope[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + # Insert the new keys/values at the end + local_end_index = kv_cache["local_end_index"].item() + current_end - \ + kv_cache["global_end_index"].item() - num_evicted_tokens + local_start_index = local_end_index - num_new_tokens + cache_k_rope[:, local_start_index:local_end_index] = roped_key + cache_v_rope[:, local_start_index:local_end_index] = v + cache_k_prope[:, local_start_index:local_end_index] = key_prope + cache_v_prope[:, local_start_index:local_end_index] = value_prope + else: + # Assign new keys/values directly up to current_end + local_end_index = kv_cache["local_end_index"].item() + current_end - kv_cache["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + kv_cache["k"] = kv_cache["k"].detach() + kv_cache["v"] = kv_cache["v"].detach() + cache_k_rope = kv_cache["k"][..., :self.head_dim] + cache_k_prope = kv_cache["k"][..., self.head_dim:] + cache_v_rope = kv_cache["v"][..., :self.head_dim] + cache_v_prope = kv_cache["v"][..., self.head_dim:] + # logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None) + cache_k_rope[:, local_start_index:local_end_index] = roped_key + cache_v_rope[:, local_start_index:local_end_index] = v + cache_k_prope[:, local_start_index:local_end_index] = key_prope + cache_v_prope[:, local_start_index:local_end_index] = value_prope + + rope_x = self.local_attn( + roped_query, + cache_k_rope[:, max(0, local_end_index - self.max_attention_size):local_end_index], + cache_v_rope[:, max(0, local_end_index - self.max_attention_size):local_end_index] + ) + prope_x = self.local_attn( + query_prope, + cache_k_prope[:, max(0, local_end_index - self.max_attention_size):local_end_index], + cache_v_prope[:, max(0, local_end_index - self.max_attention_size):local_end_index] + ) + prope_x = apply_fn_o(prope_x.transpose(1, 2)).transpose(1, 2) + kv_cache["global_end_index"].fill_(current_end) + kv_cache["local_end_index"].fill_(local_end_index) + + return rope_x, prope_x + + +class CausalWanGameActionTransformerBlock(nn.Module): + + def __init__(self, + dim: int, + ffn_dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + prefix: str = ""): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + self.to_out = ReplicatedLinear(dim, dim, bias=True) + + self.attn1 = CausalWanGameActionSelfAttention( + dim, + num_heads, + local_attn_size=local_attn_size, + sink_size=sink_size, + qk_norm=qk_norm, + eps=eps) + + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.local_attn_size = local_attn_size + dim_head = dim // num_heads + + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + raise ValueError(f"QK Norm type {qk_norm} not supported") + + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + compute_dtype=torch.float32) + + # 2. Cross-attention (I2V only) + self.attn2 = CausalWanGameCrossAttention(dim, + num_heads, + qk_norm=qk_norm, + eps=eps) + # norm3 for FFN input + self.norm3 = LayerNormScaleShift(dim, norm_type="layer", eps=eps, + elementwise_affine=False) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # PRoPE output projection (initialized via add_discrete_action_parameters on the model) + self.to_out_prope = ReplicatedLinear(dim, dim, bias=True) + nn.init.zeros_(self.to_out_prope.weight) + if self.to_out_prope.bias is not None: + nn.init.zeros_(self.to_out_prope.bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask | None = None, + kv_cache: dict | None = None, + crossattn_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + is_cache: bool = False, + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + + num_frames = temb.shape[1] + frame_seqlen = hidden_states.shape[1] // num_frames + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + + # Cast temb to float32 for scale/shift computation + e = self.scale_shift_table + temb.float() + assert e.shape == (bs, num_frames, 6, self.hidden_dim) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(6, dim=2) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * + (1 + scale_msa) + shift_msa).to(orig_dtype).flatten(1, 2) + + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q.forward_native(query) + if self.norm_k is not None: + key = self.norm_k.forward_native(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + # Self-attention with camera PRoPE + attn_output_rope, attn_output_prope = self.attn1( + query, key, value, freqs_cis, + block_mask, kv_cache, current_start, cache_start, + viewmats, Ks, is_cache=is_cache + ) + # Combine rope and prope outputs + attn_output_rope = attn_output_rope.flatten(2) + attn_output_rope, _ = self.to_out(attn_output_rope) + attn_output_prope = attn_output_prope.flatten(2) + attn_output_prope, _ = self.to_out_prope(attn_output_prope) + attn_output = attn_output_rope.squeeze(1) + attn_output_prope.squeeze(1) + + # Self-attention residual + norm in float32 + null_shift = null_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states.float(), attn_output.float(), gate_msa, null_shift, null_scale) + hidden_states = hidden_states.type_as(attn_output) + norm_hidden_states = norm_hidden_states.type_as(attn_output) + + # 2. Cross-attention + attn_output = self.attn2(norm_hidden_states.to(orig_dtype), + context=encoder_hidden_states, + context_lens=None, + crossattn_cache=crossattn_cache) + # Cross-attention residual in bfloat16 + hidden_states = hidden_states + attn_output + + # norm3 for FFN input in float32 + norm_hidden_states = self.norm3( + hidden_states.float(), c_shift_msa, c_scale_msa + ).type_as(hidden_states) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states.to(orig_dtype)) + hidden_states = self.mlp_residual(hidden_states.float(), ff_output.float(), c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class CausalWanGameActionTransformer3DModel(BaseDiT): + + _fsdp_shard_conditions = WanGameVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanGameVideoConfig()._compile_conditions + _supported_attention_backends = WanGameVideoConfig()._supported_attention_backends + param_names_mapping = WanGameVideoConfig().param_names_mapping + reverse_param_names_mapping = WanGameVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanGameVideoConfig().lora_param_names_mapping + + def __init__(self, config: WanGameVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.local_attn_size = config.local_attn_size + self.inner_dim = inner_dim + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed(in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False) + + # 2. Condition embeddings + self.condition_embedder = WanGameActionTimeImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + CausalWanGameActionTransformerBlock( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.local_attn_size, + config.sink_size, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.blocks.{i}") + for i in range(config.num_layers) + ]) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift(inner_dim, + norm_type="layer", + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + # Causal-specific + self.block_mask = None + self.num_frame_per_block = config.arch_config.num_frames_per_block + assert self.num_frame_per_block <= 3 + + self.__post_init__() + + @staticmethod + def _prepare_blockwise_causal_attn_mask( + device: torch.device | str, num_frames: int = 21, + frame_seqlen: int = 1560, num_frame_per_block=1, local_attn_size=-1 + ) -> BlockMask: + """ + we will divide the token sequence into the following format + [1 latent frame] [1 latent frame] ... [1 latent frame] + We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros(total_length + padded_length, + device=device, dtype=torch.long) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange( + start=0, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device + ) + + for tmp in frame_indices: + ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \ + frame_seqlen * num_frame_per_block + + def attention_mask(b, h, q_idx, kv_idx): + if local_attn_size == -1: + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + else: + return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))) | (q_idx == kv_idx) + + block_mask = create_block_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, _compile=False, device=device) + + if not dist.is_initialized() or dist.get_rank() == 0: + print( + f" cache a block wise causal mask with block size of {num_frame_per_block} frames") + print(block_mask) + + return block_mask + + def _forward_inference( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + action: torch.Tensor | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + kv_cache: list[dict] | None = None, + crossattn_cache: list[dict] | None = None, + current_start: int = 0, + cache_start: int = 0, + start_frame: int = 0, + is_cache: bool = False, + **kwargs + ) -> torch.Tensor: + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + """ + orig_dtype = hidden_states.dtype + if isinstance(encoder_hidden_states_image, list) and len(encoder_hidden_states_image) > 0: + encoder_hidden_states_image = encoder_hidden_states_image[0] + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + (post_patch_num_frames * get_sp_world_size(), post_patch_height, post_patch_width), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + if timestep.dim() == 2: + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, action, encoder_hidden_states, encoder_hidden_states_image=encoder_hidden_states_image) + + # condition_embedder returns: + # - temb: [B*T, dim] where T = post_patch_num_frames + # - timestep_proj: [B*T, 6*dim] + # Reshape to [B, T, 6, dim] for transformer blocks + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)) # [B*T, 6, dim] + timestep_proj = timestep_proj.view(batch_size, post_patch_num_frames, 6, self.hidden_size) # [B, T, 6, dim] + + encoder_hidden_states = encoder_hidden_states_image + + # Transformer blocks + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + self.block_mask, + kv_cache[block_idx] if kv_cache else None, + crossattn_cache[block_idx] if crossattn_cache else None, + current_start, cache_start, + viewmats, Ks, is_cache) + else: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + block_mask=self.block_mask, + kv_cache=kv_cache[block_idx] if kv_cache else None, + crossattn_cache=crossattn_cache[block_idx] if crossattn_cache else None, + current_start=current_start, cache_start=cache_start, + viewmats=viewmats, Ks=Ks, is_cache=is_cache) + + # If cache-only mode, return early + if is_cache: + return kv_cache + + # Output norm, projection & unpatchify + # temb is [B*T, dim], reshape to [B, T, 1, dim] + temb = temb.view(batch_size, post_patch_num_frames, -1).unsqueeze(2) # [B, T, 1, dim] + + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, + post_patch_height, + post_patch_width, p_t, p_h, p_w, + -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def _forward_train( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + action: torch.Tensor | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + start_frame: int = 0, + **kwargs + ) -> torch.Tensor: + + orig_dtype = hidden_states.dtype + if isinstance(encoder_hidden_states_image, list) and len(encoder_hidden_states_image) > 0: + encoder_hidden_states_image = encoder_hidden_states_image[0] + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + (post_patch_num_frames * get_sp_world_size(), post_patch_height, post_patch_width), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + # Construct blockwise causal attn mask + if self.block_mask is None: + self.block_mask = self._prepare_blockwise_causal_attn_mask( + device=hidden_states.device, + num_frames=num_frames, + frame_seqlen=post_patch_height * post_patch_width, + num_frame_per_block=self.num_frame_per_block, + local_attn_size=self.local_attn_size + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + if timestep.dim() == 2: + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, action, encoder_hidden_states, encoder_hidden_states_image=encoder_hidden_states_image) + + # condition_embedder returns: + # - temb: [B*T, dim] where T = post_patch_num_frames + # - timestep_proj: [B*T, 6*dim] + # Reshape to [B, T, 6, dim] for transformer blocks + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)) # [B*T, 6, dim] + timestep_proj = timestep_proj.view(batch_size, post_patch_num_frames, 6, self.hidden_size) # [B, T, 6, dim] + + encoder_hidden_states = encoder_hidden_states_image + + # Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, + timestep_proj, freqs_cis, + self.block_mask, + None, None, # kv_cache, crossattn_cache + 0, None, # current_start, cache_start + viewmats, Ks, False) # viewmats, Ks, is_cache + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, + timestep_proj, freqs_cis, + block_mask=self.block_mask, + viewmats=viewmats, Ks=Ks) + + # Output norm, projection & unpatchify + # temb is [B*T, dim], reshape to [B, T, 1, dim] + temb = temb.view(batch_size, post_patch_num_frames, -1).unsqueeze(2) # [B, T, 1, dim] + + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, + post_patch_height, + post_patch_width, p_t, p_h, p_w, + -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def forward( + self, + *args, + **kwargs + ): + if kwargs.get('kv_cache', None) is not None: + return self._forward_inference(*args, **kwargs) + else: + return self._forward_train(*args, **kwargs) diff --git a/fastvideo/models/dits/wangame/hyworld_action_module.py b/fastvideo/models/dits/wangame/hyworld_action_module.py new file mode 100644 index 000000000..a0159d12f --- /dev/null +++ b/fastvideo/models/dits/wangame/hyworld_action_module.py @@ -0,0 +1,280 @@ +import math + +import torch +import torch.nn as nn + +from fastvideo.layers.visual_embedding import TimestepEmbedder, ModulateProjection, timestep_embedding +from fastvideo.platforms import AttentionBackendEnum +from fastvideo.attention import DistributedAttention +from fastvideo.forward_context import set_forward_context +from fastvideo.models.dits.wanvideo import WanImageEmbedding + +from fastvideo.models.dits.hyworld.camera_rope import prope_qkv +from fastvideo.layers.rotary_embedding import _apply_rotary_emb +from fastvideo.layers.mlp import MLP + +class WanGameActionTimeImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + image_embed_dim: int | None = None, + ): + super().__init__() + + self.time_freq_dim = time_freq_dim + self.time_embedder = TimestepEmbedder( + dim, frequency_embedding_size=time_freq_dim, act_layer="silu") + self.time_modulation = ModulateProjection(dim, + factor=6, + act_layer="silu") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim) + + self.action_embedder = MLP( + time_freq_dim, + dim, + dim, + bias=True, + act_type="silu" + ) + # Initialize fc_in with kaiming_uniform (same as nn.Linear default) + nn.init.kaiming_uniform_(self.action_embedder.fc_in.weight, a=math.sqrt(5)) + # Initialize fc_out with zeros for residual-like behavior + nn.init.zeros_(self.action_embedder.fc_out.weight) + if self.action_embedder.fc_out.bias is not None: + nn.init.zeros_(self.action_embedder.fc_out.bias) + + def forward( + self, + timestep: torch.Tensor, + action: torch.Tensor, + encoder_hidden_states: torch.Tensor, # Kept for interface compatibility + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ): + """ + Args: + timestep: [B] diffusion timesteps (one per batch sample) + action: [B, T] action labels (one per frame per batch sample) + + Returns: + temb: [B*T, dim] combined timestep + action embedding + timestep_proj: [B*T, 6*dim] modulation projection + """ + # timestep may be [B] (one per sample) or [B*T] (one per frame, from causal training) + temb = self.time_embedder(timestep, timestep_seq_len) + + # Handle action embedding for batch > 1 + # action shape: [B, T] where B=batch_size, T=num_frames + batch_size = action.shape[0] + num_frames = action.shape[1] + + # Compute action embeddings: [B, T] -> [B*T] -> [B*T, dim] + action_flat = action.flatten() # [B*T] + action_emb = timestep_embedding(action_flat, self.time_freq_dim) + action_embedder_dtype = next(iter(self.action_embedder.parameters())).dtype + if ( + action_emb.dtype != action_embedder_dtype + and action_embedder_dtype != torch.int8 + ): + action_emb = action_emb.to(action_embedder_dtype) + action_emb = self.action_embedder(action_emb).type_as(temb) # [B*T, dim] + + # temb is [B*T, dim] when timestep was already per-frame (causal training), + # or [B, dim] when timestep is per-sample (inference). + # Only expand if temb is per-sample [B, dim]. + if temb.shape[0] == batch_size and num_frames > 1: + # Expand temb: [B, dim] -> [B, T, dim] -> [B*T, dim] + temb_expanded = temb.unsqueeze(1).expand(-1, num_frames, -1) # [B, T, dim] + temb_expanded = temb_expanded.reshape(batch_size * num_frames, -1) # [B*T, dim] + else: + # temb is already [B*T, dim] (per-frame timesteps) + temb_expanded = temb + + # Add action embedding to expanded temb + temb = temb_expanded + action_emb # [B*T, dim] + + timestep_proj = self.time_modulation(temb) # [B*T, 6*dim] + + # MatrixGame does not use text embeddings, so we ignore encoder_hidden_states + + if encoder_hidden_states_image is not None: + assert self.image_embedder is not None + encoder_hidden_states_image = self.image_embedder( + encoder_hidden_states_image) + + encoder_hidden_states = torch.zeros((batch_size, 0, temb.shape[-1]), + device=temb.device, + dtype=temb.dtype) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + +class WanGameActionSelfAttention(nn.Module): + """ + Self-attention module with support for: + - Standard RoPE-based attention + - Camera PRoPE-based attention (when viewmats and Ks are provided) + - KV caching for autoregressive generation + """ + + def __init__(self, + dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm=True, + eps=1e-6) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + self.max_attention_size = 32760 if local_attn_size == -1 else local_attn_size * 1560 + + # Scaled dot product attention (using DistributedAttention for SP support) + self.attn = DistributedAttention( + num_heads=num_heads, + head_size=self.head_dim, + softmax_scale=None, + causal=False, + supported_attention_backends=(AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA)) + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + kv_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + is_cache: bool = False, + attention_mask: torch.Tensor | None = None): + """ + Forward pass with camera PRoPE attention combining standard RoPE and projective positional encoding. + + Args: + q, k, v: Query, key, value tensors [B, L, num_heads, head_dim] + freqs_cis: RoPE frequency cos/sin tensors + kv_cache: KV cache dict (may have None values for training) + current_start: Current position for KV cache + cache_start: Cache start position + viewmats: Camera view matrices for PRoPE [B, cameras, 4, 4] + Ks: Camera intrinsics for PRoPE [B, cameras, 3, 3] + is_cache: Whether to store to KV cache (for inference) + attention_mask: Attention mask [B, L] (1 = attend, 0 = mask) + """ + if cache_start is None: + cache_start = current_start + + # Apply RoPE manually + cos, sin = freqs_cis + query_rope = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v) + key_rope = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v) + value_rope = v + + # # DEBUG: Check camera matrices + # if self.training and torch.distributed.get_rank() == 0: + # vm_info = f"viewmats={viewmats.shape if viewmats is not None else None}" + # ks_info = f"Ks={Ks.shape if Ks is not None else None}" + # vm_nonzero = (viewmats != 0).sum().item() if viewmats is not None else 0 + # ks_nonzero = (Ks != 0).sum().item() if Ks is not None else 0 + # print(f"[DEBUG] PRoPE input: {vm_info} nonzero={vm_nonzero}, {ks_info} nonzero={ks_nonzero}", flush=True) + + # Get PRoPE transformed q, k, v + query_prope, key_prope, value_prope, apply_fn_o = prope_qkv( + q.transpose(1, 2), # [B, num_heads, L, head_dim] + k.transpose(1, 2), + v.transpose(1, 2), + viewmats=viewmats, + Ks=Ks, + patches_x=40, # hardcoded for now + patches_y=22, + ) + # PRoPE returns [B, num_heads, L, head_dim], convert to [B, L, num_heads, head_dim] + query_prope = query_prope.transpose(1, 2) + key_prope = key_prope.transpose(1, 2) + value_prope = value_prope.transpose(1, 2) + + # # DEBUG: Check prope_qkv output + # if self.training and torch.distributed.get_rank() == 0: + # q_nz = (query_prope != 0).sum().item() + # k_nz = (key_prope != 0).sum().item() + # v_nz = (value_prope != 0).sum().item() + # print(f"[DEBUG] prope_qkv output: q_nonzero={q_nz}, k_nonzero={k_nz}, v_nonzero={v_nz}", flush=True) + + # KV cache handling + if kv_cache is not None: + cache_key = kv_cache.get("k", None) + cache_value = kv_cache.get("v", None) + + if cache_value is not None and not is_cache: + cache_key_rope, cache_key_prope = cache_key.chunk(2, dim=-1) + cache_value_rope, cache_value_prope = cache_value.chunk(2, dim=-1) + + key_rope = torch.cat([cache_key_rope, key_rope], dim=1) + value_rope = torch.cat([cache_value_rope, value_rope], dim=1) + key_prope = torch.cat([cache_key_prope, key_prope], dim=1) + value_prope = torch.cat([cache_value_prope, value_prope], dim=1) + + if is_cache: + # Store to cache (update input dict directly) + kv_cache["k"] = torch.cat([key_rope, key_prope], dim=-1) + kv_cache["v"] = torch.cat([value_rope, value_prope], dim=-1) + + # Concatenate rope and prope paths (matching original) + query_all = torch.cat([query_rope, query_prope], dim=0) + key_all = torch.cat([key_rope, key_prope], dim=0) + value_all = torch.cat([value_rope, value_prope], dim=0) + + # Check if Q and KV have different sequence lengths (KV cache mode) + # In this case, use LocalAttention (supports different Q/KV lengths) + if query_all.shape[1] != key_all.shape[1]: + raise ValueError("Q and KV have different sequence lengths") + else: + # Same sequence length: use DistributedAttention (supports SP) + # Create default attention mask if not provided + # NOTE: query_all has shape [2*B, L, ...] (rope+prope concatenated), so mask needs 2*B + if attention_mask is None: + batch_size, seq_len = q.shape[0], q.shape[1] + attention_mask = torch.ones(batch_size * 2, seq_len, device=q.device, dtype=q.dtype) + + if q.dtype == torch.float32: + from fastvideo.attention.backends.sdpa import SDPAMetadataBuilder + attn_metadata_builder = SDPAMetadataBuilder + else: + from fastvideo.attention.backends.flash_attn import FlashAttnMetadataBuilder + attn_metadata_builder = FlashAttnMetadataBuilder + attn_metadata = attn_metadata_builder().build( + current_timestep=0, + attn_mask=attention_mask, + ) + with set_forward_context(current_timestep=0, attn_metadata=attn_metadata): + hidden_states_all, _ = self.attn(query_all, key_all, value_all, attention_mask=attention_mask) + + hidden_states_rope, hidden_states_prope = hidden_states_all.chunk(2, dim=0) + + # # DEBUG: Check attention output and apply_fn_o + # if self.training and torch.distributed.get_rank() == 0: + # attn_all_nz = (hidden_states_all != 0).sum().item() + # rope_nz = (hidden_states_rope != 0).sum().item() + # prope_before = (hidden_states_prope != 0).sum().item() + # print(f"[DEBUG] attn output: all_nonzero={attn_all_nz}, rope_nonzero={rope_nz}, prope_before_apply={prope_before}", flush=True) + + hidden_states_prope = apply_fn_o(hidden_states_prope.transpose(1, 2)).transpose(1, 2) + + # # DEBUG: Check after apply_fn_o + # if self.training and torch.distributed.get_rank() == 0: + # prope_after = (hidden_states_prope != 0).sum().item() + # print(f"[DEBUG] prope_after_apply_fn_o={prope_after}", flush=True) + + return hidden_states_rope, hidden_states_prope \ No newline at end of file diff --git a/fastvideo/models/dits/wangame/model.py b/fastvideo/models/dits/wangame/model.py new file mode 100644 index 000000000..61f539821 --- /dev/null +++ b/fastvideo/models/dits/wangame/model.py @@ -0,0 +1,422 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import torch +import torch.nn as nn + +from fastvideo.configs.models.dits.wangamevideo import WanGameVideoConfig +from fastvideo.distributed.parallel_state import get_sp_world_size +from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, + RMSNorm, ScaleResidual, + ScaleResidualLayerNormScaleShift) +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.mlp import MLP +from fastvideo.layers.rotary_embedding import (_apply_rotary_emb, + get_rotary_pos_embed) +from fastvideo.layers.visual_embedding import PatchEmbed +from fastvideo.logger import init_logger +from fastvideo.models.dits.base import BaseDiT +from fastvideo.models.dits.wanvideo import WanI2VCrossAttention +from fastvideo.platforms import AttentionBackendEnum, current_platform + +# Import ActionModule +from fastvideo.models.dits.wangame.hyworld_action_module import WanGameActionTimeImageEmbedding, WanGameActionSelfAttention + +logger = init_logger(__name__) + + +class WanGameCrossAttention(WanI2VCrossAttention): + def forward(self, x, context, context_lens=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d) + k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view( + b, -1, n, d) + v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d) + img_x = self.attn(q, k_img, v_img) + + # output + x = img_x.flatten(2) + x, _ = self.to_out(x) + return x + +class WanGameActionTransformerBlock(nn.Module): + """ + Transformer block for WAN Action model with support for: + - Self-attention with RoPE and camera PRoPE + - Cross-attention with text/image context + - Feed-forward network with AdaLN modulation + """ + + def __init__(self, + dim: int, + ffn_dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + prefix: str = ""): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + self.to_out = ReplicatedLinear(dim, dim, bias=True) + + self.attn1 = WanGameActionSelfAttention( + dim, + num_heads, + local_attn_size=local_attn_size, + sink_size=sink_size, + qk_norm=qk_norm, + eps=eps) + + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.local_attn_size = local_attn_size + dim_head = dim // num_heads + + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + raise ValueError(f"QK Norm type {qk_norm} not supported") + + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + compute_dtype=torch.float32) + + # 2. Cross-attention (I2V only for now) + self.attn2 = WanGameCrossAttention(dim, + num_heads, + qk_norm=qk_norm, + eps=eps) + # norm3 for FFN input + self.norm3 = LayerNormScaleShift(dim, norm_type="layer", eps=eps, + elementwise_affine=False) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # PRoPE output projection (initialized via add_discrete_action_parameters on the model) + self.to_out_prope = ReplicatedLinear(dim, dim, bias=True) + nn.init.zeros_(self.to_out_prope.weight) + if self.to_out_prope.bias is not None: + nn.init.zeros_(self.to_out_prope.bias) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + kv_cache: dict | None = None, + crossattn_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + is_cache: bool = False, + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + + num_frames = temb.shape[1] + frame_seqlen = hidden_states.shape[1] // num_frames + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + + # Cast temb to float32 for scale/shift computation + e = self.scale_shift_table + temb.float() + assert e.shape == (bs, num_frames, 6, self.hidden_dim) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(6, dim=2) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * + (1 + scale_msa) + shift_msa).to(orig_dtype).flatten(1, 2) + + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q.forward_native(query) + if self.norm_k is not None: + key = self.norm_k.forward_native(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + # Self-attention with camera PRoPE + attn_output_rope, attn_output_prope = self.attn1( + query, key, value, freqs_cis, + kv_cache, current_start, cache_start, viewmats, Ks, + is_cache=is_cache + ) + # Combine rope and prope outputs + attn_output_rope = attn_output_rope.flatten(2) + attn_output_rope, _ = self.to_out(attn_output_rope) + attn_output_prope = attn_output_prope.flatten(2) + + # # DEBUG: Check if prope input is zero + # if self.training and torch.distributed.get_rank() == 0: + # prope_nonzero = (attn_output_prope != 0).sum().item() + # prope_total = attn_output_prope.numel() + # if prope_nonzero == 0: + # print(f"[DEBUG] to_out_prope INPUT is ALL ZEROS! shape={attn_output_prope.shape}", flush=True) + + attn_output_prope, _ = self.to_out_prope(attn_output_prope) + attn_output = attn_output_rope.squeeze(1) + attn_output_prope.squeeze(1) + + # Self-attention residual + norm in float32 + null_shift = null_scale = torch.zeros(1, device=hidden_states.device, dtype=torch.float32) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states.float(), attn_output.float(), gate_msa, null_shift, null_scale) + hidden_states = hidden_states.type_as(attn_output) + norm_hidden_states = norm_hidden_states.type_as(attn_output) + + # 2. Cross-attention + attn_output = self.attn2(norm_hidden_states.to(orig_dtype), + context=encoder_hidden_states, + context_lens=None) + # Cross-attention residual in bfloat16 + hidden_states = hidden_states + attn_output + + # norm3 for FFN input in float32 + norm_hidden_states = self.norm3( + hidden_states.float(), c_shift_msa, c_scale_msa + ).type_as(hidden_states) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states.to(orig_dtype)) + hidden_states = self.mlp_residual(hidden_states.float(), ff_output.float(), c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) # Cast back to original dtype + + return hidden_states + +class WanGameActionTransformer3DModel(BaseDiT): + """ + WAN Action Transformer 3D Model for video generation with action conditioning. + + Extends the base WAN video model with: + - Action embedding support for controllable generation + - camera PRoPE attention for 3D-aware generation + - KV caching for autoregressive inference + """ + _fsdp_shard_conditions = WanGameVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanGameVideoConfig()._compile_conditions + _supported_attention_backends = WanGameVideoConfig()._supported_attention_backends + param_names_mapping = WanGameVideoConfig().param_names_mapping + reverse_param_names_mapping = WanGameVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanGameVideoConfig().lora_param_names_mapping + + def __init__(self, config: WanGameVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.local_attn_size = config.local_attn_size + self.inner_dim = inner_dim + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed(in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False) + + # 2. Condition embeddings (with action support) + self.condition_embedder = WanGameActionTimeImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + WanGameActionTransformerBlock( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.local_attn_size, + config.sink_size, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.blocks.{i}") + for i in range(config.num_layers) + ]) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift(inner_dim, + norm_type="layer", + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + # Causal-specific + self.num_frame_per_block = config.arch_config.num_frames_per_block + assert self.num_frame_per_block <= 3 + + self.__post_init__() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor], + guidance=None, + action: torch.Tensor | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + kv_cache: list[dict] | None = None, + crossattn_cache: list[dict] | None = None, + current_start: int = 0, + cache_start: int = 0, + start_frame: int = 0, + is_cache: bool = False, + **kwargs + ) -> torch.Tensor: + """ + Forward pass for both training and inference with KV caching. + + Args: + hidden_states: Video latents [B, C, T, H, W] + encoder_hidden_states: Text embeddings [B, L, D] + timestep: Timestep tensor + encoder_hidden_states_image: Optional image embeddings + action: Action tensor [B, T] for per-frame conditioning + viewmats: Camera view matrices for PRoPE [B, T, 4, 4] + Ks: Camera intrinsics for PRoPE [B, T, 3, 3] + kv_cache: KV cache for autoregressive inference (list of dicts per layer) + crossattn_cache: Cross-attention cache for inference + current_start: Current position for KV cache + cache_start: Cache start position + start_frame: RoPE offset for new frames in autoregressive mode + is_cache: If True, populate KV cache and return early (cache-only mode) + """ + orig_dtype = hidden_states.dtype + # if not isinstance(encoder_hidden_states, torch.Tensor): + # encoder_hidden_states = encoder_hidden_states[0] + if isinstance(encoder_hidden_states_image, list) and len(encoder_hidden_states_image) > 0: + encoder_hidden_states_image = encoder_hidden_states_image[0] + # else: + # encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + (post_patch_num_frames * get_sp_world_size(), post_patch_height, post_patch_width), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + if timestep.dim() == 2: + timestep = timestep.flatten() + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, action, encoder_hidden_states, encoder_hidden_states_image=encoder_hidden_states_image) + + # condition_embedder returns: + # - temb: [B*T, dim] where T = post_patch_num_frames + # - timestep_proj: [B*T, 6*dim] + # Reshape to [B, T, 6, dim] for transformer blocks + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)) # [B*T, 6, dim] + timestep_proj = timestep_proj.view(batch_size, post_patch_num_frames, 6, self.hidden_size) # [B, T, 6, dim] + + encoder_hidden_states = encoder_hidden_states_image + + # Transformer blocks + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + kv_cache[block_idx] if kv_cache else None, + crossattn_cache[block_idx] if crossattn_cache else None, + current_start, cache_start, + viewmats, Ks, is_cache) + else: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + kv_cache[block_idx] if kv_cache else None, + crossattn_cache[block_idx] if crossattn_cache else None, + current_start, cache_start, + viewmats, Ks, is_cache) + + # If cache-only mode, return early + if is_cache: + return kv_cache + + # Output norm, projection & unpatchify + # temb is [B*T, dim], reshape to [B, T, 1, dim] + temb = temb.view(batch_size, post_patch_num_frames, -1).unsqueeze(2) # [B, T, 1, dim] + + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, + post_patch_height, + post_patch_width, p_t, p_h, p_w, + -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output \ No newline at end of file diff --git a/fastvideo/models/dits/wangame_lingbot/__init__.py b/fastvideo/models/dits/wangame_lingbot/__init__.py new file mode 100644 index 000000000..ad549c889 --- /dev/null +++ b/fastvideo/models/dits/wangame_lingbot/__init__.py @@ -0,0 +1,5 @@ +from .model import WanLingBotTransformer3DModel + +__all__ = [ + "WanLingBotTransformer3DModel", +] diff --git a/fastvideo/models/dits/wangame_lingbot/cam_utils.py b/fastvideo/models/dits/wangame_lingbot/cam_utils.py new file mode 100644 index 000000000..fb72ec84a --- /dev/null +++ b/fastvideo/models/dits/wangame_lingbot/cam_utils.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from LingBot World: https://github.com/Robbyant/lingbot-world/blob/main/wan/utils/cam_utils.py + +import numpy as np +import os +import torch +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation, Slerp + + +# --- Official Code (Leave Unchanged) --- + +def interpolate_camera_poses( + src_indices: np.ndarray, + src_rot_mat: np.ndarray, + src_trans_vec: np.ndarray, + tgt_indices: np.ndarray, +) -> torch.Tensor: + # interpolate translation + interp_func_trans = interp1d( + src_indices, + src_trans_vec, + axis=0, + kind='linear', + bounds_error=False, + fill_value="extrapolate", + ) + interpolated_trans_vec = interp_func_trans(tgt_indices) + + # interpolate rotation + src_quat_vec = Rotation.from_matrix(src_rot_mat) + # ensure there is no sudden change in qw + quats = src_quat_vec.as_quat().copy() # [N, 4] + for i in range(1, len(quats)): + if np.dot(quats[i], quats[i-1]) < 0: + quats[i] = -quats[i] + src_quat_vec = Rotation.from_quat(quats) + slerp_func_rot = Slerp(src_indices, src_quat_vec) + interpolated_rot_quat = slerp_func_rot(tgt_indices) + interpolated_rot_mat = interpolated_rot_quat.as_matrix() + + poses = np.zeros((len(tgt_indices), 4, 4)) + poses[:, :3, :3] = interpolated_rot_mat + poses[:, :3, 3] = interpolated_trans_vec + poses[:, 3, 3] = 1.0 + return torch.from_numpy(poses).float() + + +def SE3_inverse(T: torch.Tensor) -> torch.Tensor: + Rot = T[:, :3, :3] # [B,3,3] + trans = T[:, :3, 3:] # [B,3,1] + R_inv = Rot.transpose(-1, -2) + t_inv = -torch.bmm(R_inv, trans) + T_inv = torch.eye(4, device=T.device, dtype=T.dtype)[None, :, :].repeat(T.shape[0], 1, 1) + T_inv[:, :3, :3] = R_inv + T_inv[:, :3, 3:] = t_inv + return T_inv + + +def compute_relative_poses( + c2ws_mat: torch.Tensor, + framewise: bool = False, + normalize_trans: bool = True, +) -> torch.Tensor: + ref_w2cs = SE3_inverse(c2ws_mat[0:1]) + relative_poses = torch.matmul(ref_w2cs, c2ws_mat) + # ensure identity matrix for 1st frame + relative_poses[0] = torch.eye(4, device=c2ws_mat.device, dtype=c2ws_mat.dtype) + if framewise: + # compute pose between i and i+1 + relative_poses_framewise = torch.bmm(SE3_inverse(relative_poses[:-1]), relative_poses[1:]) + relative_poses[1:] = relative_poses_framewise + if normalize_trans: # note refer to camctrl2: "we scale the coordinate inputs to roughly 1 standard deviation to simplify model learning." + translations = relative_poses[:, :3, 3] # [f, 3] + max_norm = torch.norm(translations, dim=-1).max() + # only normlaize when moving + if max_norm > 0: + relative_poses[:, :3, 3] = translations / max_norm + return relative_poses + + +@torch.no_grad() +def create_meshgrid(n_frames: int, height: int, width: int, bias: float = 0.5, device='cuda', dtype=torch.float32) -> torch.Tensor: + x_range = torch.arange(width, device=device, dtype=dtype) + y_range = torch.arange(height, device=device, dtype=dtype) + grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing='ij') + grid_xy = torch.stack([grid_x, grid_y], dim=-1).view([-1, 2]) + bias # [h*w, 2] + grid_xy = grid_xy[None, ...].repeat(n_frames, 1, 1) # [f, h*w, 2] + return grid_xy + + +def get_plucker_embeddings( + c2ws_mat: torch.Tensor, + Ks: torch.Tensor, + height: int, + width: int, +): + n_frames = c2ws_mat.shape[0] + grid_xy = create_meshgrid(n_frames, height, width, device=c2ws_mat.device, dtype=c2ws_mat.dtype) # [f, h*w, 2] + fx, fy, cx, cy = Ks.chunk(4, dim=-1) # [f, 1] + + i = grid_xy[..., 0] # [f, h*w] + j = grid_xy[..., 1] # [f, h*w] + zs = torch.ones_like(i) # [f, h*w] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + + directions = torch.stack([xs, ys, zs], dim=-1) # [f, h*w, 3] + directions = directions / directions.norm(dim=-1, keepdim=True) # [f, h*w, 3] + + rays_d = directions @ c2ws_mat[:, :3, :3].transpose(-1, -2) # [f, h*w, 3] + rays_o = c2ws_mat[:, :3, 3] # [f, 3] + rays_o = rays_o[:, None, :].expand_as(rays_d) # [f, h*w, 3] + # rays_dxo = torch.cross(rays_o, rays_d, dim=-1) # [f, h*w, 3] + # note refer to: apt2 + plucker_embeddings = torch.cat([rays_o, rays_d], dim=-1) # [f, h*w, 6] + plucker_embeddings = plucker_embeddings.view([n_frames, height, width, 6]) # [f*h*w, 6] + return plucker_embeddings + + +def get_Ks_transformed( + Ks: torch.Tensor, + height_org: int, + width_org: int, + height_resize: int, + width_resize: int, + height_final: int, + width_final: int, +): + fx, fy, cx, cy = Ks.chunk(4, dim=-1) # [f, 1] + + scale_x = width_resize / width_org + scale_y = height_resize / height_org + + fx_resize = fx * scale_x + fy_resize = fy * scale_y + cx_resize = cx * scale_x + cy_resize = cy * scale_y + + crop_offset_x = (width_resize - width_final) / 2 + crop_offset_y = (height_resize - height_final) / 2 + + cx_final = cx_resize - crop_offset_x + cy_final = cy_resize - crop_offset_y + + Ks_transformed = torch.zeros_like(Ks) + Ks_transformed[:, 0:1] = fx_resize + Ks_transformed[:, 1:2] = fy_resize + Ks_transformed[:, 2:3] = cx_final + Ks_transformed[:, 3:4] = cy_final + + return Ks_transformed + + +# --- Custom --- + +def prepare_camera_embedding( + action_path: str, + num_frames: int, + height: int, + width: int, + spatial_scale: int = 8, +) -> tuple[torch.Tensor, int]: + c2ws = np.load(os.path.join(action_path, "poses.npy")) + len_c2ws = ((len(c2ws) - 1) // 4) * 4 + 1 + num_frames = min(num_frames, len_c2ws) + c2ws = c2ws[:num_frames] + + Ks = torch.from_numpy( + np.load(os.path.join(action_path, "intrinsics.npy")) + ).float() + Ks = get_Ks_transformed( + Ks, + height_org=480, + width_org=832, + height_resize=height, + width_resize=width, + height_final=height, + width_final=width, + ) + Ks = Ks[0] # use first frame + + len_c2ws = len(c2ws) + num_latent_frames = (len_c2ws - 1) // 4 + 1 + c2ws_infer = interpolate_camera_poses( + src_indices=np.linspace(0, len_c2ws - 1, len_c2ws), + src_rot_mat=c2ws[:, :3, :3], + src_trans_vec=c2ws[:, :3, 3], + tgt_indices=np.linspace(0, len_c2ws - 1, num_latent_frames), + ) + c2ws_infer = compute_relative_poses(c2ws_infer, framewise=True) + Ks = Ks.repeat(num_latent_frames, 1) + plucker = get_plucker_embeddings(c2ws_infer, Ks, height, width) # [F, H, W, 6] + + # reshpae + latent_height = height // spatial_scale + latent_width = width // spatial_scale + plucker = plucker.view(num_latent_frames, latent_height, spatial_scale, latent_width, spatial_scale, 6) + plucker = plucker.permute(0, 1, 3, 5, 2, 4).contiguous() + plucker = plucker.view(num_latent_frames, latent_height, latent_width, 6 * spatial_scale * spatial_scale) + c2ws_plucker_emb = plucker.permute(3, 0, 1, 2).contiguous().unsqueeze(0) + + return c2ws_plucker_emb, num_frames \ No newline at end of file diff --git a/fastvideo/models/dits/wangame_lingbot/model.py b/fastvideo/models/dits/wangame_lingbot/model.py new file mode 100644 index 000000000..a19c9b351 --- /dev/null +++ b/fastvideo/models/dits/wangame_lingbot/model.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fastvideo.attention import DistributedAttention +from fastvideo.configs.models.dits.wangamevideo import WanGameVideoConfig +from fastvideo.distributed.parallel_state import get_sp_world_size +from fastvideo.layers.layernorm import (FP32LayerNorm, LayerNormScaleShift, + RMSNorm, ScaleResidual, + ScaleResidualLayerNormScaleShift) +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.mlp import MLP +from fastvideo.layers.rotary_embedding import get_rotary_pos_embed +from fastvideo.layers.visual_embedding import (PatchEmbed, + WanCamControlPatchEmbedding) +from fastvideo.logger import init_logger +from fastvideo.models.dits.base import BaseDiT +from fastvideo.models.dits.wanvideo import (WanI2VCrossAttention, + WanTimeTextImageEmbedding) +from fastvideo.platforms import AttentionBackendEnum, current_platform + + +logger = init_logger(__name__) + + +class LingBotWorldCamConditioner(nn.Module): + + def __init__(self, dim: int) -> None: + super().__init__() + self.cam_injector = MLP(dim, dim, dim, bias=True, act_type="silu") + self.cam_scale_layer = nn.Linear(dim, dim) + self.cam_shift_layer = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + c2ws_plucker_emb: torch.Tensor | None, + ) -> torch.Tensor: + if c2ws_plucker_emb is None: + return hidden_states + assert c2ws_plucker_emb.shape == hidden_states.shape, ( + f"c2ws_plucker_emb shape must match hidden_states shape, got " + f"{tuple(c2ws_plucker_emb.shape)} vs {tuple(hidden_states.shape)}" + ) + c2ws_hidden_states = self.cam_injector(c2ws_plucker_emb) + c2ws_hidden_states = c2ws_hidden_states + c2ws_plucker_emb + cam_scale = self.cam_scale_layer(c2ws_hidden_states) + cam_shift = self.cam_shift_layer(c2ws_hidden_states) + return (1.0 + cam_scale) * hidden_states + cam_shift + + +class WanGameCrossAttention(WanI2VCrossAttention): + def forward(self, x, context, context_lens=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d) + k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view( + b, -1, n, d) + v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d) + img_x = self.attn(q, k_img, v_img) + + # output + x = img_x.flatten(2) + x, _ = self.to_out(x) + return x + + +class WanGameActionTransformerBlock(nn.Module): + """ + Transformer block for WAN Action model with support for: + - Self-attention with RoPE and camera PRoPE + - Cross-attention with text/image context + - Feed-forward network with AdaLN modulation + """ + + def __init__(self, + dim: int, + ffn_dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + prefix: str = ""): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + + self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.attn1 = DistributedAttention( + num_heads=num_heads, + head_size=dim // num_heads, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1") + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.local_attn_size = local_attn_size + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + print("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + compute_dtype=torch.float32) + + # 2. Cross-attention (I2V only for now) + self.attn2 = WanGameCrossAttention(dim, + num_heads, + qk_norm=qk_norm, + eps=eps) + # norm3 for FFN input + self.norm3 = LayerNormScaleShift(dim, norm_type="layer", eps=eps, + elementwise_affine=False) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.cam_conditioner = LingBotWorldCamConditioner(dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + kv_cache: dict | None = None, + crossattn_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + c2ws_plucker_emb: torch.Tensor | None = None, + is_cache: bool = False, + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + + num_frames = temb.shape[1] + frame_seqlen = hidden_states.shape[1] // num_frames + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + + # Cast temb to float32 for scale/shift computation + e = self.scale_shift_table + temb.float() + assert e.shape == (bs, num_frames, 6, self.hidden_dim) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(6, dim=2) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()).unflatten( + dim=1, sizes=(num_frames, frame_seqlen)) * + (1 + scale_msa) + shift_msa).to(orig_dtype).flatten(1, 2) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + attn_output, _ = self.attn1(query, key, value, freqs_cis=freqs_cis) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + # Self-attention residual + norm in float32 + null_shift = null_scale = torch.tensor([0], device=hidden_states.device) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype), hidden_states.to(orig_dtype) + # Inject camera condition + # must be applied after the self-attention residual update. + hidden_states = self.cam_conditioner(hidden_states, c2ws_plucker_emb) + norm_hidden_states = self.self_attn_residual_norm.norm(hidden_states) + norm_hidden_states = norm_hidden_states.to(orig_dtype) + + + # 2. Cross-attention + attn_output = self.attn2(norm_hidden_states.to(orig_dtype), + context=encoder_hidden_states, + context_lens=None) + # Cross-attention residual in bfloat16 + hidden_states = hidden_states + attn_output + + # norm3 for FFN input in float32 + norm_hidden_states = self.norm3( + hidden_states.float(), c_shift_msa, c_scale_msa + ).type_as(hidden_states) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states.to(orig_dtype)) + hidden_states = self.mlp_residual(hidden_states.float(), ff_output.float(), c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) # Cast back to original dtype + + return hidden_states + +class WanLingBotTransformer3DModel(BaseDiT): + """ + WAN Action Transformer 3D Model for video generation with action conditioning. + + Extends the base WAN video model with: + - Action embedding support for controllable generation + - camera PRoPE attention for 3D-aware generation + - KV caching for autoregressive inference + """ + _fsdp_shard_conditions = WanGameVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanGameVideoConfig()._compile_conditions + _supported_attention_backends = WanGameVideoConfig()._supported_attention_backends + param_names_mapping = WanGameVideoConfig().param_names_mapping + reverse_param_names_mapping = WanGameVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanGameVideoConfig().lora_param_names_mapping + + def __init__(self, config: WanGameVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.local_attn_size = config.local_attn_size + self.inner_dim = inner_dim + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed(in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False) + self.patch_embedding_wancamctrl = WanCamControlPatchEmbedding(in_chans=6 * 64, + embed_dim=inner_dim, + patch_size=config.patch_size) + self.c2ws_mlp = MLP(inner_dim, inner_dim, inner_dim, bias=True, act_type="silu") + + # 2. Condition embeddings (image-only) + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + text_embed_dim=0, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + WanGameActionTransformerBlock( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.local_attn_size, + config.sink_size, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.blocks.{i}") + for i in range(config.num_layers) + ]) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift(inner_dim, + norm_type="layer", + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + # Causal-specific + self.num_frame_per_block = config.arch_config.num_frames_per_block + assert self.num_frame_per_block <= 3 + + self.__post_init__() + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor], + guidance=None, + action: torch.Tensor | None = None, + viewmats: torch.Tensor | None = None, + Ks: torch.Tensor | None = None, + c2ws_plucker_emb: torch.Tensor | None = None, + kv_cache: list[dict] | None = None, + crossattn_cache: list[dict] | None = None, + current_start: int = 0, + cache_start: int = 0, + start_frame: int = 0, + is_cache: bool = False, + **kwargs + ) -> torch.Tensor: + """ + Forward pass for both training and inference with KV caching. + + Args: + hidden_states: Video latents [B, C, T, H, W] + encoder_hidden_states: Text embeddings [B, L, D] + timestep: Timestep tensor + encoder_hidden_states_image: Optional image embeddings + action: Action tensor [B, T] for per-frame conditioning + viewmats: Camera view matrices for PRoPE [B, T, 4, 4] + Ks: Camera intrinsics for PRoPE [B, T, 3, 3] + c2ws_plucker_emb: Camera plucker embedding [B, C, T, H, W] + kv_cache: KV cache for autoregressive inference (list of dicts per layer) + crossattn_cache: Cross-attention cache for inference + current_start: Current position for KV cache + cache_start: Cache start position + start_frame: RoPE offset for new frames in autoregressive mode + is_cache: If True, populate KV cache and return early (cache-only mode) + """ + orig_dtype = hidden_states.dtype + if isinstance(encoder_hidden_states, list) and len(encoder_hidden_states) > 0: + encoder_hidden_states = encoder_hidden_states[0] + if isinstance(encoder_hidden_states_image, list) and len(encoder_hidden_states_image) > 0: + encoder_hidden_states_image = encoder_hidden_states_image[0] + # else: + # encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + (post_patch_num_frames * get_sp_world_size(), post_patch_height, post_patch_width), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + c2ws_hidden_states = None + if c2ws_plucker_emb is not None: + c2ws_plucker_emb = self.patch_embedding_wancamctrl( + c2ws_plucker_emb.to(device=hidden_states.device, dtype=hidden_states.dtype) + ) + c2ws_hidden_states = self.c2ws_mlp(c2ws_plucker_emb) + c2ws_plucker_emb = c2ws_plucker_emb + c2ws_hidden_states + + if timestep.dim() == 1: + timestep = timestep.unsqueeze(1).expand(-1, post_patch_num_frames) + if timestep.dim() == 2: + timestep = timestep.flatten() + + if encoder_hidden_states is None or ( + isinstance(encoder_hidden_states, torch.Tensor) + and encoder_hidden_states.numel() == 0): + encoder_hidden_states = hidden_states.new_zeros((batch_size, 0, self.hidden_size)) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image=encoder_hidden_states_image) + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)) + timestep_proj = timestep_proj.view(batch_size, post_patch_num_frames, 6, + self.hidden_size) + + encoder_hidden_states = encoder_hidden_states_image + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states.new_zeros((batch_size, 0, self.hidden_size)) + + # Transformer blocks + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + kv_cache[block_idx] if kv_cache else None, + crossattn_cache[block_idx] if crossattn_cache else None, + current_start, cache_start, + viewmats, Ks, c2ws_hidden_states, is_cache) + else: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, freqs_cis, + kv_cache[block_idx] if kv_cache else None, + crossattn_cache[block_idx] if crossattn_cache else None, + current_start, cache_start, + viewmats, Ks, c2ws_hidden_states, is_cache) + + # If cache-only mode, return early + if is_cache: + return kv_cache + + # Output norm, projection & unpatchify + temb = temb.view(batch_size, post_patch_num_frames, -1).unsqueeze(2) + + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, + post_patch_height, + post_patch_width, p_t, p_h, p_w, + -1) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output diff --git a/fastvideo/models/loader/component_loader.py b/fastvideo/models/loader/component_loader.py index 6ee1b28c3..35dccb014 100644 --- a/fastvideo/models/loader/component_loader.py +++ b/fastvideo/models/loader/component_loader.py @@ -23,7 +23,6 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.layers.quantization import get_quantization_config from fastvideo.logger import init_logger -from fastvideo.models.encoders.base import TextEncoder from fastvideo.models.hf_transformer_utils import get_diffusers_config from fastvideo.models.loader.fsdp_load import maybe_load_fsdp_model, shard_model from fastvideo.models.loader.utils import set_default_torch_dtype @@ -268,11 +267,12 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): gemma_path = candidate gemma_path_from_candidate = True model_config["gemma_model_path"] = gemma_path - if gemma_path and not gemma_path_from_candidate: - if not os.path.isabs(gemma_path): - model_config["gemma_model_path"] = os.path.normpath( - os.path.join(repo_root, gemma_path) - ) + if gemma_path and not gemma_path_from_candidate and not os.path.isabs( + gemma_path + ): + model_config["gemma_model_path"] = os.path.normpath( + os.path.join(repo_root, gemma_path) + ) transformer_config_path = os.path.join( repo_root, "transformer", "config.json" ) @@ -280,12 +280,11 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): try: with open(transformer_config_path, encoding="utf-8") as f: transformer_config = json.load(f) - if ( + if (( "connector_double_precision_rope" not in model_config or not model_config["connector_double_precision_rope"] - ): - if transformer_config.get("double_precision_rope") is True: - model_config["connector_double_precision_rope"] = True + ) and transformer_config.get("double_precision_rope") is True): + model_config["connector_double_precision_rope"] = True if "connector_rope_type" not in model_config: rope_type = transformer_config.get("rope_type") if rope_type is not None: @@ -539,7 +538,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): tokenizer_cfg_path = os.path.join(resolved_model_path, "config.json") if os.path.exists(tokenizer_cfg_path): try: - with open(tokenizer_cfg_path, "r") as f: + with open(tokenizer_cfg_path) as f: tokenizer_cfg = json.load(f) if isinstance(tokenizer_cfg, dict) and ( tokenizer_cfg.get("_class_name") == "AutoProcessor" @@ -844,6 +843,15 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): cls_name.startswith("Cosmos25") or cls_name == "Cosmos25Transformer3DModel" or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25" + ) and not ( + cls_name.startswith("WanGame") + or cls_name == "WanGameActionTransformer3DModel" + or cls_name.startswith("CausalWan") + or getattr(fastvideo_args.pipeline_config, "prefix", "") == "WanGame" + or cls_name.startswith("WanLingBot") + or cls_name == "WanLingBotTransformer3DModel" + or getattr(fastvideo_args.pipeline_config, "prefix", "") == "WanLingBot" + or cls_name.startswith("CausalWanGameActionTransformer3DModel") ) model = maybe_load_fsdp_model( model_cls=model_cls, @@ -928,7 +936,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): try: upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[0]) upsampler_cfg.update_model_config(config_dict) - except Exception as e: + except Exception: upsampler_cfg = deepcopy(fastvideo_args.pipeline_config.upsampler_config[1]) upsampler_cfg.update_model_config(config_dict) diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index 9ba60320a..d9a3b6150 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -138,7 +138,7 @@ def maybe_load_fsdp_model( weight_iterator = safetensors_weights_iterator(weight_dir_list) param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) - load_model_from_full_model_state_dict( + incompatible_keys, unexpected_keys = load_model_from_full_model_state_dict( model, weight_iterator, device, @@ -147,6 +147,9 @@ def maybe_load_fsdp_model( cpu_offload=cpu_offload, param_names_mapping=param_names_mapping_fn, ) + if incompatible_keys or unexpected_keys: + logger.warning("Incompatible keys: %s", incompatible_keys) + logger.warning("Unexpected keys: %s", unexpected_keys) for n, p in chain(model.named_parameters(), model.named_buffers()): if p.is_meta: raise RuntimeError( @@ -339,8 +342,19 @@ def load_model_from_full_model_state_dict( logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys) - # List of allowed parameter name patterns - ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress", "proj_l"] # Can be extended as needed + # List of allowed parameter name patterns (whitelist for new params not in checkpoint) + ALLOWED_NEW_PARAM_PATTERNS = [ + "gate_compress", + "proj_l", + "to_out_prope", + "action_embedder", + "patch_embedding_wancamctrl", + "cam_conditioner", + ] # Can be extended as needed + + # Patterns for params that need kaiming_uniform init (input projections need non-zero for gradient flow) + KAIMING_INIT_PATTERNS = ["fc_in.weight"] + for new_param_name in unused_keys: if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): @@ -350,17 +364,31 @@ def load_model_from_full_model_state_dict( f"New parameter '{new_param_name}' is not supported. " f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed." ) + + # Check if this param needs kaiming init (non-zero) for gradient flow + use_kaiming = any(pattern in new_param_name for pattern in KAIMING_INIT_PATTERNS) + meta_sharded_param = meta_sd.get(new_param_name) if not hasattr(meta_sharded_param, "device_mesh"): - # Initialize with zeros - sharded_tensor = torch.zeros_like(meta_sharded_param, - device=device, - dtype=param_dtype) + # Non-sharded tensor + if use_kaiming: + import math + sharded_tensor = torch.empty_like(meta_sharded_param, device=device, dtype=param_dtype) + nn.init.kaiming_uniform_(sharded_tensor, a=math.sqrt(5)) + logger.info(f"Initialized {new_param_name} with kaiming_uniform_") + else: + # Initialize with zeros (output projections for residual behavior) + sharded_tensor = torch.zeros_like(meta_sharded_param, device=device, dtype=param_dtype) else: - # Initialize with zeros and distribute - full_tensor = torch.zeros_like(meta_sharded_param, - device=device, - dtype=param_dtype) + # Sharded tensor (DTensor) + if use_kaiming: + import math + full_tensor = torch.empty_like(meta_sharded_param, device=device, dtype=param_dtype) + nn.init.kaiming_uniform_(full_tensor, a=math.sqrt(5)) + logger.info(f"Initialized {new_param_name} with kaiming_uniform_") + else: + # Initialize with zeros and distribute + full_tensor = torch.zeros_like(meta_sharded_param, device=device, dtype=param_dtype) sharded_tensor = distribute_tensor( full_tensor, meta_sharded_param.device_mesh, diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index a22a582e0..d9f2e8691 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -46,6 +46,12 @@ # "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), "CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"), + "CausalWanGameTransformer3DModel": + ("dits", "wangame", "CausalWanGameActionTransformer3DModel"), + "CausalWanGameActionTransformer3DModel": + ("dits", "wangame", "CausalWanGameActionTransformer3DModel"), + "WanGameActionTransformer3DModel": ("dits", "wangame", "WanGameActionTransformer3DModel"), + "WanLingBotTransformer3DModel": ("dits", "wangame_lingbot", "WanLingBotTransformer3DModel"), "MatrixGameWanModel": ("dits", "matrixgame", "MatrixGameWanModel"), "CausalMatrixGameWanModel": ("dits", "matrixgame", "CausalMatrixGameWanModel"), } diff --git a/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py b/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py index 7cffcb0f3..7a5106126 100644 --- a/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_dmd_pipeline.py @@ -1,74 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 """ -Wan video diffusion pipeline implementation. +Legacy Wan DMD pipeline entrypoint. -This module contains an implementation of the Wan video diffusion pipeline -using the modular pipeline architecture. +Historically FastVideo exposed a dedicated `WanDMDPipeline` class that wired a +stochastic (SDE-style) denoising loop. Phase 3.2 makes sampling loop selection +explicit via `pipeline_config.sampler_kind`, so this file becomes a thin +compatibility wrapper around `WanPipeline`. """ from fastvideo.fastvideo_args import FastVideoArgs -from fastvideo.logger import init_logger -from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( - FlowMatchEulerDiscreteScheduler) -from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline +from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline -# isort: off -from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, - DmdDenoisingStage, InputValidationStage, - LatentPreparationStage, - TextEncodingStage, - TimestepPreparationStage) -# isort: on -logger = init_logger(__name__) - - -class WanDMDPipeline(LoRAPipeline, ComposedPipelineBase): - """ - Wan video diffusion pipeline with LoRA support. - """ - - _required_config_modules = [ - "text_encoder", "tokenizer", "vae", "transformer", "scheduler" - ] +class WanDMDPipeline(WanPipeline): + """Compatibility wrapper for SDE sampling on Wan.""" def initialize_pipeline(self, fastvideo_args: FastVideoArgs): - - self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( - shift=fastvideo_args.pipeline_config.flow_shift) + fastvideo_args.pipeline_config.sampler_kind = "sde" + return super().initialize_pipeline(fastvideo_args) def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: - """Set up pipeline stages with proper dependency injection.""" - - self.add_stage(stage_name="input_validation_stage", - stage=InputValidationStage()) - - self.add_stage(stage_name="prompt_encoding_stage", - stage=TextEncodingStage( - text_encoders=[self.get_module("text_encoder")], - tokenizers=[self.get_module("tokenizer")], - )) - - self.add_stage(stage_name="conditioning_stage", - stage=ConditioningStage()) - - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) - - self.add_stage(stage_name="latent_preparation_stage", - stage=LatentPreparationStage( - scheduler=self.get_module("scheduler"), - transformer=self.get_module("transformer", None), - use_btchw_layout=True)) - - self.add_stage(stage_name="denoising_stage", - stage=DmdDenoisingStage( - transformer=self.get_module("transformer"), - scheduler=self.get_module("scheduler"))) - - self.add_stage(stage_name="decoding_stage", - stage=DecodingStage(vae=self.get_module("vae"))) + fastvideo_args.pipeline_config.sampler_kind = "sde" + return super().create_pipeline_stages(fastvideo_args) EntryClass = WanDMDPipeline diff --git a/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py b/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py index ed4d870c6..dd3ff1538 100644 --- a/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_i2v_dmd_pipeline.py @@ -12,10 +12,12 @@ from fastvideo.pipelines.lora_pipeline import LoRAPipeline # isort: off -from fastvideo.pipelines.stages import ( - ImageEncodingStage, ConditioningStage, DecodingStage, DmdDenoisingStage, - ImageVAEEncodingStage, InputValidationStage, LatentPreparationStage, - TextEncodingStage, TimestepPreparationStage) +from fastvideo.pipelines.stages import (ImageEncodingStage, ConditioningStage, + DecodingStage, DmdDenoisingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage) # isort: on from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler) @@ -55,10 +57,6 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) - self.add_stage(stage_name="latent_preparation_stage", stage=LatentPreparationStage( scheduler=self.get_module("scheduler"), diff --git a/fastvideo/pipelines/basic/wan/wan_pipeline.py b/fastvideo/pipelines/basic/wan/wan_pipeline.py index 64c4a0685..78c2e4dad 100644 --- a/fastvideo/pipelines/basic/wan/wan_pipeline.py +++ b/fastvideo/pipelines/basic/wan/wan_pipeline.py @@ -8,13 +8,16 @@ from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.logger import init_logger -from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( - FlowUniPCMultistepScheduler) +from fastvideo.pipelines.samplers.wan import ( + build_wan_scheduler, + get_wan_sampler_kind, + wan_use_btchw_layout, +) from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage, LatentPreparationStage, - TextEncodingStage, + SdeDenoisingStage, TextEncodingStage, TimestepPreparationStage) logger = init_logger(__name__) @@ -30,12 +33,14 @@ class WanPipeline(LoRAPipeline, ComposedPipelineBase): ] def initialize_pipeline(self, fastvideo_args: FastVideoArgs): - # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers. - self.modules["scheduler"] = FlowUniPCMultistepScheduler( - shift=fastvideo_args.pipeline_config.flow_shift) + sampler_kind = get_wan_sampler_kind(fastvideo_args) + self.modules["scheduler"] = build_wan_scheduler(fastvideo_args, + sampler_kind) def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: """Set up pipeline stages with proper dependency injection.""" + sampler_kind = get_wan_sampler_kind(fastvideo_args) + use_btchw_layout = wan_use_btchw_layout(sampler_kind) self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage()) @@ -49,22 +54,32 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) - self.add_stage(stage_name="timestep_preparation_stage", - stage=TimestepPreparationStage( - scheduler=self.get_module("scheduler"))) + if sampler_kind == "ode": + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) self.add_stage(stage_name="latent_preparation_stage", stage=LatentPreparationStage( scheduler=self.get_module("scheduler"), - transformer=self.get_module("transformer", None))) - - self.add_stage(stage_name="denoising_stage", - stage=DenoisingStage( - transformer=self.get_module("transformer"), - transformer_2=self.get_module("transformer_2", None), - scheduler=self.get_module("scheduler"), - vae=self.get_module("vae"), - pipeline=self)) + transformer=self.get_module("transformer", None), + use_btchw_layout=use_btchw_layout)) + + if sampler_kind == "sde": + self.add_stage(stage_name="denoising_stage", + stage=SdeDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + )) + else: + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module( + "transformer_2", None), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"), + pipeline=self)) self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), diff --git a/fastvideo/pipelines/basic/wan/wangame_causal_dmd_pipeline.py b/fastvideo/pipelines/basic/wan/wangame_causal_dmd_pipeline.py new file mode 100644 index 000000000..af6191b4c --- /dev/null +++ b/fastvideo/pipelines/basic/wan/wangame_causal_dmd_pipeline.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WanGame causal DMD pipeline implementation.""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline +from fastvideo.pipelines.samplers.wan import get_wan_sampler_kind + +from fastvideo.pipelines.stages import ( + ConditioningStage, DecodingStage, MatrixGameCausalDenoisingStage, + MatrixGameCausalOdeDenoisingStage, MatrixGameImageEncodingStage, + InputValidationStage, LatentPreparationStage, TextEncodingStage, + TimestepPreparationStage) +from fastvideo.pipelines.stages.image_encoding import ( + MatrixGameImageVAEEncodingStage) + +logger = init_logger(__name__) + + +class WanGameCausalDMDPipeline(LoRAPipeline, ComposedPipelineBase): + _required_config_modules = [ + "vae", "transformer", "scheduler", "image_encoder", "image_processor" + ] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: + sampler_kind = get_wan_sampler_kind(fastvideo_args) + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + if (self.get_module("text_encoder", None) is not None + and self.get_module("tokenizer", None) is not None): + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + if (self.get_module("image_encoder", None) is not None + and self.get_module("image_processor", None) is not None): + self.add_stage( + stage_name="image_encoding_stage", + stage=MatrixGameImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + if sampler_kind == "ode": + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None))) + + self.add_stage( + stage_name="image_latent_preparation_stage", + stage=MatrixGameImageVAEEncodingStage(vae=self.get_module("vae"))) + + if sampler_kind == "ode": + denoising_stage = MatrixGameCausalOdeDenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2", None), + scheduler=self.get_module("scheduler"), + pipeline=self, + vae=self.get_module("vae"), + ) + else: + denoising_stage = MatrixGameCausalDenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2", None), + scheduler=self.get_module("scheduler"), + pipeline=self, + vae=self.get_module("vae"), + ) + + self.add_stage(stage_name="denoising_stage", stage=denoising_stage) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + logger.info("WanGameCausalDMDPipeline initialized with action support") + + +EntryClass = WanGameCausalDMDPipeline diff --git a/fastvideo/pipelines/basic/wan/wangame_i2v_pipeline.py b/fastvideo/pipelines/basic/wan/wangame_i2v_pipeline.py new file mode 100644 index 000000000..307bf48f9 --- /dev/null +++ b/fastvideo/pipelines/basic/wan/wangame_i2v_pipeline.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WanGame image-to-video pipeline implementation. + +This module contains an implementation of the WanGame image-to-video pipeline +using the modular pipeline architecture. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.lora_pipeline import LoRAPipeline +from fastvideo.pipelines.samplers.wan import ( + build_wan_scheduler, + get_wan_sampler_kind, + wan_use_btchw_layout, +) + +# isort: off +from fastvideo.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DenoisingStage, + ImageEncodingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + SdeDenoisingStage, + TimestepPreparationStage, +) + +# isort: on + +logger = init_logger(__name__) + + +class WanGameActionImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase): + + _required_config_modules = [ + "vae", + "transformer", + "scheduler", + "image_encoder", + "image_processor", + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + sampler_kind = get_wan_sampler_kind(fastvideo_args) + self.modules["scheduler"] = build_wan_scheduler(fastvideo_args, + sampler_kind) + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + + sampler_kind = get_wan_sampler_kind(fastvideo_args) + use_btchw_layout = wan_use_btchw_layout(sampler_kind) + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + ), + ) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + if sampler_kind == "ode": + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + use_btchw_layout=use_btchw_layout)) + + self.add_stage(stage_name="image_latent_preparation_stage", + stage=ImageVAEEncodingStage(vae=self.get_module("vae"))) + + if sampler_kind == "sde": + self.add_stage(stage_name="denoising_stage", + stage=SdeDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"))) + else: + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + +class WanLingBotImageToVideoPipeline(WanGameActionImageToVideoPipeline): + pass + + +EntryClass = [WanGameActionImageToVideoPipeline, WanLingBotImageToVideoPipeline] diff --git a/fastvideo/pipelines/pipeline_batch_info.py b/fastvideo/pipelines/pipeline_batch_info.py index ee9086e82..b6065dc7a 100644 --- a/fastvideo/pipelines/pipeline_batch_info.py +++ b/fastvideo/pipelines/pipeline_batch_info.py @@ -160,6 +160,10 @@ class ForwardBatch: # Timesteps timesteps: torch.Tensor | None = None + # Optional explicit denoising-loop timesteps (sampler-specific). + # When set, some samplers (e.g. SDE-style rollout) will iterate this list + # instead of `timesteps` produced by `TimestepPreparationStage`. + sampling_timesteps: torch.Tensor | None = None timestep: torch.Tensor | float | int | None = None step_index: int | None = None boundary_ratio: float | None = None diff --git a/fastvideo/pipelines/preprocess/v1_preprocess.py b/fastvideo/pipelines/preprocess/v1_preprocess.py index 18455d70f..e15e2239c 100644 --- a/fastvideo/pipelines/preprocess/v1_preprocess.py +++ b/fastvideo/pipelines/preprocess/v1_preprocess.py @@ -18,6 +18,10 @@ PreprocessPipeline_Text) from fastvideo.pipelines.preprocess.matrixgame.matrixgame_preprocess_pipeline import ( PreprocessPipeline_MatrixGame) +from fastvideo.pipelines.preprocess.wangame.wangame_preprocess_pipeline import ( + PreprocessPipeline_WanGame) +from fastvideo.pipelines.preprocess.wangame.wangame_preprocess_pipeline_ode_trajectory import ( + PreprocessPipeline_WanGame_ODE_Trajectory) from fastvideo.utils import maybe_download_model logger = init_logger(__name__) @@ -64,10 +68,16 @@ def main(args) -> None: PreprocessPipeline = PreprocessPipeline_ODE_Trajectory elif args.preprocess_task == "matrixgame": PreprocessPipeline = PreprocessPipeline_MatrixGame + elif args.preprocess_task == "wangame": + PreprocessPipeline = PreprocessPipeline_WanGame + elif args.preprocess_task == "wangame_ode_trajectory": + fastvideo_args.pipeline_config.flow_shift = args.flow_shift if args.flow_shift is not None else 5.0 + PreprocessPipeline = PreprocessPipeline_WanGame_ODE_Trajectory else: raise ValueError( f"Invalid preprocess task: {args.preprocess_task}. " - f"Valid options: t2v, i2v, ode_trajectory, text_only, matrixgame") + f"Valid options: t2v, i2v, ode_trajectory, text_only, matrixgame, wangame, wangame_ode_trajectory" + ) logger.info("Preprocess task: %s using %s", args.preprocess_task, PreprocessPipeline.__name__) @@ -111,12 +121,14 @@ def main(args) -> None: parser.add_argument("--group_frame", action="store_true") # TODO parser.add_argument("--group_resolution", action="store_true") # TODO parser.add_argument("--flow_shift", type=float, default=None) - parser.add_argument( - "--preprocess_task", - type=str, - default="t2v", - choices=["t2v", "i2v", "text_only", "ode_trajectory", "matrixgame"], - help="Type of preprocessing task to run") + parser.add_argument("--preprocess_task", + type=str, + default="t2v", + choices=[ + "t2v", "i2v", "text_only", "ode_trajectory", + "matrixgame", "wangame", "wangame_ode_trajectory" + ], + help="Type of preprocessing task to run") parser.add_argument("--train_fps", type=int, default=30) parser.add_argument("--use_image_num", type=int, default=0) parser.add_argument("--text_max_length", type=int, default=256) diff --git a/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline.py b/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline.py new file mode 100644 index 000000000..40c958d76 --- /dev/null +++ b/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from fastvideo.dataset.dataloader.schema import pyarrow_schema_wangame +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline) +from fastvideo.pipelines.stages import ImageEncodingStage + + +class PreprocessPipeline_WanGame(BasePreprocessPipeline): + """I2V preprocessing pipeline implementation.""" + + _required_config_modules = ["vae", "image_encoder", "image_processor"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + self.add_stage(stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + )) + + def get_pyarrow_schema(self): + """Return the PyArrow schema for I2V pipeline.""" + return pyarrow_schema_wangame + + def get_extra_features(self, valid_data: dict[str, Any], + fastvideo_args: FastVideoArgs) -> dict[str, Any]: + + # TODO(will): move these to cpu at some point + self.get_module("image_encoder").to(get_local_torch_device()) + self.get_module("vae").to(get_local_torch_device()) + + features = {} + """Get CLIP features from the first frame of each video.""" + first_frame = valid_data["pixel_values"][:, :, 0, :, :].permute( + 0, 2, 3, 1) # (B, C, T, H, W) -> (B, H, W, C) + _, _, num_frames, height, width = valid_data["pixel_values"].shape + # latent_height = height // self.get_module( + # "vae").spatial_compression_ratio + # latent_width = width // self.get_module("vae").spatial_compression_ratio + + processed_images = [] + # Frame has values between -1 and 1 + for frame in first_frame: + frame = (frame + 1) * 127.5 + frame_pil = Image.fromarray(frame.cpu().numpy().astype(np.uint8)) + processed_img = self.get_module("image_processor")( + images=frame_pil, return_tensors="pt") + processed_images.append(processed_img) + + # Get CLIP features + pixel_values = torch.cat( + [img['pixel_values'] for img in processed_images], + dim=0).to(get_local_torch_device()) + with torch.no_grad(): + image_inputs = {'pixel_values': pixel_values} + with set_forward_context(current_timestep=0, attn_metadata=None): + clip_features = self.get_module("image_encoder")(**image_inputs) + clip_features = clip_features.last_hidden_state + + features["clip_feature"] = clip_features + """Get VAE features from the first frame of each video""" + video_conditions = [] + for frame in first_frame: + processed_img = frame.to(device="cpu", dtype=torch.float32) + processed_img = processed_img.unsqueeze(0).permute(0, 3, 1, + 2).unsqueeze(2) + # (B, H, W, C) -> (B, C, 1, H, W) + video_condition = torch.cat([ + processed_img, + processed_img.new_zeros(processed_img.shape[0], + processed_img.shape[1], num_frames - 1, + height, width) + ], + dim=2) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32) + video_conditions.append(video_condition) + + video_conditions = torch.cat(video_conditions, dim=0) + + with torch.autocast(device_type="cuda", + dtype=torch.float32, + enabled=True): + encoder_outputs = self.get_module("vae").encode(video_conditions) + + latent_condition = encoder_outputs.mean + if (hasattr(self.get_module("vae"), "shift_factor") + and self.get_module("vae").shift_factor is not None): + if isinstance(self.get_module("vae").shift_factor, torch.Tensor): + latent_condition -= self.get_module("vae").shift_factor.to( + latent_condition.device, latent_condition.dtype) + else: + latent_condition -= self.get_module("vae").shift_factor + + if isinstance(self.get_module("vae").scaling_factor, torch.Tensor): + latent_condition = latent_condition * self.get_module( + "vae").scaling_factor.to(latent_condition.device, + latent_condition.dtype) + else: + latent_condition = latent_condition * self.get_module( + "vae").scaling_factor + + # mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + # latent_width) + # mask_lat_size[:, :, list(range(1, num_frames))] = 0 + # first_frame_mask = mask_lat_size[:, :, 0:1] + # first_frame_mask = torch.repeat_interleave( + # first_frame_mask, + # dim=2, + # repeats=self.get_module("vae").temporal_compression_ratio) + # mask_lat_size = torch.concat( + # [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + # mask_lat_size = mask_lat_size.view( + # batch_size, -1, + # self.get_module("vae").temporal_compression_ratio, latent_height, + # latent_width) + # mask_lat_size = mask_lat_size.transpose(1, 2) + # mask_lat_size = mask_lat_size.to(latent_condition.device) + + # image_latent = torch.concat([mask_lat_size, latent_condition], dim=1) + + features["first_frame_latent"] = latent_condition + + if "action_path" in valid_data and valid_data["action_path"]: + keyboard_cond_list = [] + mouse_cond_list = [] + num_bits = 6 + for action_path in valid_data["action_path"]: + if action_path: + action_data = np.load(action_path, allow_pickle=True) + if isinstance( + action_data, + np.ndarray) and action_data.dtype == np.dtype('O'): + action_dict = action_data.item() + if "keyboard" in action_dict: + keyboard_raw = action_dict["keyboard"] + # Convert 1D bit-flag values to 2D multi-hot encoding + if isinstance(keyboard_raw, np.ndarray): + if keyboard_raw.ndim == 1: + # [T] -> [T, num_bits] + T = len(keyboard_raw) + multi_hot = np.zeros((T, num_bits), + dtype=np.float32) + action_values = keyboard_raw.astype(int) + for bit_idx in range(num_bits): + target_idx = ( + 2 - + (bit_idx % 3)) + 3 * (bit_idx // 3) + if target_idx < num_bits: + multi_hot[:, target_idx] = ( + (action_values >> bit_idx) + & 1).astype(np.float32) + keyboard_cond_list.append(multi_hot) + else: + # If already 2D, pad to num_bits if necessary + k_data = keyboard_raw.astype(np.float32) + if k_data.ndim == 2 and k_data.shape[ + -1] < num_bits: + padding = np.zeros( + (k_data.shape[0], + num_bits - k_data.shape[-1]), + dtype=np.float32) + k_data = np.concatenate( + [k_data, padding], axis=-1) + keyboard_cond_list.append(k_data) + else: + keyboard_cond_list.append(keyboard_raw) + if "mouse" in action_dict: + mouse_cond_list.append(action_dict["mouse"]) + else: + if isinstance(action_data, + np.ndarray) and action_data.ndim == 1: + T = len(action_data) + multi_hot = np.zeros((T, num_bits), + dtype=np.float32) + action_values = action_data.astype(int) + for bit_idx in range(num_bits): + target_idx = ( + 2 - (bit_idx % 3)) + 3 * (bit_idx // 3) + if target_idx < num_bits: + multi_hot[:, target_idx] = ( + (action_values >> bit_idx) & 1).astype( + np.float32) + keyboard_cond_list.append(multi_hot) + else: + # If already 2D, pad to num_bits if necessary + k_data = action_data.astype(np.float32) + if k_data.ndim == 2 and k_data.shape[-1] < num_bits: + padding = np.zeros( + (k_data.shape[0], + num_bits - k_data.shape[-1]), + dtype=np.float32) + k_data = np.concatenate([k_data, padding], + axis=-1) + keyboard_cond_list.append(k_data) + if keyboard_cond_list: + features["keyboard_cond"] = keyboard_cond_list + if mouse_cond_list: + features["mouse_cond"] = mouse_cond_list + + return features + + def create_record( + self, + video_name: str, + vae_latent: np.ndarray, + text_embedding: np.ndarray, + valid_data: dict[str, Any], + idx: int, + extra_features: dict[str, Any] | None = None) -> dict[str, Any]: + """Create a record for the Parquet dataset with CLIP features.""" + record = super().create_record(video_name=video_name, + vae_latent=vae_latent, + text_embedding=text_embedding, + valid_data=valid_data, + idx=idx, + extra_features=extra_features) + + if extra_features and "clip_feature" in extra_features: + clip_feature = extra_features["clip_feature"] + record.update({ + "clip_feature_bytes": clip_feature.tobytes(), + "clip_feature_shape": list(clip_feature.shape), + "clip_feature_dtype": str(clip_feature.dtype), + }) + else: + record.update({ + "clip_feature_bytes": b"", + "clip_feature_shape": [], + "clip_feature_dtype": "", + }) + + if extra_features and "first_frame_latent" in extra_features: + first_frame_latent = extra_features["first_frame_latent"] + record.update({ + "first_frame_latent_bytes": + first_frame_latent.tobytes(), + "first_frame_latent_shape": + list(first_frame_latent.shape), + "first_frame_latent_dtype": + str(first_frame_latent.dtype), + }) + else: + record.update({ + "first_frame_latent_bytes": b"", + "first_frame_latent_shape": [], + "first_frame_latent_dtype": "", + }) + + if extra_features and "pil_image" in extra_features: + pil_image = extra_features["pil_image"] + record.update({ + "pil_image_bytes": pil_image.tobytes(), + "pil_image_shape": list(pil_image.shape), + "pil_image_dtype": str(pil_image.dtype), + }) + else: + record.update({ + "pil_image_bytes": b"", + "pil_image_shape": [], + "pil_image_dtype": "", + }) + + if extra_features and "keyboard_cond" in extra_features: + keyboard_cond = extra_features["keyboard_cond"] + record.update({ + "keyboard_cond_bytes": keyboard_cond.tobytes(), + "keyboard_cond_shape": list(keyboard_cond.shape), + "keyboard_cond_dtype": str(keyboard_cond.dtype), + }) + else: + record.update({ + "keyboard_cond_bytes": b"", + "keyboard_cond_shape": [], + "keyboard_cond_dtype": "", + }) + + if extra_features and "mouse_cond" in extra_features: + mouse_cond = extra_features["mouse_cond"] + record.update({ + "mouse_cond_bytes": mouse_cond.tobytes(), + "mouse_cond_shape": list(mouse_cond.shape), + "mouse_cond_dtype": str(mouse_cond.dtype), + }) + else: + record.update({ + "mouse_cond_bytes": b"", + "mouse_cond_shape": [], + "mouse_cond_dtype": "", + }) + + return record + + +EntryClass = PreprocessPipeline_WanGame diff --git a/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline_ode_trajectory.py b/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline_ode_trajectory.py new file mode 100644 index 000000000..6335d187e --- /dev/null +++ b/fastvideo/pipelines/preprocess/wangame/wangame_preprocess_pipeline_ode_trajectory.py @@ -0,0 +1,497 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +ODE Trajectory Data Preprocessing pipeline implementation. + +This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline +using the modular pipeline architecture. + +Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772 +""" + +import os +from collections.abc import Iterator +from typing import Any + +import numpy as np +import pyarrow as pa +import torch +from PIL import Image +from torch.utils.data import DataLoader +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset import getdataset +from fastvideo.dataset.dataloader.parquet_io import (ParquetDatasetWriter, + records_to_table) +from fastvideo.dataset.dataloader.record_schema import ( + wangame_ode_record_creator) +from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_ode_trajectory_wangame) +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler) +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline) +from fastvideo.pipelines.stages import (DecodingStage, DenoisingStage, + InputValidationStage, + LatentPreparationStage, + ImageEncodingStage, + TimestepPreparationStage) +from fastvideo.utils import save_decoded_latents_as_video, shallow_asdict + +logger = init_logger(__name__) + + +class PreprocessPipeline_WanGame_ODE_Trajectory(BasePreprocessPipeline): + """ODE Trajectory preprocessing pipeline implementation.""" + + _required_config_modules = [ + "vae", "image_encoder", "image_processor", "transformer", "scheduler" + ] + + preprocess_dataloader: StatefulDataLoader + preprocess_loader_iter: Iterator[dict[str, Any]] + pbar: Any + num_processed_samples: int + + def get_pyarrow_schema(self) -> pa.Schema: + """Return the PyArrow schema for ODE Trajectory pipeline.""" + return pyarrow_schema_ode_trajectory_wangame + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + assert fastvideo_args.pipeline_config.flow_shift == 5 + self.modules["scheduler"] = SelfForcingFlowMatchScheduler( + shift=fastvideo_args.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + self.modules["scheduler"].set_timesteps(num_inference_steps=48, + denoising_strength=1.0) + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + self.add_stage(stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + )) + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None))) + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + pipeline=self, + )) + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + def get_extra_features(self, valid_data: dict[str, Any], + fastvideo_args: FastVideoArgs) -> dict[str, Any]: + + # TODO(will): move these to cpu at some point + self.get_module("image_encoder").to(get_local_torch_device()) + self.get_module("vae").to(get_local_torch_device()) + + features = {} + """Get CLIP features from the first frame of each video.""" + first_frame = valid_data["pixel_values"][:, :, 0, :, :].permute( + 0, 2, 3, 1) # (B, C, T, H, W) -> (B, H, W, C) + _, _, num_frames, height, width = valid_data["pixel_values"].shape + # latent_height = height // self.get_module( + # "vae").spatial_compression_ratio + # latent_width = width // self.get_module("vae").spatial_compression_ratio + + processed_images = [] + # Frame has values between -1 and 1 + for frame in first_frame: + frame = (frame + 1) * 127.5 + frame_pil = Image.fromarray(frame.cpu().numpy().astype(np.uint8)) + processed_img = self.get_module("image_processor")( + images=frame_pil, return_tensors="pt") + processed_images.append(processed_img) + + # Get CLIP features + pixel_values = torch.cat( + [img['pixel_values'] for img in processed_images], + dim=0).to(get_local_torch_device()) + with torch.no_grad(): + image_inputs = {'pixel_values': pixel_values} + with set_forward_context(current_timestep=0, attn_metadata=None): + clip_features = self.get_module("image_encoder")(**image_inputs) + clip_features = clip_features.last_hidden_state + + features["clip_feature"] = clip_features + features["pil_image"] = first_frame + """Get VAE features from the first frame of each video""" + video_conditions = [] + for frame in first_frame: + processed_img = frame.to(device="cpu", dtype=torch.float32) + processed_img = processed_img.unsqueeze(0).permute(0, 3, 1, + 2).unsqueeze(2) + # (B, H, W, C) -> (B, C, 1, H, W) + video_condition = torch.cat([ + processed_img, + processed_img.new_zeros(processed_img.shape[0], + processed_img.shape[1], num_frames - 1, + height, width) + ], + dim=2) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32) + video_conditions.append(video_condition) + + video_conditions = torch.cat(video_conditions, dim=0) + + with torch.autocast(device_type="cuda", + dtype=torch.float32, + enabled=True): + encoder_outputs = self.get_module("vae").encode(video_conditions) + + # Use mode() instead of mean + latent_condition = encoder_outputs.mode() + + # Use latents_mean/latents_std normalization to match + vae = self.get_module("vae") + if (hasattr(vae.config, 'latents_mean') + and hasattr(vae.config, 'latents_std')): + latents_mean = torch.tensor(vae.config.latents_mean, + device=latent_condition.device, + dtype=latent_condition.dtype).view( + 1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, + device=latent_condition.device, + dtype=latent_condition.dtype).view( + 1, -1, 1, 1, 1) + latent_condition = (latent_condition - latents_mean) / latents_std + elif (hasattr(vae, "shift_factor") and vae.shift_factor is not None): + if isinstance(vae.shift_factor, torch.Tensor): + latent_condition -= vae.shift_factor.to(latent_condition.device, + latent_condition.dtype) + else: + latent_condition -= vae.shift_factor + + if isinstance(vae.scaling_factor, torch.Tensor): + latent_condition = latent_condition * vae.scaling_factor.to( + latent_condition.device, latent_condition.dtype) + else: + latent_condition = latent_condition * vae.scaling_factor + + # mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + # latent_width) + # mask_lat_size[:, :, list(range(1, num_frames))] = 0 + # first_frame_mask = mask_lat_size[:, :, 0:1] + # first_frame_mask = torch.repeat_interleave( + # first_frame_mask, + # dim=2, + # repeats=self.get_module("vae").temporal_compression_ratio) + # mask_lat_size = torch.concat( + # [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + # mask_lat_size = mask_lat_size.view( + # batch_size, -1, + # self.get_module("vae").temporal_compression_ratio, latent_height, + # latent_width) + # mask_lat_size = mask_lat_size.transpose(1, 2) + # mask_lat_size = mask_lat_size.to(latent_condition.device) + + # image_latent = torch.concat([mask_lat_size, latent_condition], dim=1) + + # Create mask_cond: ones for first frame, zeros for rest + # Shape: (B, 16, latent_frames, latent_height, latent_width) + mask_cond = torch.ones_like(latent_condition) + mask_cond[:, :, 1:] = 0 # Set all frames except first to 0 + # Create cond_concat: first 4 channels of mask + all 16 channels of img_cond + # Shape: (B, 20, latent_frames, latent_height, latent_width) + cond_concat = torch.cat([mask_cond[:, :4], latent_condition], dim=1) + features["first_frame_latent"] = cond_concat + + if "action_path" in valid_data and valid_data["action_path"]: + keyboard_cond_list = [None] * len(valid_data["action_path"]) + mouse_cond_list = [None] * len(valid_data["action_path"]) + arch_cfg = self.get_module("transformer").config.arch_config + action_cfg = getattr(arch_cfg, "action_config", {}) or {} + keyboard_dim = action_cfg.get("keyboard_dim_in", None) + for idx, action_path in enumerate(valid_data["action_path"]): + if action_path: + action_data = np.load(action_path, allow_pickle=True) + if isinstance( + action_data, + np.ndarray) and action_data.dtype == np.dtype('O'): + action_dict = action_data.item() + if "keyboard" in action_dict: + keyboard = action_dict["keyboard"].astype( + np.float32) + if keyboard_dim is not None: + if keyboard.ndim >= 2: + keyboard = keyboard[:, :keyboard_dim] + else: + keyboard = keyboard[:keyboard_dim] + keyboard_cond_list[idx] = keyboard + if "mouse" in action_dict: + mouse_cond_list[idx] = action_dict["mouse"].astype( + np.float32) + else: + keyboard = action_data.astype(np.float32) + if keyboard_dim is not None: + if keyboard.ndim >= 2: + keyboard = keyboard[:, :keyboard_dim] + else: + keyboard = keyboard[:keyboard_dim] + keyboard_cond_list[idx] = keyboard + features["keyboard_cond"] = keyboard_cond_list + features["mouse_cond"] = mouse_cond_list + return features + + def preprocess_action_and_trajectory(self, fastvideo_args: FastVideoArgs, + args): + """Preprocess data and generate trajectory information.""" + + for batch_idx, data in enumerate(self.pbar): + if data is None: + continue + + with torch.inference_mode(): + # Filter out invalid samples (those with all zeros) + valid_indices = [] + for i, pixel_values in enumerate(data["pixel_values"]): + if not torch.all( + pixel_values == 0): # Check if all values are zero + valid_indices.append(i) + self.num_processed_samples += len(valid_indices) + + if not valid_indices: + continue + + # Create new batch with only valid samples + valid_data = { + "pixel_values": + torch.stack( + [data["pixel_values"][i] for i in valid_indices]), + "path": [data["path"][i] for i in valid_indices], + } + + if "fps" in data: + valid_data["fps"] = [data["fps"][i] for i in valid_indices] + if "duration" in data: + valid_data["duration"] = [ + data["duration"][i] for i in valid_indices + ] + if "action_path" in data: + valid_data["action_path"] = [ + data["action_path"][i] for i in valid_indices + ] + + pixel_values = valid_data["pixel_values"] + if pixel_values.shape[2] == 1 and args.num_frames is not None: + pixel_values = pixel_values.repeat(1, 1, args.num_frames, 1, + 1) + valid_data["pixel_values"] = pixel_values + + # Get extra features if needed + extra_features = self.get_extra_features( + valid_data, fastvideo_args) + + clip_features = extra_features['clip_feature'] + image_latents = extra_features['first_frame_latent'] + image_latents = image_latents[:, :, :args.num_latent_t] + pil_image = extra_features['pil_image'] + if "keyboard_cond" in extra_features: + keyboard_cond = extra_features['keyboard_cond'] + else: + keyboard_cond = None + if "mouse_cond" in extra_features: + mouse_cond = extra_features['mouse_cond'] + else: + mouse_cond = None + + sampling_params = SamplingParam.from_pretrained(args.model_path) + + trajectory_latents = [] + trajectory_timesteps = [] + trajectory_decoded = [] + + device = get_local_torch_device() + for i in range(len(valid_indices)): + # Collect the trajectory data + batch = ForwardBatch(**shallow_asdict(sampling_params), ) + batch.image_embeds = [clip_features[i].unsqueeze(0)] + batch.image_latent = image_latents[i].unsqueeze(0) + sample_keyboard = keyboard_cond[ + i] if keyboard_cond is not None else None + sample_mouse = mouse_cond[ + i] if mouse_cond is not None else None + if sample_keyboard is not None and sample_mouse is not None: + batch.keyboard_cond = torch.from_numpy( + sample_keyboard).unsqueeze(0).to(device) + batch.mouse_cond = torch.from_numpy( + sample_mouse).unsqueeze(0).to(device) + else: + batch.keyboard_cond = None + batch.mouse_cond = None + batch.num_inference_steps = 48 + batch.return_trajectory_latents = True + # Enabling this will save the decoded trajectory videos. + # Used for debugging. + batch.return_trajectory_decoded = False + batch.height = args.max_height + batch.width = args.max_width + batch.fps = args.train_fps + batch.num_frames = valid_data["pixel_values"].shape[2] + batch.guidance_scale = 6.0 + batch.do_classifier_free_guidance = False + batch.prompt = "" + batch.prompt_embeds = [ + torch.zeros( + (1, 0, self.get_module("transformer").hidden_size), + dtype=torch.bfloat16, + device=device) + ] + + result_batch = self.input_validation_stage( + batch, fastvideo_args) + result_batch = self.timestep_preparation_stage( + result_batch, fastvideo_args) + result_batch.timesteps = result_batch.timesteps.to(device) + result_batch = self.latent_preparation_stage( + result_batch, fastvideo_args) + result_batch = self.denoising_stage(result_batch, + fastvideo_args) + result_batch = self.decoding_stage(result_batch, + fastvideo_args) + + trajectory_latents.append( + result_batch.trajectory_latents.cpu()) + trajectory_timesteps.append( + result_batch.trajectory_timesteps.cpu()) + trajectory_decoded.append(result_batch.trajectory_decoded) + + # Prepare extra features + extra_features = { + "trajectory_latents": trajectory_latents, + "trajectory_timesteps": trajectory_timesteps + } + + if batch.return_trajectory_decoded: + for i, decoded_frames in enumerate(trajectory_decoded): + for j, decoded_frame in enumerate(decoded_frames): + save_decoded_latents_as_video( + decoded_frame, + f"decoded_videos/trajectory_decoded_{i}_{j}.mp4", + args.train_fps) + + # Prepare batch data for Parquet dataset + batch_data: list[dict[str, Any]] = [] + + # Add progress bar for saving outputs + save_pbar = tqdm(enumerate(valid_data["path"]), + desc="Saving outputs", + unit="item", + leave=False) + + for idx, video_path in save_pbar: + video_name = os.path.basename(video_path).split(".")[0] + + clip_feature_np = clip_features[idx].cpu().numpy() + first_frame_latent_np = image_latents[idx].cpu().numpy() + pil_image_np = pil_image[idx].cpu().numpy() + keyboard_cond_np = keyboard_cond[ + idx] if keyboard_cond is not None else None + mouse_cond_np = mouse_cond[ + idx] if mouse_cond is not None else None + + # Get trajectory features for this sample + traj_latents = extra_features["trajectory_latents"][idx] + traj_timesteps = extra_features["trajectory_timesteps"][idx] + if isinstance(traj_latents, torch.Tensor): + traj_latents = traj_latents.cpu().float().numpy() + if isinstance(traj_timesteps, torch.Tensor): + traj_timesteps = traj_timesteps.cpu().float().numpy() + + # Create record for Parquet dataset + record: dict[str, Any] = wangame_ode_record_creator( + video_name=video_name, + clip_feature=clip_feature_np, + first_frame_latent=first_frame_latent_np, + trajectory_latents=traj_latents, + trajectory_timesteps=traj_timesteps, + pil_image=pil_image_np, + keyboard_cond=keyboard_cond_np, + mouse_cond=mouse_cond_np, + caption="") + batch_data.append(record) + + if batch_data: + write_pbar = tqdm(total=1, + desc="Writing to Parquet dataset", + unit="batch") + table = records_to_table(batch_data, + self.get_pyarrow_schema()) + write_pbar.update(1) + write_pbar.close() + + if not hasattr(self, 'dataset_writer'): + self.dataset_writer = ParquetDatasetWriter( + out_dir=self.combined_parquet_dir, + samples_per_file=args.samples_per_file, + ) + self.dataset_writer.append_table(table) + + logger.info("Collected batch with %s samples", len(table)) + + if self.num_processed_samples >= args.flush_frequency: + written = self.dataset_writer.flush() + logger.info("Flushed %s samples to parquet", written) + self.num_processed_samples = 0 + + # Final flush for any remaining samples + if hasattr(self, 'dataset_writer'): + written = self.dataset_writer.flush(write_remainder=True) + if written: + logger.info("Final flush wrote %s samples", written) + + def forward(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs, args): + if not self.post_init_called: + self.post_init() + + self.local_rank = int(os.getenv("RANK", 0)) + os.makedirs(args.output_dir, exist_ok=True) + # Create directory for combined data + self.combined_parquet_dir = os.path.join(args.output_dir, + "combined_parquet_dataset") + os.makedirs(self.combined_parquet_dir, exist_ok=True) + + # Loading dataset + train_dataset = getdataset(args) + + self.preprocess_dataloader = DataLoader( + train_dataset, + batch_size=args.preprocess_video_batch_size, + num_workers=args.dataloader_num_workers, + ) + + self.preprocess_loader_iter = iter(self.preprocess_dataloader) + + self.num_processed_samples = 0 + # Add progress bar for video preprocessing + self.pbar = tqdm(self.preprocess_loader_iter, + desc="Processing videos", + unit="batch", + disable=self.local_rank != 0) + + # Initialize class variables for data sharing + self.video_data: dict[str, Any] = {} # Store video metadata and paths + self.latent_data: dict[str, Any] = {} # Store latent tensors + self.preprocess_action_and_trajectory(fastvideo_args, args) + + +EntryClass = PreprocessPipeline_WanGame_ODE_Trajectory diff --git a/fastvideo/pipelines/samplers/__init__.py b/fastvideo/pipelines/samplers/__init__.py new file mode 100644 index 000000000..638a9e532 --- /dev/null +++ b/fastvideo/pipelines/samplers/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.pipelines.samplers.base import SamplerKind + +__all__ = [ + "SamplerKind", +] diff --git a/fastvideo/pipelines/samplers/base.py b/fastvideo/pipelines/samplers/base.py new file mode 100644 index 000000000..da4ef502c --- /dev/null +++ b/fastvideo/pipelines/samplers/base.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Literal + +SamplerKind = Literal["ode", "sde"] + + +def normalize_sampler_kind( + raw: str | None, + *, + where: str, + default: SamplerKind = "ode", +) -> SamplerKind: + if raw is None: + return default + + kind = str(raw).strip().lower() + if kind == "ode": + return "ode" + if kind == "sde": + return "sde" + + raise ValueError( + f"Unknown sampler kind at {where}: {raw!r} (expected ode|sde)") diff --git a/fastvideo/pipelines/samplers/wan.py b/fastvideo/pipelines/samplers/wan.py new file mode 100644 index 000000000..22fd9d3f3 --- /dev/null +++ b/fastvideo/pipelines/samplers/wan.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, ) +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, ) +from fastvideo.pipelines.samplers.base import SamplerKind, normalize_sampler_kind + + +def get_wan_sampler_kind(fastvideo_args: FastVideoArgs) -> SamplerKind: + raw = getattr(fastvideo_args.pipeline_config, "sampler_kind", None) + return normalize_sampler_kind(raw, where="pipeline_config.sampler_kind") + + +def build_wan_scheduler(fastvideo_args: FastVideoArgs, kind: SamplerKind): + shift = fastvideo_args.pipeline_config.flow_shift + if kind == "sde": + return FlowMatchEulerDiscreteScheduler(shift=shift) + + ode_solver_raw = getattr(fastvideo_args.pipeline_config, "ode_solver", + "unipc") + ode_solver = str(ode_solver_raw).strip().lower( + ) if ode_solver_raw is not None else "unipc" + if ode_solver in {"unipc", "unipc_multistep", "multistep"}: + return FlowUniPCMultistepScheduler(shift=shift) + if ode_solver in {"euler", "flowmatch", "flowmatch_euler"}: + return FlowMatchEulerDiscreteScheduler(shift=shift) + + raise ValueError("Unknown pipeline_config.ode_solver for wan pipelines: " + f"{ode_solver_raw!r} (expected 'unipc' or 'euler').") + + +def wan_use_btchw_layout(kind: SamplerKind) -> bool: + return kind == "sde" diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index 9896539c0..e59bd2a4d 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -13,7 +13,7 @@ from fastvideo.pipelines.stages.denoising import ( Cosmos25AutoDenoisingStage, Cosmos25DenoisingStage, Cosmos25V2WDenoisingStage, Cosmos25T2WDenoisingStage, CosmosDenoisingStage, - DenoisingStage, DmdDenoisingStage) + DenoisingStage, DmdDenoisingStage, SdeDenoisingStage) from fastvideo.pipelines.stages.sr_denoising import SRDenoisingStage from fastvideo.pipelines.stages.encoding import EncodingStage from fastvideo.pipelines.stages.image_encoding import ( @@ -33,7 +33,9 @@ LTX2LatentPreparationStage) from fastvideo.pipelines.stages.ltx2_text_encoding import LTX2TextEncodingStage from fastvideo.pipelines.stages.matrixgame_denoising import ( - MatrixGameCausalDenoisingStage) + MatrixGameCausalDenoisingStage, + MatrixGameCausalOdeDenoisingStage, +) from fastvideo.pipelines.stages.hyworld_denoising import HYWorldDenoisingStage from fastvideo.pipelines.stages.gamecraft_denoising import GameCraftDenoisingStage from fastvideo.pipelines.stages.text_encoding import (Cosmos25TextEncodingStage, @@ -61,9 +63,11 @@ "LTX2AudioDecodingStage", "ConditioningStage", "DenoisingStage", + "SdeDenoisingStage", "DmdDenoisingStage", "CausalDMDDenosingStage", "MatrixGameCausalDenoisingStage", + "MatrixGameCausalOdeDenoisingStage", "HYWorldDenoisingStage", "GameCraftDenoisingStage", "CosmosDenoisingStage", diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 1b11478ce..38cf77f14 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -17,8 +17,6 @@ from fastvideo.forward_context import set_forward_context from fastvideo.logger import init_logger from fastvideo.models.loader.component_loader import TransformerLoader -from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( - FlowMatchEulerDiscreteScheduler) from fastvideo.models.utils import pred_noise_to_pred_video from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage @@ -159,6 +157,35 @@ def forward( }, ) + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + viewmats, intrinsics, action_labels = process_custom_actions( + batch.keyboard_cond, batch.mouse_cond) + camera_action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "viewmats": + viewmats.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "Ks": + intrinsics.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "action": + action_labels.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + }, + ) + # from fastvideo.models.dits.wangame_lingbot.cam_utils import process_custom_actions as process_lingbot_actions + # num_frames = batch.num_frames + # latent_height = batch.height // 8 + # latent_width = batch.width // 8 + # c2ws_plucker_emb = process_lingbot_actions( + # num_frames, batch.keyboard_cond, batch.mouse_cond, + # latent_height=latent_height, latent_width=latent_width + # ).to(get_local_torch_device(), dtype=target_dtype) + else: + camera_action_kwargs = {} + action_kwargs = self.prepare_extra_func_kwargs( self.transformer.forward, { @@ -419,8 +446,8 @@ def forward( **image_kwargs, **pos_cond_kwargs, **action_kwargs, - **camera_kwargs, **timesteps_r_kwarg, + **camera_action_kwargs, ) if batch.do_classifier_free_guidance: @@ -438,8 +465,7 @@ def forward( **image_kwargs, **neg_cond_kwargs, **action_kwargs, - **camera_kwargs, - **timesteps_r_kwarg, + **camera_action_kwargs, ) noise_pred_text = noise_pred @@ -527,10 +553,19 @@ def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]: Returns: The prepared kwargs. """ - extra_step_kwargs = {} + signature = inspect.signature(func) + if any(p.kind == inspect.Parameter.VAR_KEYWORD + for p in signature.parameters.values()): + # If the callee accepts `**kwargs`, do not filter by signature. + # This is important for models that route parameters internally via + # `forward(*args, **kwargs)` (e.g. causal Wangame), where filtering + # would incorrectly drop conditioning kwargs like `action`. + return dict(kwargs) + + accepted = set(signature.parameters.keys()) + extra_step_kwargs: dict[str, Any] = {} for k, v in kwargs.items(): - accepts = k in set(inspect.signature(func).parameters.keys()) - if accepts: + if k in accepted: extra_step_kwargs[k] = v return extra_step_kwargs @@ -1171,14 +1206,16 @@ def verify_output(self, batch: ForwardBatch, return self._t2w.verify_output(batch, fastvideo_args) -class DmdDenoisingStage(DenoisingStage): - """ - Denoising stage for DMD. +class SdeDenoisingStage(DenoisingStage): + """Denoising stage for SDE-style sampling. + + This stage runs a stochastic rollout loop: + - predict x0 at timestep t + - inject fresh noise to reach the next timestep """ def __init__(self, transformer, scheduler) -> None: super().__init__(transformer, scheduler) - self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0) def forward( self, @@ -1202,16 +1239,6 @@ def forward( autocast_enabled = (target_dtype != torch.float32 ) and not fastvideo_args.disable_autocast - # Get timesteps and calculate warmup steps - timesteps = batch.timesteps - - # TODO(will): remove this once we add input/output validation for stages - if timesteps is None: - raise ValueError("Timesteps must be provided") - num_inference_steps = batch.num_inference_steps - num_warmup_steps = len( - timesteps) - num_inference_steps * self.scheduler.order - # Prepare image latents and embeddings for I2V generation image_embeds = batch.image_embeds if len(image_embeds) > 0: @@ -1237,6 +1264,37 @@ def forward( }, ) + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats, intrinsics, action_labels = process_custom_actions( + batch.keyboard_cond, batch.mouse_cond) + camera_action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "viewmats": + viewmats.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "Ks": + intrinsics.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + "action": + action_labels.unsqueeze(0).to(get_local_torch_device(), + dtype=target_dtype), + }, + ) + else: + camera_action_kwargs = {} + + action_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "mouse_cond": batch.mouse_cond, + "keyboard_cond": batch.keyboard_cond, + "c2ws_plucker_emb": batch.c2ws_plucker_emb, + }, + ) + # Get latents and embeddings assert batch.latents is not None, "latents must be provided" latents = batch.latents @@ -1245,14 +1303,29 @@ def forward( prompt_embeds = batch.prompt_embeds assert not torch.isnan( prompt_embeds[0]).any(), "prompt_embeds contains nan" - timesteps = torch.tensor( - fastvideo_args.pipeline_config.dmd_denoising_steps, - dtype=torch.long, - device=get_local_torch_device()) + loop_timesteps = batch.sampling_timesteps + if loop_timesteps is None: + legacy = getattr(fastvideo_args.pipeline_config, + "dmd_denoising_steps", None) + if legacy is not None: + loop_timesteps = torch.tensor(legacy, dtype=torch.long) + else: + loop_timesteps = batch.timesteps + + if loop_timesteps is None: + raise ValueError( + "SDE sampling requires `batch.sampling_timesteps` (preferred) " + "or `pipeline_config.dmd_denoising_steps`.") + if not isinstance(loop_timesteps, torch.Tensor): + loop_timesteps = torch.tensor(loop_timesteps, dtype=torch.long) + if loop_timesteps.ndim != 1: + raise ValueError("Expected 1D `sampling_timesteps`, got shape " + f"{tuple(loop_timesteps.shape)}") + loop_timesteps = loop_timesteps.to(get_local_torch_device()) # Run denoising loop - with self.progress_bar(total=len(timesteps)) as progress_bar: - for i, t in enumerate(timesteps): + with self.progress_bar(total=len(loop_timesteps)) as progress_bar: + for i, t in enumerate(loop_timesteps): # Skip if interrupted if hasattr(self, 'interrupt') and self.interrupt: continue @@ -1326,6 +1399,8 @@ def forward( guidance=guidance_expand, **image_kwargs, **pos_cond_kwargs, + **action_kwargs, + **camera_action_kwargs, ).permute(0, 2, 1, 3, 4) pred_video = pred_noise_to_pred_video( @@ -1335,13 +1410,15 @@ def forward( scheduler=self.scheduler).unflatten( 0, pred_noise.shape[:2]) - if i < len(timesteps) - 1: - next_timestep = timesteps[i + 1] * torch.ones( + if i < len(loop_timesteps) - 1: + next_timestep = loop_timesteps[i + 1] * torch.ones( [1], dtype=torch.long, device=pred_video.device) - noise = torch.randn(video_raw_latent_shape, - dtype=pred_video.dtype, - generator=batch.generator[0]).to( - self.device) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video.dtype, + generator=batch.generator[0] if isinstance( + batch.generator, list) else batch.generator).to( + self.device) latents = self.scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) @@ -1349,11 +1426,7 @@ def forward( latents = pred_video # Update progress bar - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and - (i + 1) % self.scheduler.order == 0 - and progress_bar is not None): - progress_bar.update() + progress_bar.update() # Gather results if using sequence parallelism latents = latents.permute(0, 2, 1, 3, 4) @@ -1361,3 +1434,7 @@ def forward( batch.latents = latents return batch + + +# Backwards-compatible alias (legacy pipelines still import this symbol). +DmdDenoisingStage = SdeDenoisingStage diff --git a/fastvideo/pipelines/stages/matrixgame_denoising.py b/fastvideo/pipelines/stages/matrixgame_denoising.py index cd9599bbb..85f42f6b9 100644 --- a/fastvideo/pipelines/stages/matrixgame_denoising.py +++ b/fastvideo/pipelines/stages/matrixgame_denoising.py @@ -54,6 +54,9 @@ class BlockProcessingContext: image_kwargs: dict[str, Any] pos_cond_kwargs: dict[str, Any] + viewmats_full: torch.Tensor | None = None + intrinsics_full: torch.Tensor | None = None + action_full: torch.Tensor | None = None def get_kv_cache(self, timestep_val: float) -> list[dict[Any, Any]]: if self.boundary_timestep is not None: @@ -97,10 +100,12 @@ def __init__(self, -1) except Exception: self.local_attn_size = -1 + try: + self.local_attn_size = getattr(self.transformer.model, + "local_attn_size", -1) + except Exception: + self.local_attn_size = -1 - assert self.local_attn_size != -1, ( - f"local_attn_size must be set for Matrix-Game causal inference, " - f"got {self.local_attn_size}. Check MatrixGameWanVideoArchConfig.") assert self.num_frame_per_block > 0, ( f"num_frame_per_block must be positive, got {self.num_frame_per_block}" ) @@ -126,7 +131,10 @@ def forward( ) and not fastvideo_args.disable_autocast latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] - patch_size = self.transformer.patch_size + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size patch_ratio = patch_size[-1] * patch_size[-2] self.frame_seq_length = latent_seq_length // patch_ratio @@ -166,6 +174,31 @@ def forward( prompt_embeds = batch.prompt_embeds assert torch.isnan(prompt_embeds[0]).sum() == 0 + viewmats_full = None + intrinsics_full = None + action_full = None + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats_list = [] + intrinsics_list = [] + action_list = [] + for bi in range(b): + vm, ks, action = process_custom_actions(batch.keyboard_cond[bi], + batch.mouse_cond[bi]) + viewmats_list.append(vm) + intrinsics_list.append(ks) + action_list.append(action) + viewmats_full = torch.stack(viewmats_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + intrinsics_full = torch.stack(intrinsics_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + action_full = torch.stack(action_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + kv_cache1 = self._initialize_kv_cache(batch_size=latents.shape[0], dtype=target_dtype, device=latents.device) @@ -225,6 +258,9 @@ def forward( "context_noise", 0), image_kwargs=image_kwargs, pos_cond_kwargs=pos_cond_kwargs, + viewmats_full=viewmats_full, + intrinsics_full=intrinsics_full, + action_full=action_full, ) context_noise = getattr(fastvideo_args.pipeline_config, "context_noise", @@ -240,6 +276,8 @@ def forward( action_kwargs = self._prepare_action_kwargs( batch, start_index, current_num_frames) + camera_action_kwargs = self._prepare_camera_action_kwargs( + ctx, start_index, current_num_frames) current_latents = self._process_single_block( current_latents=current_latents, @@ -249,6 +287,7 @@ def forward( timesteps=timesteps, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, progress_bar=progress_bar, ) @@ -263,6 +302,7 @@ def forward( current_num_frames=current_num_frames, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, context_noise=context_noise, ) @@ -324,9 +364,9 @@ def _initialize_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), }) return kv_cache @@ -362,9 +402,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), }) kv_cache_mouse.append({ "k": @@ -382,9 +422,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), "local_end_index": - torch.tensor([0], dtype=torch.long, device=device), + torch.zeros((), dtype=torch.long, device=device), }) return kv_cache_mouse, kv_cache_keyboard @@ -418,6 +458,18 @@ def _initialize_crossattn_cache(self, batch_size: int, max_text_len: int, }) return crossattn_cache + def _prepare_camera_action_kwargs( + self, ctx: BlockProcessingContext, start_index: int, + current_num_frames: int) -> dict[str, Any]: + if ctx.action_full is None or ctx.viewmats_full is None or ctx.intrinsics_full is None: + return {} + end_index = start_index + current_num_frames + return { + "viewmats": ctx.viewmats_full[:, start_index:end_index], + "Ks": ctx.intrinsics_full[:, start_index:end_index], + "action": ctx.action_full[:, start_index:end_index], + } + def _process_single_block( self, current_latents: torch.Tensor, @@ -427,6 +479,7 @@ def _process_single_block( timesteps: torch.Tensor, ctx: BlockProcessingContext, action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], noise_generator: Callable[[tuple, torch.dtype, int], torch.Tensor] | None = None, progress_bar: Any | None = None, @@ -445,7 +498,16 @@ def _process_single_block( independent_first_frame = getattr(self.transformer, 'independent_first_frame', False) - if batch.image_latent is not None and independent_first_frame and start_index == 0: + if batch.image_latent is not None and not independent_first_frame: + image_latent_chunk = batch.image_latent[:, :, start_index: + start_index + + current_num_frames, :, :] + latent_model_input = torch.cat([ + latent_model_input, + image_latent_chunk.to(ctx.target_dtype) + ], + dim=1) + elif batch.image_latent is not None and independent_first_frame and start_index == 0: latent_model_input = torch.cat([ latent_model_input, batch.image_latent.to(ctx.target_dtype) @@ -495,6 +557,7 @@ def _process_single_block( "crossattn_cache": ctx.crossattn_cache, "current_start": start_index * self.frame_seq_length, "start_frame": start_index, + "is_cache": False, } if self.use_action_module and current_model == self.transformer: @@ -510,6 +573,7 @@ def _process_single_block( latent_model_input, prompt_embeds, t_expanded_noise, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, **model_kwargs, @@ -582,6 +646,7 @@ def _update_context_cache( current_num_frames: int, ctx: BlockProcessingContext, action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], context_noise: float, ) -> None: prompt_embeds = batch.prompt_embeds @@ -592,6 +657,17 @@ def _update_context_cache( device=latents_device, dtype=torch.long) * int(context_noise) context_bcthw = current_latents.to(ctx.target_dtype) + context_input = context_bcthw + independent_first_frame = getattr(self.transformer, + "independent_first_frame", False) + if batch.image_latent is not None and not independent_first_frame: + image_context_chunk = batch.image_latent[:, :, + start_index:start_index + + current_num_frames, :, :] + context_input = torch.cat( + [context_input, + image_context_chunk.to(ctx.target_dtype)], + dim=1) with torch.autocast(device_type="cuda", dtype=ctx.target_dtype, @@ -605,6 +681,7 @@ def _update_context_cache( "crossattn_cache": ctx.crossattn_cache, "current_start": start_index * self.frame_seq_length, "start_frame": start_index, + "is_cache": True, } if self.use_action_module: @@ -617,26 +694,398 @@ def _update_context_cache( context_model_kwargs.update(action_kwargs) if ctx.boundary_timestep is not None and self.transformer_2 is not None: - self.transformer_2( - context_bcthw, + cache_update_ret_2 = self.transformer_2( + context_input, prompt_embeds, t_context, kv_cache=ctx.kv_cache2, crossattn_cache=ctx.crossattn_cache, current_start=start_index * self.frame_seq_length, start_frame=start_index, + is_cache=True, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, ) + if isinstance(cache_update_ret_2, + list) and len(cache_update_ret_2) > 0: + ctx.kv_cache2 = cache_update_ret_2 - self.transformer( - context_bcthw, + cache_update_ret = self.transformer( + context_input, prompt_embeds, t_context, + **camera_action_kwargs, **ctx.image_kwargs, **ctx.pos_cond_kwargs, **context_model_kwargs, ) + if isinstance(cache_update_ret, list) and len(cache_update_ret) > 0: + ctx.kv_cache1 = cache_update_ret + + +class MatrixGameCausalOdeDenoisingStage(MatrixGameCausalDenoisingStage): + """Causal ODE denoising for WanGame/MatrixGame. + + This is the deterministic counterpart of `MatrixGameCausalDenoisingStage`. + It performs block-by-block causal rollout, but uses the scheduler's ODE-style + `step()` update (no re-noising between steps). + """ + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + timesteps = batch.timesteps + if timesteps is None: + raise ValueError( + "MatrixGameCausalOdeDenoisingStage requires batch.timesteps. " + "Make sure TimestepPreparationStage runs before this stage.") + + target_dtype = torch.bfloat16 + autocast_enabled = (target_dtype != torch.float32 + ) and not fastvideo_args.disable_autocast + + latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size + patch_ratio = patch_size[-1] * patch_size[-2] + self.frame_seq_length = latent_seq_length // patch_ratio + + timesteps = timesteps.to(get_local_torch_device()) + + boundary_ratio = getattr(fastvideo_args.pipeline_config.dit_config, + "boundary_ratio", None) + if boundary_ratio is not None: + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + else: + boundary_timestep = None + + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + assert torch.isnan(image_embeds[0]).sum() == 0 + image_embeds = [ + image_embed.to(target_dtype) for image_embed in image_embeds + ] + + # directly set the kwarg. + image_kwargs = {"encoder_hidden_states_image": image_embeds} + pos_cond_kwargs: dict[str, Any] = {} + + assert batch.latents is not None, "latents must be provided" + latents = batch.latents + b, c, t, h, w = latents.shape + + prompt_embeds = batch.prompt_embeds + assert torch.isnan(prompt_embeds[0]).sum() == 0 + + viewmats_full = None + intrinsics_full = None + action_full = None + if batch.mouse_cond is not None and batch.keyboard_cond is not None: + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + viewmats_list = [] + intrinsics_list = [] + action_list = [] + for bi in range(b): + vm, ks, action = process_custom_actions(batch.keyboard_cond[bi], + batch.mouse_cond[bi]) + viewmats_list.append(vm) + intrinsics_list.append(ks) + action_list.append(action) + viewmats_full = torch.stack(viewmats_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + intrinsics_full = torch.stack(intrinsics_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + action_full = torch.stack(action_list, + dim=0).to(device=latents.device, + dtype=target_dtype) + + kv_cache1 = self._initialize_kv_cache(batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + kv_cache2 = None + if boundary_timestep is not None: + kv_cache2 = self._initialize_kv_cache(batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + + kv_cache_mouse = None + kv_cache_keyboard = None + if self.use_action_module: + kv_cache_mouse, kv_cache_keyboard = self._initialize_action_kv_cache( + batch_size=latents.shape[0], + dtype=target_dtype, + device=latents.device) + + crossattn_cache = self._initialize_crossattn_cache( + batch_size=latents.shape[0], + max_text_len=257, # 1 CLS + 256 patch tokens + dtype=target_dtype, + device=latents.device) + + if t % self.num_frame_per_block != 0: + raise ValueError( + "num_frames must be divisible by num_frame_per_block for causal denoising" + ) + num_blocks = t // self.num_frame_per_block + block_sizes = [self.num_frame_per_block] * num_blocks + start_index = 0 + + if boundary_timestep is not None: + block_sizes[0] = 1 + + ctx = BlockProcessingContext( + batch=batch, + block_idx=0, + start_index=0, + kv_cache1=kv_cache1, + kv_cache2=kv_cache2, + kv_cache_mouse=kv_cache_mouse, + kv_cache_keyboard=kv_cache_keyboard, + crossattn_cache=crossattn_cache, + timesteps=timesteps, + block_sizes=block_sizes, + noise_pool=None, + fastvideo_args=fastvideo_args, + target_dtype=target_dtype, + autocast_enabled=autocast_enabled, + boundary_timestep=boundary_timestep, + high_noise_timesteps=None, + context_noise=getattr(fastvideo_args.pipeline_config, + "context_noise", 0), + image_kwargs=image_kwargs, + pos_cond_kwargs=pos_cond_kwargs, + viewmats_full=viewmats_full, + intrinsics_full=intrinsics_full, + action_full=action_full, + ) + + context_noise = getattr(fastvideo_args.pipeline_config, "context_noise", + 0) + + with self.progress_bar(total=len(block_sizes) * + len(timesteps)) as progress_bar: + for block_idx, current_num_frames in enumerate(block_sizes): + ctx.block_idx = block_idx + ctx.start_index = start_index + current_latents = latents[:, :, start_index:start_index + + current_num_frames, :, :] + + # The scheduler maintains an internal `step_index` (and potentially + # additional multistep state, e.g. UniPC). Since causal streaming runs + # a full denoising trajectory *per block*, reset that state before + # each block rollout. + self._reset_scheduler_state_for_new_rollout() + + action_kwargs = self._prepare_action_kwargs( + batch, start_index, current_num_frames) + camera_action_kwargs = self._prepare_camera_action_kwargs( + ctx, start_index, current_num_frames) + + current_latents = self._process_single_block_ode( + current_latents=current_latents, + batch=batch, + start_index=start_index, + current_num_frames=current_num_frames, + timesteps=timesteps, + ctx=ctx, + action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, + progress_bar=progress_bar, + ) + + latents[:, :, start_index:start_index + + current_num_frames, :, :] = current_latents + + # Update KV caches with clean context + self._update_context_cache( + current_latents=current_latents, + batch=batch, + start_index=start_index, + current_num_frames=current_num_frames, + ctx=ctx, + action_kwargs=action_kwargs, + camera_action_kwargs=camera_action_kwargs, + context_noise=context_noise, + ) + + start_index += current_num_frames + + if boundary_timestep is not None: + num_frames_to_remove = self.num_frame_per_block - 1 + if num_frames_to_remove > 0: + latents = latents[:, :, :-num_frames_to_remove, :, :] + + batch.latents = latents + return batch + + def _reset_scheduler_state_for_new_rollout(self) -> None: + scheduler = self.scheduler + + # Common diffusers-like state. + if hasattr(scheduler, "_step_index"): + scheduler._step_index = None # type: ignore[attr-defined] + if hasattr(scheduler, "_begin_index"): + scheduler._begin_index = None # type: ignore[attr-defined] + + # UniPC multistep state (FlowUniPCMultistepScheduler) needs additional reset + # between independent trajectories. + if hasattr(scheduler, "model_outputs") and hasattr(scheduler, "config"): + try: + solver_order = int( + getattr(scheduler.config, "solver_order", 0) or 0) + except Exception: + solver_order = 0 + if solver_order > 0: + scheduler.model_outputs = [ + None + ] * solver_order # type: ignore[attr-defined] + if hasattr(scheduler, "timestep_list") and hasattr(scheduler, "config"): + try: + solver_order = int( + getattr(scheduler.config, "solver_order", 0) or 0) + except Exception: + solver_order = 0 + if solver_order > 0: + scheduler.timestep_list = [ + None + ] * solver_order # type: ignore[attr-defined] + if hasattr(scheduler, "lower_order_nums"): + scheduler.lower_order_nums = 0 # type: ignore[attr-defined] + if hasattr(scheduler, "last_sample"): + scheduler.last_sample = None # type: ignore[attr-defined] + + def _process_single_block_ode( + self, + *, + current_latents: torch.Tensor, + batch: ForwardBatch, + start_index: int, + current_num_frames: int, + timesteps: torch.Tensor, + ctx: BlockProcessingContext, + action_kwargs: dict[str, Any], + camera_action_kwargs: dict[str, Any], + progress_bar: Any | None = None, + ) -> torch.Tensor: + prompt_embeds = batch.prompt_embeds + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + { + "generator": batch.generator, + "eta": batch.eta, + }, + ) + + for i, t_cur in enumerate(timesteps): + if ctx.boundary_timestep is not None and t_cur < ctx.boundary_timestep: + current_model = self.transformer_2 if self.transformer_2 is not None else self.transformer + else: + current_model = self.transformer + + latent_model_input = current_latents.to(ctx.target_dtype) + + independent_first_frame = getattr(self.transformer, + "independent_first_frame", False) + if batch.image_latent is not None and not independent_first_frame: + image_latent_chunk = batch.image_latent[:, :, start_index: + start_index + + current_num_frames, :, :] + latent_model_input = torch.cat([ + latent_model_input, + image_latent_chunk.to(ctx.target_dtype) + ], + dim=1) + elif (batch.image_latent is not None and independent_first_frame + and start_index == 0): + latent_model_input = torch.cat([ + latent_model_input, + batch.image_latent.to(ctx.target_dtype) + ], + dim=2) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_cur) + + # Build attention metadata if VSA is available + if vsa_available and self.attn_backend == VideoSparseAttentionBackend: + self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls( + ) + if self.attn_metadata_builder_cls is not None: + self.attn_metadata_builder = self.attn_metadata_builder_cls( + ) + h, w = current_latents.shape[-2:] + attn_metadata = self.attn_metadata_builder.build( + current_timestep=i, + raw_latent_shape=(current_num_frames, h, w), + patch_size=ctx.fastvideo_args.pipeline_config. + dit_config.patch_size, + VSA_sparsity=ctx.fastvideo_args.VSA_sparsity, + device=get_local_torch_device(), + ) + assert attn_metadata is not None, "attn_metadata cannot be None" + else: + attn_metadata = None + else: + attn_metadata = None + + with torch.autocast(device_type="cuda", + dtype=ctx.target_dtype, + enabled=ctx.autocast_enabled), \ + set_forward_context(current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch): + t_expanded = t_cur * torch.ones( + (latent_model_input.shape[0], current_num_frames), + device=latent_model_input.device, + dtype=t_cur.dtype) + + model_kwargs = { + "kv_cache": ctx.get_kv_cache(t_cur), + "crossattn_cache": ctx.crossattn_cache, + "current_start": start_index * self.frame_seq_length, + "start_frame": start_index, + "is_cache": False, + } + + if self.use_action_module and current_model == self.transformer: + model_kwargs.update({ + "kv_cache_mouse": + ctx.kv_cache_mouse, + "kv_cache_keyboard": + ctx.kv_cache_keyboard, + }) + model_kwargs.update(action_kwargs) + + noise_pred = current_model( + latent_model_input, + prompt_embeds, + t_expanded, + **camera_action_kwargs, + **ctx.image_kwargs, + **ctx.pos_cond_kwargs, + **model_kwargs, + ) + + current_latents = self.scheduler.step( + noise_pred, + t_cur, + current_latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + if progress_bar is not None: + progress_bar.update() + + return current_latents def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch: @@ -645,7 +1094,10 @@ def streaming_reset(self, batch: ForwardBatch, ) and not fastvideo_args.disable_autocast latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] - patch_size = self.transformer.patch_size + if hasattr(self.transformer, "patch_size"): + patch_size = self.transformer.patch_size + else: + patch_size = self.transformer.config.arch_config.patch_size patch_ratio = patch_size[-1] * patch_size[-2] self.frame_seq_length = latent_seq_length // patch_ratio @@ -821,6 +1273,7 @@ def streaming_noise_generator(shape: tuple, dtype: torch.dtype, timesteps=ctx.timesteps, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs={}, noise_generator=streaming_noise_generator, ) @@ -835,6 +1288,7 @@ def streaming_noise_generator(shape: tuple, dtype: torch.dtype, current_num_frames=current_num_frames, ctx=ctx, action_kwargs=action_kwargs, + camera_action_kwargs={}, context_noise=ctx.context_noise, ) diff --git a/fastvideo/registry.py b/fastvideo/registry.py index 230d878fa..2d8c1f29f 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -32,20 +32,11 @@ TurboDiffusionT2V_1_3B_Config, ) from fastvideo.configs.pipelines.wan import ( - FastWan2_1_T2V_480P_Config, - FastWan2_2_TI2V_5B_Config, - MatrixGameI2V480PConfig, - SelfForcingWan2_2_T2V480PConfig, - SelfForcingWanT2V480PConfig, - WANV2VConfig, - Wan2_2_I2V_A14B_Config, - Wan2_2_T2V_A14B_Config, - Wan2_2_TI2V_5B_Config, - WanI2V480PConfig, - WanI2V720PConfig, - WanT2V480PConfig, - WanT2V720PConfig, -) + FastWan2_1_T2V_480P_Config, FastWan2_2_TI2V_5B_Config, + MatrixGameI2V480PConfig, SelfForcingWan2_2_T2V480PConfig, + SelfForcingWanT2V480PConfig, WANV2VConfig, Wan2_2_I2V_A14B_Config, + Wan2_2_T2V_A14B_Config, Wan2_2_TI2V_5B_Config, WanI2V480PConfig, + WanI2V720PConfig, WanT2V480PConfig, WanT2V720PConfig, WanGameI2V480PConfig) from fastvideo.configs.pipelines.sd35 import SD35Config from fastvideo.configs.sample.base import SamplingParam from fastvideo.configs.sample.cosmos import ( @@ -558,6 +549,14 @@ def _register_configs() -> None: "FastVideo/SFWan2.2-I2V-A14B-Preview-Diffusers", ], ) + register_configs( + sampling_param_cls=Wan2_1_Fun_1_3B_InP_SamplingParam, + pipeline_config_cls=WanGameI2V480PConfig, + hf_model_paths=[ + "weizhou03/Wan2.1-Game-Fun-1.3B-InP-Diffusers", + ], + ) + # TODO: Need to add Lingbot # SD3.5 register_configs( diff --git a/fastvideo/train/.style.yapf b/fastvideo/train/.style.yapf new file mode 100644 index 000000000..c9a88d5a6 --- /dev/null +++ b/fastvideo/train/.style.yapf @@ -0,0 +1,3 @@ +[style] +based_on_style = pep8 +column_limit = 120 diff --git a/fastvideo/train/__init__.py b/fastvideo/train/__init__.py new file mode 100644 index 000000000..fed6b183c --- /dev/null +++ b/fastvideo/train/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.trainer import Trainer + +__all__ = [ + "Trainer", +] diff --git a/fastvideo/train/callbacks/__init__.py b/fastvideo/train/callbacks/__init__.py new file mode 100644 index 000000000..23334280a --- /dev/null +++ b/fastvideo/train/callbacks/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.callbacks.callback import ( + Callback, + CallbackDict, +) +from fastvideo.train.callbacks.ema import ( + EMACallback, +) +from fastvideo.train.callbacks.grad_clip import ( + GradNormClipCallback, +) +from fastvideo.train.callbacks.validation import ( + ValidationCallback, +) + +__all__ = [ + "Callback", + "CallbackDict", + "EMACallback", + "GradNormClipCallback", + "ValidationCallback", +] diff --git a/fastvideo/train/callbacks/callback.py b/fastvideo/train/callbacks/callback.py new file mode 100644 index 000000000..b44c7dc07 --- /dev/null +++ b/fastvideo/train/callbacks/callback.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Callback base class and CallbackDict manager. + +Adapted from FastGen's callback pattern to FastVideo's types. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TYPE_CHECKING + +from fastvideo.logger import init_logger +from fastvideo.train.utils.instantiate import instantiate + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +logger = init_logger(__name__) + +# Well-known callback names that don't need ``_target_`` in YAML. +_BUILTIN_CALLBACKS: dict[str, str] = { + "grad_clip": "fastvideo.train.callbacks.grad_clip.GradNormClipCallback", + "validation": "fastvideo.train.callbacks.validation.ValidationCallback", + "ema": "fastvideo.train.callbacks.ema.EMACallback", +} + + +class Callback: + """Base callback with no-op hooks. + + Subclasses override whichever hooks they need. The + ``training_config`` and ``method`` attributes are set by + ``CallbackDict`` after instantiation. + """ + + training_config: TrainingConfig + method: TrainingMethod + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_training_step_end( + self, + method: TrainingMethod, + loss_dict: dict[str, Any], + iteration: int = 0, + ) -> None: + pass + + def on_before_optimizer_step( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_validation_begin( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_validation_end( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def on_train_end( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + pass + + def state_dict(self) -> dict[str, Any]: + return {} + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + pass + + +class CallbackDict: + """Manages a collection of named callbacks. + + Instantiates each callback from its ``_target_`` config and + dispatches hook calls to all registered callbacks. + """ + + def __init__( + self, + callback_configs: dict[str, dict[str, Any]], + training_config: TrainingConfig, + ) -> None: + self._callbacks: dict[str, Callback] = {} + if not callback_configs: + return + for name, cb_cfg in callback_configs.items(): + cb_cfg = dict(cb_cfg) + if "_target_" not in cb_cfg: + if name in _BUILTIN_CALLBACKS: + cb_cfg["_target_"] = ( + _BUILTIN_CALLBACKS[name] + ) + else: + logger.warning( + "Callback %r is missing " + "'_target_', skipping: %s", + name, + cb_cfg, + ) + continue + logger.info( + "Instantiating callback %r: %s", + name, + cb_cfg, + ) + cb = instantiate(cb_cfg) + if not isinstance(cb, Callback): + raise TypeError( + f"Callback {name!r} resolved to " + f"{type(cb).__name__}, expected a " + f"Callback subclass." + ) + cb.training_config = training_config + self._callbacks[name] = cb + + def __getattr__( + self, method_name: str, + ) -> Callable[..., Any]: + if method_name.startswith("_"): + raise AttributeError(method_name) + + if method_name == "state_dict": + + def _state_dict() -> dict[str, Any]: + return { + n: cb.state_dict() + for n, cb in self._callbacks.items() + } + + return _state_dict + + if method_name == "load_state_dict": + + def _load_state_dict( + state_dict: dict[str, Any], + ) -> None: + for n, cb in self._callbacks.items(): + if n in state_dict: + cb.load_state_dict(state_dict[n]) + else: + logger.warning( + "Callback %r not found in " + "checkpoint.", + n, + ) + + return _load_state_dict + + def _dispatch(*args: Any, **kwargs: Any) -> None: + for cb in self._callbacks.values(): + fn = getattr(cb, method_name, None) + if fn is None: + continue + if not callable(fn): + continue + fn(*args, **kwargs) + + return _dispatch diff --git a/fastvideo/train/callbacks/ema.py b/fastvideo/train/callbacks/ema.py new file mode 100644 index 000000000..328e39a76 --- /dev/null +++ b/fastvideo/train/callbacks/ema.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +"""EMA (Exponential Moving Average) callback. + +Updates EMA shadow weights after each training step. The model owns +the EMA network (created by ``ModelBase._setup_ema``); this callback +only performs the ``lerp_`` update. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +import torch + +from fastvideo.logger import init_logger +from fastvideo.train.callbacks.callback import Callback + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + +logger = init_logger(__name__) + + +class EMACallback(Callback): + """Update EMA parameters after each optimizer step. + + The EMA network lives on the method (``method.ema``). + If the method was created with ``use_ema: false``, the callback + detects this at train start and disables itself gracefully. + + Supports three beta strategies: + - ``constant``: fixed ``beta`` every step. + - ``power``: ``(1 - 1/t)^(gamma+1)``. + - ``halflife``: half-life in k-images with optional ramp-up. + """ + + def __init__( + self, + *, + type: str = "constant", + beta: float = 0.9999, + gamma: float = 16.97, + ema_halflife_kimg: float = 500.0, + ema_rampup_ratio: float | None = 0.05, + start_iter: int = 0, + batch_size: int = 1, + ) -> None: + self._type = str(type) + self._beta = float(beta) + self._gamma = float(gamma) + self._ema_halflife_kimg = float(ema_halflife_kimg) + self._ema_rampup_ratio = ( + float(ema_rampup_ratio) + if ema_rampup_ratio is not None + else None + ) + self._start_iter = int(start_iter) + self._batch_size = int(batch_size) + self._enabled = True + + # ---------------------------------------------------------- + # Hooks + # ---------------------------------------------------------- + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + ema = getattr(method, "ema", None) + if ema is None: + self._enabled = False + logger.info( + "EMA not found on method; " + "EMA callback disabled.", + ) + return + + assert not ema.training, ( + "EMA should be in eval mode" + ) + for name, p in ema.named_parameters(): + assert not p.requires_grad, ( + f"EMA parameter {name} should not " + f"require gradients" + ) + + def on_training_step_end( + self, + method: TrainingMethod, + loss_dict: dict[str, Any], + iteration: int = 0, + ) -> None: + if not self._enabled: + return + + if iteration < self._start_iter: + return + if iteration == self._start_iter: + logger.info( + "Starting EMA %r updates at iteration %d.", + "ema", + iteration, + ) + + beta = self._compute_beta(iteration) + ema = method.ema + ema_state = ema.state_dict() + + with torch.no_grad(): + for name, p_net in ( + method.student.transformer.named_parameters() + ): + full = self._gather_full(p_net) + ema_key = name.replace( + "_checkpoint_wrapped_module.", "", + ) + if ema_key not in ema_state: + if iteration == self._start_iter: + logger.warning( + "EMA param %r not found, " + "skipping.", + ema_key, + ) + continue + ema_p = ema_state[ema_key] + val = full.to( + device=ema_p.device, + dtype=ema_p.dtype, + ) + if iteration == self._start_iter: + ema_p.copy_(val) + else: + ema_p.lerp_(val, 1.0 - beta) + + for name, buf in ( + method.student.transformer.named_buffers() + ): + if name in ema_state: + ema_state[name].copy_( + buf.to( + device=ema_state[name].device, + dtype=ema_state[name].dtype, + ) + ) + + tracker = getattr(method, "tracker", None) + if tracker is not None: + tracker.log( + {"ema/beta": beta}, + iteration, + ) + + # ---------------------------------------------------------- + # Beta strategies + # ---------------------------------------------------------- + + def _compute_beta(self, iteration: int) -> float: + if self._type == "constant": + return self._beta + if self._type == "power": + it = max(iteration, 1) + return (1.0 - 1.0 / it) ** (self._gamma + 1) + if self._type == "halflife": + return self._halflife_beta(iteration) + raise ValueError( + f"Invalid EMA type: {self._type!r}" + ) + + def _halflife_beta(self, iteration: int) -> float: + hl_nimg = self._ema_halflife_kimg * 1000.0 + cur_nimg = iteration * self._batch_size + if self._ema_rampup_ratio is not None: + hl_nimg = min( + hl_nimg, + cur_nimg * self._ema_rampup_ratio, + ) + return 0.5 ** ( + self._batch_size / max(hl_nimg, 1e-8) + ) + + # ---------------------------------------------------------- + # FSDP helper + # ---------------------------------------------------------- + + @staticmethod + def _gather_full( + param: torch.Tensor, + ) -> torch.Tensor: + if hasattr(param, "full_tensor"): + if param.device.type == "cpu": + return param.to("cuda").full_tensor() + return param.full_tensor() + return param diff --git a/fastvideo/train/callbacks/grad_clip.py b/fastvideo/train/callbacks/grad_clip.py new file mode 100644 index 000000000..f8d445422 --- /dev/null +++ b/fastvideo/train/callbacks/grad_clip.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Gradient norm clipping callback. + +Clips gradients on modules returned by +``method.get_grad_clip_targets()`` before the optimizer step. +Optionally logs per-module grad norms to the tracker. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from fastvideo.logger import init_logger +from fastvideo.train.callbacks.callback import Callback +from fastvideo.train.utils.optimizer import ( + clip_grad_norm_if_needed, +) + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + +logger = init_logger(__name__) + + +class GradNormClipCallback(Callback): + """Clip gradient norms before the optimizer step. + + ``max_grad_norm`` must be set explicitly in the callback + config (``callbacks.grad_clip.max_grad_norm``). + """ + + def __init__( + self, + *, + max_grad_norm: float = 0.0, + log_grad_norms: bool = False, + ) -> None: + self._max_grad_norm = float(max_grad_norm) + self._log_grad_norms = bool(log_grad_norms) + + def on_before_optimizer_step( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + max_norm = self._max_grad_norm + if max_norm <= 0.0: + return + + targets = method.get_grad_clip_targets(iteration) + tracker = getattr(method, "tracker", None) + + for name, module in targets.items(): + grad_norm = clip_grad_norm_if_needed( + module, max_norm, + ) + if ( + self._log_grad_norms + and tracker is not None + and grad_norm > 0.0 + ): + tracker.log( + {f"grad_norm/{name}": grad_norm}, + iteration, + ) diff --git a/fastvideo/train/callbacks/validation.py b/fastvideo/train/callbacks/validation.py new file mode 100644 index 000000000..b0b152591 --- /dev/null +++ b/fastvideo/train/callbacks/validation.py @@ -0,0 +1,768 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Validation callback (unified replacement for WanValidator +and WanGameValidator). + +All configuration is read from the YAML ``callbacks.validation`` +section. The pipeline class is resolved from +``pipeline_target``. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING + +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from torch.utils.data import DataLoader + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.validation_dataset import ( + ValidationDataset, ) +from fastvideo.distributed import ( + get_sp_group, + get_world_group, +) +from fastvideo.logger import init_logger +from fastvideo.pipelines import ForwardBatch +from fastvideo.train.callbacks.callback import Callback +from fastvideo.train.utils.instantiate import resolve_target +from fastvideo.train.utils.moduleloader import ( + make_inference_args, ) +from fastvideo.training.trackers import DummyTracker +from fastvideo.utils import shallow_asdict + +if TYPE_CHECKING: + from fastvideo.train.methods.base import TrainingMethod + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +logger = init_logger(__name__) + + +@dataclass(slots=True) +class _ValidationStepResult: + videos: list[list[np.ndarray]] + captions: list[str] + + +class ValidationCallback(Callback): + """Generic validation callback driven entirely by YAML + config. + + Works with any pipeline that follows the + ``PipelineCls.from_pretrained(...)`` + ``pipeline.forward()`` + contract (Wan, WanGame parallel, WanGame causal/DMD, etc.). + """ + + def __init__( + self, + *, + pipeline_target: str, + dataset_file: str, + every_steps: int = 100, + sampling_steps: list[int] | None = None, + sampler_kind: str = "ode", + scheduler_target: str | None = None, + guidance_scale: float | None = None, + num_frames: int | None = None, + output_dir: str | None = None, + sampling_timesteps: list[int] | None = None, + rollout_mode: str = "parallel", + **pipeline_kwargs: Any, + ) -> None: + self.pipeline_target = str(pipeline_target) + self.dataset_file = str(dataset_file) + self.every_steps = int(every_steps) + self.sampling_steps = ( + [int(s) for s in sampling_steps] + if sampling_steps + else [40] + ) + self.sampler_kind = str(sampler_kind) + self.scheduler_target = ( + str(scheduler_target) + if scheduler_target is not None + else None + ) + self.guidance_scale = ( + float(guidance_scale) + if guidance_scale is not None + else None + ) + self.num_frames = ( + int(num_frames) if num_frames is not None + else None + ) + self.output_dir = ( + str(output_dir) if output_dir is not None + else None + ) + self.sampling_timesteps = ( + [int(s) for s in sampling_timesteps] + if sampling_timesteps is not None + else None + ) + self.rollout_mode = str(rollout_mode) + self.pipeline_kwargs = dict(pipeline_kwargs) + + # Set after on_train_start. + self._pipeline: Any | None = None + self._pipeline_key: tuple[Any, ...] | None = None + self._sampling_param: SamplingParam | None = None + self.tracker: Any = DummyTracker() + self.validation_random_generator: ( + torch.Generator | None + ) = None + self.seed: int = 0 + + # ---------------------------------------------------------- + # Callback hooks + # ---------------------------------------------------------- + + def on_train_start( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + self.method = method + tc = self.training_config + + self.world_group = get_world_group() + self.sp_group = get_sp_group() + self.global_rank = self.world_group.rank + self.rank_in_sp_group = ( + self.sp_group.rank_in_group + ) + self.sp_world_size = self.sp_group.world_size + + seed = tc.data.seed + if seed is None: + raise ValueError( + "training.data.seed must be set " + "for validation" + ) + self.seed = int(seed) + self.validation_random_generator = ( + torch.Generator(device="cpu").manual_seed( + self.seed + ) + ) + + tracker = getattr(method, "tracker", None) + if tracker is not None: + self.tracker = tracker + + def on_validation_begin( + self, + method: TrainingMethod, + iteration: int = 0, + ) -> None: + if self.every_steps <= 0: + return + if iteration % self.every_steps != 0: + return + + self._run_validation(method, iteration) + + # ---------------------------------------------------------- + # Core validation logic + # ---------------------------------------------------------- + + def _run_validation( + self, + method: TrainingMethod, + step: int, + ) -> None: + tc = self.training_config + # Use EMA transformer for validation when available. + transformer = method.transformer_inference + was_training = bool( + getattr(transformer, "training", False) + ) + + output_dir = ( + self.output_dir + or tc.checkpoint.output_dir + ) + + # For streaming SDE pipelines we may need to + # temporarily set dmd_denoising_steps on + # pipeline_config. + old_dmd_denoising_steps = getattr( + tc.pipeline_config, + "dmd_denoising_steps", + None, + ) + try: + transformer.eval() + num_sp_groups = ( + self.world_group.world_size + // self.sp_group.world_size + ) + + for num_inference_steps in self.sampling_steps: + self._maybe_set_dmd_denoising_steps( + tc, + num_inference_steps, + ) + + result = self._run_validation_for_steps( + num_inference_steps, + transformer=transformer, + ) + + if self.rank_in_sp_group != 0: + continue + + if self.global_rank == 0: + all_videos = list(result.videos) + all_captions = list(result.captions) + for sp_idx in range( + 1, num_sp_groups + ): + src = ( + sp_idx * self.sp_world_size + ) + recv_v = ( + self.world_group.recv_object( + src=src + ) + ) + recv_c = ( + self.world_group.recv_object( + src=src + ) + ) + all_videos.extend(recv_v) + all_captions.extend(recv_c) + + os.makedirs( + output_dir, exist_ok=True, + ) + video_filenames: list[str] = [] + sp = self._get_sampling_param() + for i, video in enumerate(all_videos): + fname = os.path.join( + output_dir, + f"validation_step_{step}" + f"_inference_steps_" + f"{num_inference_steps}" + f"_video_{i}.mp4", + ) + imageio.mimsave( + fname, + video, + fps=sp.fps, + ) + video_filenames.append(fname) + + video_logs = [] + for fname, cap in zip( + video_filenames, + all_captions, + strict=True, + ): + art = self.tracker.video( + fname, caption=cap, + ) + if art is not None: + video_logs.append(art) + if video_logs: + logs = { + f"validation_videos_" + f"{num_inference_steps}" + f"_steps": video_logs + } + self.tracker.log_artifacts( + logs, step, + ) + else: + self.world_group.send_object( + result.videos, dst=0, + ) + self.world_group.send_object( + result.captions, dst=0, + ) + finally: + if hasattr(tc.pipeline_config, "dmd_denoising_steps"): + tc.pipeline_config.dmd_denoising_steps = ( + old_dmd_denoising_steps + ) + if was_training: + transformer.train() + + def _maybe_set_dmd_denoising_steps( + self, + tc: TrainingConfig, + num_inference_steps: int, + ) -> None: + """Set dmd_denoising_steps on pipeline_config for + streaming SDE validation.""" + if self.rollout_mode != "streaming": + return + if self.sampler_kind != "sde": + return + if self.sampling_timesteps is not None: + tc.pipeline_config.dmd_denoising_steps = ( # type: ignore[union-attr] + list(self.sampling_timesteps) + ) + else: + timesteps = np.linspace( + 1000, 0, int(num_inference_steps), + ) + tc.pipeline_config.dmd_denoising_steps = [ # type: ignore[union-attr] + int(max(0, min(1000, round(t)))) + for t in timesteps + ] + + # Also set any pipeline-specific kwargs from + # YAML (e.g. dmd_denoising_steps override). + pk = self.pipeline_kwargs + if "dmd_denoising_steps" in pk: + tc.pipeline_config.dmd_denoising_steps = [ # type: ignore[union-attr] + int(s) + for s in pk["dmd_denoising_steps"] + ] + + # ---------------------------------------------------------- + # Pipeline management + # ---------------------------------------------------------- + + def _get_sampling_param(self) -> SamplingParam: + if self._sampling_param is None: + self._sampling_param = ( + SamplingParam.from_pretrained( + self.training_config.model_path + ) + ) + return self._sampling_param + + def _get_pipeline( + self, + *, + transformer: torch.nn.Module, + ) -> Any: + key = ( + id(transformer), + self.rollout_mode, + self.sampler_kind, + self.scheduler_target, + ) + if ( + self._pipeline is not None + and self._pipeline_key == key + ): + return self._pipeline + + tc = self.training_config + PipelineCls = resolve_target(self.pipeline_target) + flow_shift = getattr( + tc.pipeline_config, "flow_shift", None, + ) + + kwargs: dict[str, Any] = { + "inference_mode": True, + "sampler_kind": self.sampler_kind, + "loaded_modules": { + "transformer": transformer, + }, + "tp_size": tc.distributed.tp_size, + "sp_size": tc.distributed.sp_size, + "num_gpus": tc.distributed.num_gpus, + "pin_cpu_memory": ( + tc.distributed.pin_cpu_memory + ), + "dit_cpu_offload": True, + } + if flow_shift is not None: + kwargs["flow_shift"] = float(flow_shift) + + # Build and inject a scheduler if target is set. + scheduler = self._build_scheduler(flow_shift) + if scheduler is not None: + kwargs["loaded_modules"]["scheduler"] = ( + scheduler + ) + + self._pipeline = PipelineCls.from_pretrained( + tc.model_path, **kwargs, + ) + self._pipeline_key = key + return self._pipeline + + def _build_scheduler( + self, flow_shift: float | None, + ) -> Any | None: + """Build scheduler from ``scheduler_target``.""" + if self.scheduler_target is None: + return None + if flow_shift is None: + return None + + SchedulerCls = resolve_target( + self.scheduler_target + ) + return SchedulerCls(shift=float(flow_shift)) + + # ---------------------------------------------------------- + # Batch preparation + # ---------------------------------------------------------- + + def _prepare_validation_batch( + self, + sampling_param: SamplingParam, + validation_batch: dict[str, Any], + num_inference_steps: int, + ) -> ForwardBatch: + tc = self.training_config + + sampling_param.prompt = validation_batch["prompt"] + sampling_param.height = tc.data.num_height + sampling_param.width = tc.data.num_width + sampling_param.num_inference_steps = int( + num_inference_steps + ) + sampling_param.data_type = "video" + if self.guidance_scale is not None: + sampling_param.guidance_scale = float( + self.guidance_scale + ) + sampling_param.seed = self.seed + + # image_path for I2V pipelines. + img_path = ( + validation_batch.get("image_path") + or validation_batch.get("video_path") + ) + if img_path is not None and ( + img_path.startswith("http") + or os.path.isfile(img_path) + ): + sampling_param.image_path = img_path + + temporal_compression_factor = int( + tc.pipeline_config.vae_config.arch_config.temporal_compression_ratio # type: ignore[union-attr] + ) + default_num_frames = ( + (tc.data.num_latent_t - 1) + * temporal_compression_factor + + 1 + ) + if self.num_frames is not None: + sampling_param.num_frames = int( + self.num_frames + ) + else: + sampling_param.num_frames = int( + default_num_frames + ) + + latents_size = [ + (sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, + sampling_param.width // 8, + ] + n_tokens = ( + latents_size[0] + * latents_size[1] + * latents_size[2] + ) + + sampling_timesteps_tensor = ( + torch.tensor( + [int(s) for s in self.sampling_timesteps], + dtype=torch.long, + ) + if self.sampling_timesteps is not None + else None + ) + + inference_args = make_inference_args( + tc, model_path=tc.model_path, + ) + + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=self.validation_random_generator, + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=tc.vsa.sparsity, + timesteps=sampling_timesteps_tensor, + sampling_timesteps=sampling_timesteps_tensor, + ) + batch._inference_args = inference_args # type: ignore[attr-defined] + + # Conditionally set I2V / WanGame fields. + if ( + "image" in validation_batch + and validation_batch["image"] is not None + ): + batch.pil_image = validation_batch["image"] + + self._maybe_set_action_conds( + batch, validation_batch, sampling_param, + ) + return batch + + def _maybe_set_action_conds( + self, + batch: ForwardBatch, + validation_batch: dict[str, Any], + sampling_param: SamplingParam, + ) -> None: + """Set keyboard_cond / mouse_cond on the batch if + present in the dataset.""" + target_len = int(sampling_param.num_frames) + + if ( + "keyboard_cond" in validation_batch + and validation_batch["keyboard_cond"] + is not None + ): + kb = torch.as_tensor( + validation_batch["keyboard_cond"] + ).to(dtype=torch.bfloat16) + if kb.ndim == 3 and kb.shape[0] == 1: + kb = kb.squeeze(0) + if kb.ndim != 2: + raise ValueError( + "validation keyboard_cond must have" + " shape (T, K), got " + f"{tuple(kb.shape)}" + ) + if kb.shape[0] > target_len: + kb = kb[:target_len] + elif kb.shape[0] < target_len: + pad = torch.zeros( + ( + target_len - kb.shape[0], + kb.shape[1], + ), + dtype=kb.dtype, + device=kb.device, + ) + kb = torch.cat([kb, pad], dim=0) + batch.keyboard_cond = kb.unsqueeze(0) + + if ( + "mouse_cond" in validation_batch + and validation_batch["mouse_cond"] + is not None + ): + mc = torch.as_tensor( + validation_batch["mouse_cond"] + ).to(dtype=torch.bfloat16) + if mc.ndim == 3 and mc.shape[0] == 1: + mc = mc.squeeze(0) + if mc.ndim != 2: + raise ValueError( + "validation mouse_cond must have " + "shape (T, 2), got " + f"{tuple(mc.shape)}" + ) + if mc.shape[0] > target_len: + mc = mc[:target_len] + elif mc.shape[0] < target_len: + pad = torch.zeros( + ( + target_len - mc.shape[0], + mc.shape[1], + ), + dtype=mc.dtype, + device=mc.device, + ) + mc = torch.cat([mc, pad], dim=0) + batch.mouse_cond = mc.unsqueeze(0) + + # ---------------------------------------------------------- + # Post-processing + # ---------------------------------------------------------- + + def _post_process_validation_frames( + self, + frames: list[np.ndarray], + batch: ForwardBatch, + ) -> list[np.ndarray]: + """Overlay action indicators if conditions present.""" + keyboard_cond = getattr(batch, "keyboard_cond", None) + mouse_cond = getattr(batch, "mouse_cond", None) + if keyboard_cond is None and mouse_cond is None: + return frames + + try: + from fastvideo.models.dits.matrixgame.utils import ( + draw_keys_on_frame, + draw_mouse_on_frame, + ) + except Exception as e: + logger.warning( + "Action overlay unavailable: %s", e, + ) + return frames + + if ( + keyboard_cond is not None + and torch.is_tensor(keyboard_cond) + ): + keyboard_np = ( + keyboard_cond.squeeze(0) + .detach() + .cpu() + .float() + .numpy() + ) + else: + keyboard_np = None + + if ( + mouse_cond is not None + and torch.is_tensor(mouse_cond) + ): + mouse_np = ( + mouse_cond.squeeze(0) + .detach() + .cpu() + .float() + .numpy() + ) + else: + mouse_np = None + + key_names = ["W", "S", "A", "D", "left", "right"] + processed: list[np.ndarray] = [] + for fi, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + if ( + keyboard_np is not None + and fi < len(keyboard_np) + ): + keys = { + key_names[i]: bool( + keyboard_np[fi, i] + ) + for i in range( + min( + len(key_names), + int(keyboard_np.shape[1]), + ) + ) + } + draw_keys_on_frame( + frame, keys, mode="universal", + ) + if ( + mouse_np is not None + and fi < len(mouse_np) + ): + pitch = float(mouse_np[fi, 0]) + yaw = float(mouse_np[fi, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + processed.append(frame) + return processed + + # ---------------------------------------------------------- + # Validation loop + # ---------------------------------------------------------- + + def _run_validation_for_steps( + self, + num_inference_steps: int, + *, + transformer: torch.nn.Module, + ) -> _ValidationStepResult: + tc = self.training_config + pipeline = self._get_pipeline( + transformer=transformer, + ) + sampling_param = self._get_sampling_param() + + dataset = ValidationDataset(self.dataset_file) + dataloader = DataLoader( + dataset, batch_size=None, num_workers=0, + ) + + inference_args = make_inference_args( + tc, model_path=tc.model_path, + ) + + videos: list[list[np.ndarray]] = [] + captions: list[str] = [] + + for validation_batch in dataloader: + batch = self._prepare_validation_batch( + sampling_param, + validation_batch, + num_inference_steps, + ) + + assert ( + batch.prompt is not None + and isinstance(batch.prompt, str) + ) + captions.append(batch.prompt) + + with torch.no_grad(): + output_batch = pipeline.forward( + batch, inference_args, + ) + + samples = output_batch.output.cpu() + if self.rank_in_sp_group != 0: + continue + + video = rearrange( + samples, "b c t h w -> t b c h w", + ) + frames: list[np.ndarray] = [] + for x in video: + x = torchvision.utils.make_grid( + x, nrow=6, + ) + x = ( + x.transpose(0, 1) + .transpose(1, 2) + .squeeze(-1) + ) + frames.append( + (x * 255).numpy().astype(np.uint8) + ) + frames = ( + self._post_process_validation_frames( + frames, batch, + ) + ) + videos.append(frames) + + return _ValidationStepResult( + videos=videos, captions=captions, + ) + + # ---------------------------------------------------------- + # State management + # ---------------------------------------------------------- + + def state_dict(self) -> dict[str, Any]: + state: dict[str, Any] = {} + if self.validation_random_generator is not None: + state["validation_rng"] = ( + self.validation_random_generator.get_state() + ) + return state + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + rng_state = state_dict.get("validation_rng") + if ( + rng_state is not None + and self.validation_random_generator is not None + ): + self.validation_random_generator.set_state( + rng_state + ) diff --git a/fastvideo/train/entrypoint/__init__.py b/fastvideo/train/entrypoint/__init__.py new file mode 100644 index 000000000..988131360 --- /dev/null +++ b/fastvideo/train/entrypoint/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/fastvideo/train/entrypoint/dcp_to_diffusers.py b/fastvideo/train/entrypoint/dcp_to_diffusers.py new file mode 100644 index 000000000..a62dde639 --- /dev/null +++ b/fastvideo/train/entrypoint/dcp_to_diffusers.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Convert a DCP training checkpoint to a diffusers-style model directory. + +Works on a single GPU regardless of how many GPUs were used for training +(DCP handles resharding automatically). + +Usage (no torchrun needed):: + + python -m fastvideo.train.entrypoint.dcp_to_diffusers \ + --checkpoint /path/to/checkpoint-1000 \ + --output-dir /path/to/diffusers_output + +Or with torchrun (also fine):: + + torchrun --nproc_per_node=1 \ + -m fastvideo.train.entrypoint.dcp_to_diffusers \ + --checkpoint ... --output-dir ... + +The checkpoint must contain ``metadata.json`` (written by +``CheckpointManager``). If the checkpoint predates metadata +support, pass ``--config`` explicitly to provide the training +YAML. +""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Any + +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +def _ensure_distributed() -> None: + """Set up a single-process distributed env if needed. + + When running under ``torchrun`` the env vars are already set. + For plain ``python`` we fill in the minimum required vars so + that ``init_process_group`` succeeds with world_size=1. + """ + for key, default in [ + ("RANK", "0"), + ("LOCAL_RANK", "0"), + ("WORLD_SIZE", "1"), + ("MASTER_ADDR", "127.0.0.1"), + ("MASTER_PORT", "29500"), + ]: + os.environ.setdefault(key, default) + + +def _save_role_pretrained( + *, + role: str, + base_model_path: str, + output_dir: str, + module_names: list[str] | None = None, + overwrite: bool = False, + model: Any, +) -> str: + """Export a role's modules into a diffusers-style model dir. + + Produces a ``model_path`` loadable by + ``PipelineComponentLoader`` (``model_index.json``, + ``transformer/``, ``vae/``, etc. copied from + ``base_model_path``). + """ + import shutil + from pathlib import Path + + import torch + import torch.distributed as dist + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) + + from fastvideo.utils import maybe_download_model + + def _rank() -> int: + if dist.is_available() and dist.is_initialized(): + return int(dist.get_rank()) + return 0 + + def _barrier() -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + local_base = Path( + maybe_download_model(str(base_model_path)) + ).resolve() + dst = Path( + os.path.expanduser(str(output_dir)) + ).resolve() + + if _rank() == 0: + if dst.exists(): + if overwrite: + shutil.rmtree(dst, ignore_errors=True) + else: + raise FileExistsError( + f"Refusing to overwrite existing " + f"directory: {dst}. " + "Pass --overwrite to replace it." + ) + + def _copy_or_link(src: str, dest: str) -> None: + try: + os.link(src, dest) + except OSError: + shutil.copy2(src, dest) + + logger.info( + "Creating pretrained export dir at %s " + "(base=%s)", dst, local_base, + ) + shutil.copytree( + local_base, dst, symlinks=True, + copy_function=_copy_or_link, + ) + + _barrier() + + modules: dict[str, torch.nn.Module] = {} + if model.transformer is not None: + modules["transformer"] = model.transformer + + if module_names is None: + module_names = sorted(modules.keys()) + + for module_name in module_names: + if module_name not in modules: + raise KeyError( + f"Role {role!r} does not have module " + f"{module_name!r}. " + f"Available: {sorted(modules.keys())}" + ) + + module_dir = dst / module_name + if not module_dir.is_dir(): + raise FileNotFoundError( + f"Export directory missing component " + f"dir {module_name!r}: {module_dir}" + ) + + options = StateDictOptions( + full_state_dict=True, cpu_offload=True, + ) + state_dict = get_model_state_dict( + modules[module_name], options=options, + ) + + if _rank() == 0: + for path in module_dir.glob("*.safetensors"): + path.unlink(missing_ok=True) + + tensor_state: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + raise TypeError( + f"Expected tensor in state_dict " + f"for {module_name}.{key}, " + f"got {type(value).__name__}" + ) + tensor_state[key] = value.detach().cpu() + + from safetensors.torch import save_file + + out_path = module_dir / "model.safetensors" + logger.info( + "Saving %s weights to %s (%s tensors)", + module_name, out_path, + len(tensor_state), + ) + save_file(tensor_state, str(out_path)) + + _barrier() + + return str(dst) + + +def convert( + *, + checkpoint_dir: str, + output_dir: str, + config_path: str | None = None, + role: str = "student", + overwrite: bool = False, +) -> str: + """Load a DCP checkpoint and export as a diffusers model. + + Returns the path to the exported model directory. + """ + _ensure_distributed() + + from fastvideo.distributed import ( + maybe_init_distributed_environment_and_model_parallel, + ) + from fastvideo.train.utils.builder import build_from_config + from fastvideo.train.utils.checkpoint import ( + CheckpointManager, + _resolve_resume_checkpoint, + ) + from fastvideo.train.utils.config import ( + RunConfig, + load_run_config, + ) + + import torch.distributed.checkpoint as dcp + + # -- Resolve checkpoint directory -- + resolved = _resolve_resume_checkpoint( + checkpoint_dir, output_dir=checkpoint_dir, + ) + dcp_dir = resolved / "dcp" + if not dcp_dir.is_dir(): + raise FileNotFoundError( + f"Missing dcp/ under {resolved}" + ) + + # -- Obtain config -- + cfg: RunConfig + if config_path is not None: + cfg = load_run_config(config_path) + else: + metadata = CheckpointManager.load_metadata( + resolved + ) + raw_config = metadata.get("config") + if raw_config is None: + raise ValueError( + "Checkpoint metadata.json does not " + "contain 'config'. Pass --config " + "explicitly." + ) + cfg = _run_config_from_raw(raw_config) + + tc = cfg.training + + # -- Init distributed (1 GPU is enough; DCP reshards) -- + maybe_init_distributed_environment_and_model_parallel( + tp_size=1, sp_size=1, + ) + + # Override distributed config so model loading uses 1 GPU. + tc.distributed.tp_size = 1 + tc.distributed.sp_size = 1 + tc.distributed.num_gpus = 1 + tc.distributed.hsdp_replicate_dim = 1 + tc.distributed.hsdp_shard_dim = 1 + + # -- Build model (loads pretrained weights + FSDP) -- + _, method, _, _ = build_from_config(cfg) + + # -- Load DCP weights into the model -- + states = method.checkpoint_state() + logger.info( + "Loading DCP checkpoint from %s", resolved, + ) + dcp.load(states, checkpoint_id=str(dcp_dir)) + + # -- Export to diffusers format -- + model = method._role_models[role] + base_model_path = str(tc.model_path) + if not base_model_path: + raise ValueError( + "Cannot determine base_model_path from " + "config. Ensure models.student.init_from " + "is set." + ) + + logger.info( + "Exporting role=%s to %s (base=%s)", + role, + output_dir, + base_model_path, + ) + result = _save_role_pretrained( + role=role, + base_model_path=base_model_path, + output_dir=output_dir, + overwrite=overwrite, + model=model, + ) + logger.info("Export complete: %s", result) + return result + + +def _run_config_from_raw( + raw: dict[str, Any], +) -> Any: + """Reconstruct a RunConfig from a raw config dict. + + This mirrors ``load_run_config`` but operates on an + already-parsed dict (from metadata.json) instead of + reading from a YAML file. + """ + from fastvideo.train.utils.config import ( + RunConfig, + _build_training_config, + _parse_pipeline_config, + _require_mapping, + _require_str, + ) + + models_raw = _require_mapping( + raw.get("models"), where="models", + ) + models: dict[str, dict[str, Any]] = {} + for role_key, model_cfg_raw in models_raw.items(): + role_str = _require_str( + role_key, where="models.", + ) + model_cfg = _require_mapping( + model_cfg_raw, + where=f"models.{role_str}", + ) + models[role_str] = dict(model_cfg) + + method_raw = _require_mapping( + raw.get("method"), where="method", + ) + method = dict(method_raw) + + callbacks_raw = raw.get("callbacks", None) + callbacks: dict[str, dict[str, Any]] = ( + _require_mapping( + callbacks_raw, where="callbacks", + ) + if callbacks_raw is not None + else {} + ) + + pipeline_config = _parse_pipeline_config( + raw, models=models, + ) + + training_raw = _require_mapping( + raw.get("training"), where="training", + ) + t = dict(training_raw) + training = _build_training_config( + t, + models=models, + pipeline_config=pipeline_config, + ) + + return RunConfig( + models=models, + method=method, + training=training, + callbacks=callbacks, + raw=raw, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Convert a DCP training checkpoint to a " + "diffusers-style model directory. " + "Only 1 GPU needed (DCP reshards " + "automatically)." + ), + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help=( + "Path to checkpoint- dir, its dcp/ " + "subdir, or an output_dir (auto-picks " + "latest)." + ), + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Destination for the diffusers model.", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help=( + "Training YAML config. If omitted, read " + "from checkpoint metadata.json." + ), + ) + parser.add_argument( + "--role", + type=str, + default="student", + help="Role to export (default: student).", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output-dir if it exists.", + ) + args = parser.parse_args(sys.argv[1:]) + + convert( + checkpoint_dir=args.checkpoint, + output_dir=args.output_dir, + config_path=args.config, + role=args.role, + overwrite=args.overwrite, + ) + + +if __name__ == "__main__": + main() diff --git a/fastvideo/train/entrypoint/train.py b/fastvideo/train/entrypoint/train.py new file mode 100644 index 000000000..1c253004e --- /dev/null +++ b/fastvideo/train/entrypoint/train.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +"""YAML-only training entrypoint. + +Usage:: + + torchrun --nproc_per_node= -m fastvideo.train.entrypoint.train \ + --config path/to/run.yaml +""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Any + +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +def run_training_from_config( + config_path: str, + *, + dry_run: bool = False, + resume_from_checkpoint: str | None = None, + override_output_dir: str | None = None, +) -> None: + """YAML-only training entrypoint (schema v2).""" + + from fastvideo.distributed import ( + maybe_init_distributed_environment_and_model_parallel, + ) + from fastvideo.train import Trainer + from fastvideo.train.utils.checkpoint import ( + CheckpointConfig, + CheckpointManager, + ) + from fastvideo.train.utils.builder import build_from_config + from fastvideo.train.utils.config import load_run_config + + cfg = load_run_config(config_path) + tc = cfg.training + + if resume_from_checkpoint is not None: + tc.checkpoint.resume_from_checkpoint = str( + resume_from_checkpoint + ) + if override_output_dir is not None: + tc.checkpoint.output_dir = str(override_output_dir) + + maybe_init_distributed_environment_and_model_parallel( + tc.distributed.tp_size, + tc.distributed.sp_size, + ) + + _, method, dataloader, start_step = build_from_config( + cfg + ) + + if dry_run: + logger.info( + "Dry-run: config parsed and " + "build_from_config succeeded." + ) + return + + trainer = Trainer( + tc, + config=cfg.resolved_config(), + callback_configs=cfg.callbacks, + ) + + # Attach the exact YAML used for this run to the + # tracker (e.g., W&B Files). + trainer.tracker.log_file( + os.path.abspath(os.path.expanduser(config_path)), + name="run.yaml", + ) + + ckpt_config = CheckpointConfig( + save_steps=int( + tc.checkpoint.training_state_checkpointing_steps + or 0 + ), + keep_last=int( + tc.checkpoint.checkpoints_total_limit or 0 + ), + ) + + checkpoint_manager = CheckpointManager( + method=method, + dataloader=dataloader, + output_dir=tc.checkpoint.output_dir, + config=ckpt_config, + callbacks=trainer.callbacks, + raw_config=cfg.raw, + ) + + trainer.run( + method, + dataloader=dataloader, + max_steps=tc.loop.max_train_steps, + start_step=start_step, + checkpoint_manager=checkpoint_manager, + ) + + +def main(args: Any) -> None: + config_path = str(args.config) + dry_run = bool(args.dry_run) + resume_from_checkpoint = getattr( + args, "resume_from_checkpoint", None + ) + override_output_dir = getattr( + args, "override_output_dir", None + ) + logger.info( + "Starting training from config=%s", + config_path, + ) + run_training_from_config( + config_path, + dry_run=dry_run, + resume_from_checkpoint=resume_from_checkpoint, + override_output_dir=override_output_dir, + ) + logger.info("Training completed") + + +if __name__ == "__main__": + argv = sys.argv + parser = argparse.ArgumentParser( + description="YAML-only training entrypoint.", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help=( + "Path to training YAML config (schema v2)." + ), + ) + parser.add_argument( + "--dry-run", + action="store_true", + help=( + "Parse config and build runtime, " + "but do not start training." + ), + ) + parser.add_argument( + "--resume-from-checkpoint", + type=str, + default=None, + help=( + "Path to a checkpoint directory " + "(checkpoint-), its 'dcp/' subdir, " + "or an output_dir containing checkpoints " + "(auto-picks latest)." + ), + ) + parser.add_argument( + "--override-output-dir", + type=str, + default=None, + help=( + "Override training.output_dir from YAML " + "(useful for repeated runs)." + ), + ) + args = parser.parse_args(argv[1:]) + main(args) diff --git a/fastvideo/train/methods/__init__.py b/fastvideo/train/methods/__init__.py new file mode 100644 index 000000000..61fd6ef2e --- /dev/null +++ b/fastvideo/train/methods/__init__.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.methods.base import TrainingMethod + +__all__ = [ + "TrainingMethod", + "DMD2Method", + "FineTuneMethod", + "SelfForcingMethod", + "DiffusionForcingSFTMethod", +] + + +def __getattr__(name: str) -> object: + if name == "DMD2Method": + from fastvideo.train.methods.distribution_matching.dmd2 import DMD2Method + return DMD2Method + if name == "FineTuneMethod": + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + return FineTuneMethod + if name == "SelfForcingMethod": + from fastvideo.train.methods.distribution_matching.self_forcing import SelfForcingMethod + return SelfForcingMethod + if name == "DiffusionForcingSFTMethod": + from fastvideo.train.methods.fine_tuning.dfsft import DiffusionForcingSFTMethod + return DiffusionForcingSFTMethod + raise AttributeError(name) diff --git a/fastvideo/train/methods/base.py b/fastvideo/train/methods/base.py new file mode 100644 index 000000000..06b63a3d7 --- /dev/null +++ b/fastvideo/train/methods/base.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Any, Literal, cast + +import torch + +from fastvideo.logger import init_logger +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.checkpoint import _RoleModuleContainer +from fastvideo.training.checkpointing_utils import ( + ModelWrapper, + OptimizerWrapper, + RandomStateWrapper, + SchedulerWrapper, +) + +logger = init_logger(__name__) + +LogScalar = float | int | torch.Tensor + + +class TrainingMethod(torch.nn.Module, ABC): + """Base training method (algorithm layer). + + Subclasses own their role models (student, teacher, critic, …) as + plain attributes and manage optimizers directly — no ``RoleManager`` + or ``RoleHandle``. + + The constructor receives *role_models* (a ``dict[str, ModelBase]``) + and a *cfg* object. It calls ``init_preprocessors`` on the student + and builds ``self.role_modules`` for FSDP wrapping. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__() + self.tracker: Any | None = None + self._role_models: dict[str, ModelBase] = dict(role_models) + + self.student = role_models["student"] + self.training_config = cfg.training + self.method_config: dict[str, Any] = dict(cfg.method) + self.validation_config: dict[str, Any] = dict( + getattr(cfg, "validation", {}) or {} + ) + self._use_ema: bool = bool( + self.method_config.get("use_ema", False) + ) + + # Build nn.ModuleDict for FSDP / checkpoint visibility. + self.role_modules = torch.nn.ModuleDict() + for role, model in role_models.items(): + mods: dict[str, torch.nn.Module] = {} + transformer = getattr(model, "transformer", None) + if isinstance(transformer, torch.nn.Module): + mods["transformer"] = transformer + if mods: + self.role_modules[role] = torch.nn.ModuleDict(mods) + + self._setup_ema() + + # ------------------------------------------------------------------ + # EMA + # ------------------------------------------------------------------ + + def _setup_ema(self) -> None: + """Create EMA copy of student transformer. + + Called at the end of ``__init__``, before FSDP wrapping. + Only acts when ``use_ema: true`` is set in method config. + """ + if not self._use_ema: + return + logger.info( + "Initializing EMA from student transformer", + ) + ema = copy.deepcopy(self.student.transformer) + ema.eval().requires_grad_(False) + self.ema = ema + # Register in role_modules for FSDP / checkpoint. + if "student" not in self.role_modules: + self.role_modules["student"] = ( + torch.nn.ModuleDict() + ) + self.role_modules["student"]["ema"] = ema # type: ignore[index] + + @property + def transformer_inference(self) -> torch.nn.Module: + """Return EMA transformer for inference if available.""" + if self._use_ema: + ema = getattr(self, "ema", None) + if ema is not None: + return ema + return self.student.transformer + + # ------------------------------------------------------------------ + + def set_tracker(self, tracker: Any) -> None: + self.tracker = tracker + + @abstractmethod + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + raise NotImplementedError + + @abstractmethod + def get_optimizers( + self, iteration: int, + ) -> Sequence[torch.optim.Optimizer]: + raise NotImplementedError + + @abstractmethod + def get_lr_schedulers( + self, iteration: int, + ) -> Sequence[Any]: + raise NotImplementedError + + @property + @abstractmethod + def _optimizer_dict(self) -> dict[str, Any]: + ... + + @property + @abstractmethod + def _lr_scheduler_dict(self) -> dict[str, Any]: + ... + + def checkpoint_state(self) -> dict[str, Any]: + """Return DCP-ready checkpoint state for all trainable roles. + + Keys follow the convention: + ``roles..``, ``optimizers.``, + ``schedulers.``, ``random_state.*``. + """ + states: dict[str, Any] = {} + + for role, model in self._role_models.items(): + if not getattr(model, "_trainable", False): + continue + + modules: dict[str, torch.nn.Module] = {} + if model.transformer is not None: + modules["transformer"] = model.transformer + ema = getattr(self, "ema", None) + if role == "student" and ema is not None: + modules["ema"] = ema + + container = _RoleModuleContainer(modules) + + for module_name, module in modules.items(): + states[ + f"roles.{role}.{module_name}" + ] = ModelWrapper(module) + + opt = self._optimizer_dict.get(role) + if opt is not None: + states[ + f"optimizers.{role}" + ] = OptimizerWrapper(container, opt) + + sched = self._lr_scheduler_dict.get(role) + if sched is not None: + states[ + f"schedulers.{role}" + ] = SchedulerWrapper(sched) + + # RNG states. + states["random_state"] = RandomStateWrapper(None) + for name, gen in ( + self.get_rng_generators() or {} + ).items(): + if gen is not None: + states[ + f"random_state.{name}" + ] = RandomStateWrapper(gen) + + return states + + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + del outputs + grad_accum_rounds = max(1, int(grad_accum_rounds)) + (loss_map["total_loss"] / grad_accum_rounds).backward() + + def optimizers_schedulers_step( + self, iteration: int, + ) -> None: + for optimizer in self.get_optimizers(iteration): + optimizer.step() + for scheduler in self.get_lr_schedulers(iteration): + scheduler.step() + + def optimizers_zero_grad( + self, iteration: int, + ) -> None: + for optimizer in self.get_optimizers(iteration): + try: + optimizer.zero_grad(set_to_none=True) + except TypeError: + optimizer.zero_grad() + + # -- Shared hooks (override in subclasses as needed) -- + + def get_grad_clip_targets( + self, iteration: int, + ) -> dict[str, torch.nn.Module]: + """Return modules whose gradients should be clipped. + + Override in subclasses to add/conditionally include + modules (e.g. critic, conditionally student). + Default: student transformer. + """ + return {"student": self.student.transformer} + + def on_train_start(self) -> None: + self.student.on_train_start() + + def get_rng_generators( + self, + ) -> dict[str, torch.Generator]: + generators: dict[str, torch.Generator] = {} + + student_gens = self.student.get_rng_generators() + generators.update(student_gens) + + return generators + + @staticmethod + def _parse_attn_kind( + raw: Any, + ) -> Literal["dense", "vsa"]: + if raw in (None, ""): + return "dense" + kind = str(raw).strip().lower() + if kind not in {"dense", "vsa"}: + raise ValueError( + "method_config.attn_kind must be one of " + f"{{'dense', 'vsa'}}, got {raw!r}." + ) + return cast(Literal["dense", "vsa"], kind) diff --git a/fastvideo/train/methods/consistency_model/__init__.py b/fastvideo/train/methods/consistency_model/__init__.py new file mode 100644 index 000000000..324710b84 --- /dev/null +++ b/fastvideo/train/methods/consistency_model/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/fastvideo/train/methods/distribution_matching/__init__.py b/fastvideo/train/methods/distribution_matching/__init__.py new file mode 100644 index 000000000..4edb43cf7 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 + +from fastvideo.train.methods.distribution_matching.dmd2 import DMD2Method +from fastvideo.train.methods.distribution_matching.self_forcing import ( + SelfForcingMethod, ) + +__all__ = [ + "DMD2Method", + "SelfForcingMethod", +] diff --git a/fastvideo/train/methods/distribution_matching/dmd2.py b/fastvideo/train/methods/distribution_matching/dmd2.py new file mode 100644 index 000000000..9c2e07ef9 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/dmd2.py @@ -0,0 +1,745 @@ +# SPDX-License-Identifier: Apache-2.0 +"""DMD2 distillation method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) +from fastvideo.train.utils.config import ( + get_optional_float, + get_optional_int, + parse_betas, +) + + +class DMD2Method(TrainingMethod): + """DMD2 distillation algorithm (method layer). + + Owns role model instances directly: + - ``self.student`` — trainable student :class:`ModelBase` + - ``self.teacher`` — frozen teacher :class:`ModelBase` + - ``self.critic`` — trainable critic :class:`ModelBase` + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError( + "DMD2Method requires role 'student'" + ) + if "teacher" not in role_models: + raise ValueError( + "DMD2Method requires role 'teacher'" + ) + if "critic" not in role_models: + raise ValueError( + "DMD2Method requires role 'critic'" + ) + + self.teacher = role_models["teacher"] + self.critic = role_models["critic"] + + if not self.student._trainable: + raise ValueError( + "DMD2Method requires student to be trainable" + ) + if self.teacher._trainable: + raise ValueError( + "DMD2Method requires teacher to be " + "non-trainable" + ) + if not self.critic._trainable: + raise ValueError( + "DMD2Method requires critic to be trainable" + ) + self._cfg_uncond = self._parse_cfg_uncond() + self._rollout_mode = self._parse_rollout_mode() + self._denoising_step_list: torch.Tensor | None = ( + None + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict( + self, + ) -> dict[str, torch.optim.Optimizer]: + return { + "student": self._student_optimizer, + "critic": self._critic_optimizer, + } + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return { + "student": self._student_lr_scheduler, + "critic": self._critic_lr_scheduler, + } + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + latents_source: Literal["data", "zeros"] = "data" + if self._rollout_mode == "simulate": + latents_source = "zeros" + + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source=latents_source, + ) + + update_student = self._should_update_student( + iteration + ) + + generator_loss = torch.zeros( + (), + device=training_batch.latents.device, + dtype=training_batch.latents.dtype, + ) + student_ctx = None + if update_student: + generator_pred_x0 = self._student_rollout( + training_batch, with_grad=True + ) + student_ctx = ( + training_batch.timesteps, + training_batch.attn_metadata_vsa, + ) + generator_loss = self._dmd_loss( + generator_pred_x0, training_batch + ) + + ( + fake_score_loss, + critic_ctx, + critic_outputs, + ) = self._critic_flow_matching_loss(training_batch) + + total_loss = generator_loss + fake_score_loss + loss_map = { + "total_loss": total_loss, + "generator_loss": generator_loss, + "fake_score_loss": fake_score_loss, + } + + outputs: dict[str, Any] = dict(critic_outputs) + outputs["_fv_backward"] = { + "update_student": update_student, + "student_ctx": student_ctx, + "critic_ctx": critic_ctx, + } + metrics: dict[str, LogScalar] = { + "update_student": float(update_student) + } + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + backward_ctx = outputs.get("_fv_backward") + if not isinstance(backward_ctx, dict): + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + + update_student = bool( + backward_ctx.get("update_student", False) + ) + if update_student: + student_ctx = backward_ctx.get("student_ctx") + if student_ctx is None: + raise RuntimeError( + "Missing student backward context" + ) + self.student.backward( + loss_map["generator_loss"], + student_ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + critic_ctx = backward_ctx.get("critic_ctx") + if critic_ctx is None: + raise RuntimeError( + "Missing critic backward context" + ) + self.critic.backward( + loss_map["fake_score_loss"], + critic_ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + optimizers: list[torch.optim.Optimizer] = [] + optimizers.append(self._critic_optimizer) + if self._should_update_student(iteration): + optimizers.append(self._student_optimizer) + return optimizers + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + schedulers: list[Any] = [] + schedulers.append(self._critic_lr_scheduler) + if self._should_update_student(iteration): + schedulers.append(self._student_lr_scheduler) + return schedulers + + # TrainingMethod override: get_grad_clip_targets + def get_grad_clip_targets( + self, iteration: int, + ) -> dict[str, torch.nn.Module]: + targets: dict[str, torch.nn.Module] = {} + if self._should_update_student(iteration): + targets["student"] = ( + self.student.transformer + ) + targets["critic"] = self.critic.transformer + return targets + + def _parse_rollout_mode( + self, + ) -> Literal["simulate", "data_latent"]: + raw = self.method_config.get( + "rollout_mode", None + ) + if raw is None: + raise ValueError( + "method_config.rollout_mode must be set " + "for DMD2" + ) + if not isinstance(raw, str): + raise ValueError( + "method_config.rollout_mode must be a " + "string, " + f"got {type(raw).__name__}" + ) + mode = raw.strip().lower() + if mode in ("simulate", "sim"): + return "simulate" + if mode in ("data_latent", "data", "vae_latent"): + return "data_latent" + raise ValueError( + "method_config.rollout_mode must be one of " + "{simulate, data_latent}, got " + f"{raw!r}" + ) + + def _parse_cfg_uncond( + self, + ) -> dict[str, Any] | None: + raw = self.method_config.get("cfg_uncond", None) + if raw is None: + return None + if not isinstance(raw, dict): + raise ValueError( + "method_config.cfg_uncond must be a dict " + f"when set, got {type(raw).__name__}" + ) + + cfg: dict[str, Any] = dict(raw) + + on_missing_raw = cfg.get("on_missing", "error") + if on_missing_raw is None: + on_missing_raw = "error" + if not isinstance(on_missing_raw, str): + raise ValueError( + "method_config.cfg_uncond.on_missing must " + "be a string, got " + f"{type(on_missing_raw).__name__}" + ) + on_missing = on_missing_raw.strip().lower() + if on_missing not in {"error", "ignore"}: + raise ValueError( + "method_config.cfg_uncond.on_missing must " + "be one of {error, ignore}, got " + f"{on_missing_raw!r}" + ) + cfg["on_missing"] = on_missing + + for channel, policy_raw in list(cfg.items()): + if channel == "on_missing": + continue + if policy_raw is None: + continue + if not isinstance(policy_raw, str): + raise ValueError( + "method_config.cfg_uncond values must " + "be strings, got " + f"{channel}=" + f"{type(policy_raw).__name__}" + ) + policy = policy_raw.strip().lower() + allowed = {"keep", "zero", "drop"} + if channel == "text": + allowed = {*allowed, "negative_prompt"} + if policy not in allowed: + raise ValueError( + "method_config.cfg_uncond values must " + "be one of " + f"{sorted(allowed)}, got " + f"{channel}={policy_raw!r}" + ) + cfg[channel] = policy + + return cfg + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + + # Student optimizer/scheduler. + student_lr = float(tc.optimizer.learning_rate) + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) + + # Critic optimizer/scheduler — must be set in + # method config. + critic_lr_raw = get_optional_float( + self.method_config, + "fake_score_learning_rate", + where="method.fake_score_learning_rate", + ) + if critic_lr_raw is None or critic_lr_raw == 0.0: + raise ValueError( + "method.fake_score_learning_rate must " + "be set to a positive value" + ) + critic_lr = float(critic_lr_raw) + + critic_betas_raw = self.method_config.get( + "fake_score_betas", None + ) + if critic_betas_raw is None: + raise ValueError( + "method.fake_score_betas must be set " + "(e.g. [0.0, 0.999])" + ) + critic_betas = parse_betas( + critic_betas_raw, + where="method.fake_score_betas", + ) + + critic_sched_raw = self.method_config.get( + "fake_score_lr_scheduler", None + ) + if critic_sched_raw is None: + raise ValueError( + "method.fake_score_lr_scheduler must " + "be set (e.g. 'constant')" + ) + critic_sched = str(critic_sched_raw) + critic_params = [ + p + for p in self.critic.transformer.parameters() + if p.requires_grad + ] + ( + self._critic_optimizer, + self._critic_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=critic_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=critic_lr, + betas=critic_betas, + scheduler_name=critic_sched, + ) + + def _should_update_student( + self, iteration: int, + ) -> bool: + interval = get_optional_int( + self.method_config, + "generator_update_interval", + where="method.generator_update_interval", + ) + if interval is None: + interval = 1 + if interval <= 0: + return True + return iteration % interval == 0 + + def _get_denoising_step_list( + self, device: torch.device, + ) -> torch.Tensor: + if ( + self._denoising_step_list is not None + and self._denoising_step_list.device == device + ): + return self._denoising_step_list + + raw = self.method_config.get( + "dmd_denoising_steps", None + ) + if not isinstance(raw, list) or not raw: + raise ValueError( + "method_config.dmd_denoising_steps must " + "be set for DMD2 distillation" + ) + + steps = torch.tensor( + [int(s) for s in raw], + dtype=torch.long, + device=device, + ) + + warp = self.method_config.get( + "warp_denoising_step", None + ) + if warp is None: + warp = False + if bool(warp): + timesteps = torch.cat(( + self.student.noise_scheduler.timesteps.to( + "cpu" + ), + torch.tensor( + [0], dtype=torch.float32 + ), + )).to(device) + steps = timesteps[1000 - steps] + + self._denoising_step_list = steps + return steps + + def _sample_rollout_timestep( + self, device: torch.device, + ) -> torch.Tensor: + step_list = self._get_denoising_step_list(device) + index = torch.randint( + 0, + len(step_list), + [1], + device=device, + dtype=torch.long, + ) + return step_list[index] + + def _student_rollout( + self, batch: Any, *, with_grad: bool, + ) -> torch.Tensor: + latents = batch.latents + device = latents.device + dtype = latents.dtype + step_list = self._get_denoising_step_list(device) + + if self._rollout_mode != "simulate": + timestep = self._sample_rollout_timestep( + device + ) + noise = torch.randn( + latents.shape, device=device, dtype=dtype + ) + noisy_latents = self.student.add_noise( + latents, noise, timestep + ) + pred_x0 = self.student.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + batch.dmd_latent_vis_dict[ + "generator_timestep" + ] = timestep + return pred_x0 + + target_timestep_idx = torch.randint( + 0, + len(step_list), + [1], + device=device, + dtype=torch.long, + ) + target_timestep_idx_int = int( + target_timestep_idx.item() + ) + target_timestep = step_list[target_timestep_idx] + + current_noise_latents = torch.randn( + latents.shape, device=device, dtype=dtype + ) + current_noise_latents_copy = ( + current_noise_latents.clone() + ) + + max_target_idx = len(step_list) - 1 + noise_latents: list[torch.Tensor] = [] + noise_latent_index = target_timestep_idx_int - 1 + + if max_target_idx > 0: + with torch.no_grad(): + for step_idx in range(max_target_idx): + current_timestep = step_list[step_idx] + current_timestep_tensor = ( + current_timestep + * torch.ones( + 1, + device=device, + dtype=torch.long, + ) + ) + + pred_clean = self.student.predict_x0( + current_noise_latents, + current_timestep_tensor, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + next_timestep = step_list[step_idx + 1] + next_timestep_tensor = ( + next_timestep + * torch.ones( + 1, + device=device, + dtype=torch.long, + ) + ) + noise = torch.randn( + latents.shape, + device=device, + dtype=pred_clean.dtype, + ) + current_noise_latents = ( + self.student.add_noise( + pred_clean, + noise, + next_timestep_tensor, + ) + ) + noise_latents.append( + current_noise_latents.clone() + ) + + if noise_latent_index >= 0: + if noise_latent_index >= len(noise_latents): + raise RuntimeError( + "noise_latent_index is out of bounds" + ) + noisy_input = noise_latents[noise_latent_index] + else: + noisy_input = current_noise_latents_copy + + if with_grad: + pred_x0 = self.student.predict_x0( + noisy_input, + target_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + else: + with torch.no_grad(): + pred_x0 = self.student.predict_x0( + noisy_input, + target_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + batch.dmd_latent_vis_dict[ + "generator_timestep" + ] = target_timestep.float().detach() + return pred_x0 + + def _critic_flow_matching_loss( + self, batch: Any, + ) -> tuple[torch.Tensor, Any, dict[str, Any]]: + with torch.no_grad(): + generator_pred_x0 = self._student_rollout( + batch, with_grad=False + ) + + device = generator_pred_x0.device + fake_score_timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + fake_score_timestep = ( + self.student.shift_and_clamp_timestep( + fake_score_timestep + ) + ) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_x0 = self.student.add_noise( + generator_pred_x0, noise, fake_score_timestep + ) + + pred_noise = self.critic.predict_noise( + noisy_x0, + fake_score_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + target = noise - generator_pred_x0 + flow_matching_loss = torch.mean( + (pred_noise - target)**2 + ) + + batch.fake_score_latent_vis_dict = { + "generator_pred_video": generator_pred_x0, + "fake_score_timestep": fake_score_timestep, + } + outputs = { + "fake_score_latent_vis_dict": ( + batch.fake_score_latent_vis_dict + ) + } + return ( + flow_matching_loss, + (batch.timesteps, batch.attn_metadata), + outputs, + ) + + def _dmd_loss( + self, + generator_pred_x0: torch.Tensor, + batch: Any, + ) -> torch.Tensor: + guidance_scale = get_optional_float( + self.method_config, + "real_score_guidance_scale", + where="method.real_score_guidance_scale", + ) + if guidance_scale is None: + guidance_scale = 1.0 + device = generator_pred_x0.device + + with torch.no_grad(): + timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + timestep = ( + self.student.shift_and_clamp_timestep( + timestep + ) + ) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_latents = self.student.add_noise( + generator_pred_x0, noise, timestep + ) + + faker_x0 = self.critic.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_cond_x0 = self.teacher.predict_x0( + noisy_latents, + timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_uncond_x0 = self.teacher.predict_x0( + noisy_latents, + timestep, + batch, + conditional=False, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + real_cfg_x0 = real_uncond_x0 + ( + real_cond_x0 - real_uncond_x0 + ) * guidance_scale + + denom = torch.abs( + generator_pred_x0 - real_cfg_x0 + ).mean() + grad = (faker_x0 - real_cfg_x0) / denom + grad = torch.nan_to_num(grad) + + loss = 0.5 * F.mse_loss( + generator_pred_x0.float(), + ( + generator_pred_x0.float() - grad.float() + ).detach(), + ) + return loss diff --git a/fastvideo/train/methods/distribution_matching/self_forcing.py b/fastvideo/train/methods/distribution_matching/self_forcing.py new file mode 100644 index 000000000..ae2547bd5 --- /dev/null +++ b/fastvideo/train/methods/distribution_matching/self_forcing.py @@ -0,0 +1,571 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Self-Forcing distillation method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal, TYPE_CHECKING + +import torch +import torch.distributed as dist + +from fastvideo.train.models.base import ( + CausalModelBase, + ModelBase, +) +from fastvideo.train.methods.distribution_matching.dmd2 import ( + DMD2Method, ) +from fastvideo.train.utils.config import ( + get_optional_float, + get_optional_int, +) +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler, ) +from fastvideo.models.utils import pred_noise_to_pred_video + +if TYPE_CHECKING: + from fastvideo.pipelines import TrainingBatch + + +def _require_bool(raw: Any, *, where: str) -> bool: + if isinstance(raw, bool): + return raw + raise ValueError(f"Expected bool at {where}, got {type(raw).__name__}") + + +def _require_str(raw: Any, *, where: str) -> str: + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"Expected non-empty string at {where}") + return raw + + +class SelfForcingMethod(DMD2Method): + """Self-Forcing DMD2 (distribution matching) method. + + Requires a causal student implementing ``CausalModelBase``. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__( + cfg=cfg, + role_models=role_models, + ) + + # Validate causal student. + if not isinstance(self.student, CausalModelBase): + raise ValueError("SelfForcingMethod requires a causal student " + "implementing CausalModelBase.") + + if self._rollout_mode != "simulate": + raise ValueError("SelfForcingMethod only supports " + "method_config.rollout_mode='simulate'") + + mcfg = self.method_config + + chunk_size = get_optional_int( + mcfg, + "chunk_size", + where="method_config.chunk_size", + ) + if chunk_size is None: + chunk_size = 3 + if chunk_size <= 0: + raise ValueError("method_config.chunk_size must be a positive " + f"integer, got {chunk_size}") + self._chunk_size = int(chunk_size) + + sample_type_raw = mcfg.get("student_sample_type", "sde") + sample_type = _require_str( + sample_type_raw, + where="method_config.student_sample_type", + ) + sample_type = sample_type.strip().lower() + if sample_type not in {"sde", "ode"}: + raise ValueError("method_config.student_sample_type must be one " + f"of {{sde, ode}}, got {sample_type_raw!r}") + self._student_sample_type: Literal["sde", "ode"] = ( + sample_type # type: ignore[assignment] + ) + + same_step_raw = mcfg.get("same_step_across_blocks", False) + if same_step_raw is None: + same_step_raw = False + self._same_step_across_blocks = _require_bool( + same_step_raw, + where="method_config.same_step_across_blocks", + ) + + last_step_raw = mcfg.get("last_step_only", False) + if last_step_raw is None: + last_step_raw = False + self._last_step_only = _require_bool( + last_step_raw, + where="method_config.last_step_only", + ) + + context_noise = get_optional_float( + mcfg, + "context_noise", + where="method_config.context_noise", + ) + if context_noise is None: + context_noise = 0.0 + if context_noise < 0.0: + raise ValueError("method_config.context_noise must be >= 0, " + f"got {context_noise}") + self._context_noise = float(context_noise) + + enable_grad_raw = mcfg.get("enable_gradient_in_rollout", True) + if enable_grad_raw is None: + enable_grad_raw = True + self._enable_gradient_in_rollout = _require_bool( + enable_grad_raw, + where="method_config.enable_gradient_in_rollout", + ) + + start_grad_frame = get_optional_int( + mcfg, + "start_gradient_frame", + where="method_config.start_gradient_frame", + ) + if start_grad_frame is None: + start_grad_frame = 0 + if start_grad_frame < 0: + raise ValueError("method_config.start_gradient_frame must be " + f">= 0, got {start_grad_frame}") + self._start_gradient_frame = int(start_grad_frame) + + shift = float(getattr( + self.training_config.pipeline_config, + "flow_shift", + 0.0, + ) or 0.0) + self._sf_scheduler = SelfForcingFlowMatchScheduler( + num_inference_steps=1000, + num_train_timesteps=int(self.student.num_train_timesteps), + shift=shift, + sigma_min=0.0, + extra_one_step=True, + training=True, + ) + + self._sf_denoising_step_list: torch.Tensor | None = None + + def _get_denoising_step_list(self, device: torch.device) -> torch.Tensor: + if (self._sf_denoising_step_list is not None and self._sf_denoising_step_list.device == device): + return self._sf_denoising_step_list + + raw = self.method_config.get("dmd_denoising_steps", None) + if not isinstance(raw, list) or not raw: + raise ValueError("method_config.dmd_denoising_steps must be set " + "for self_forcing") + steps = torch.tensor( + [int(s) for s in raw], + dtype=torch.long, + device=device, + ) + + warp = self.method_config.get("warp_denoising_step", None) + if warp is None: + warp = False + if bool(warp): + timesteps = torch.cat(( + self._sf_scheduler.timesteps.to("cpu"), + torch.tensor([0], dtype=torch.float32), + )).to(device) + steps = timesteps[int(self.student.num_train_timesteps) - steps] + + self._sf_denoising_step_list = steps + return steps + + def _predict_x0_with_scheduler( + self, + model: ModelBase, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + attn_kind: Literal["dense", "vsa"], + ) -> torch.Tensor: + pred_noise = model.predict_noise( + noisy_latents, + timestep, + batch, + conditional=conditional, + cfg_uncond=self._cfg_uncond, + attn_kind=attn_kind, + ) + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=timestep, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + def _sf_add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + b, t = clean_latents.shape[:2] + noisy = self._sf_scheduler.add_noise( + clean_latents.flatten(0, 1), + noise.flatten(0, 1), + timestep, + ).unflatten(0, (b, t)) + return noisy + + def _timestep_to_sigma(self, timestep: torch.Tensor) -> torch.Tensor: + sigmas = self._sf_scheduler.sigmas.to(device=timestep.device, dtype=torch.float32) + timesteps = self._sf_scheduler.timesteps.to(device=timestep.device, dtype=torch.float32) + t = timestep.to(device=timestep.device, dtype=torch.float32) + if t.ndim == 2: + t = t.flatten(0, 1) + elif t.ndim == 1 and t.numel() == 1: + t = t.expand(1) + elif t.ndim != 1: + raise ValueError("Invalid timestep shape: " + f"{tuple(timestep.shape)}") + idx = torch.argmin( + (timesteps.unsqueeze(0) - t.unsqueeze(1)).abs(), + dim=1, + ) + return sigmas[idx] + + def _sample_exit_indices( + self, + *, + num_blocks: int, + num_steps: int, + device: torch.device, + ) -> list[int]: + if num_blocks <= 0: + return [] + if num_steps <= 0: + raise ValueError("num_steps must be positive") + + shape = ((1, ) if self._same_step_across_blocks else (num_blocks, )) + + if not dist.is_initialized() or dist.get_rank() == 0: + if self._last_step_only: + indices = torch.full( + shape, + num_steps - 1, + dtype=torch.long, + device=device, + ) + else: + indices = torch.randint( + low=0, + high=num_steps, + size=shape, + device=device, + ) + else: + indices = torch.empty(shape, dtype=torch.long, device=device) + + if dist.is_initialized(): + dist.broadcast(indices, src=0) + + if self._same_step_across_blocks: + return [int(indices.item()) for _ in range(num_blocks)] + return [int(i) for i in indices.tolist()] + + def _student_rollout(self, batch: Any, *, with_grad: bool) -> torch.Tensor: + if not isinstance(self.student, CausalModelBase): + raise ValueError("SelfForcingMethod requires a causal student " + "implementing CausalModelBase.") + return self._student_rollout_streaming(batch, with_grad=with_grad) + + def _student_rollout_streaming(self, batch: Any, *, with_grad: bool) -> torch.Tensor: + assert isinstance(self.student, CausalModelBase) + latents = batch.latents + if latents is None: + raise RuntimeError("TrainingBatch.latents is required for " + "self-forcing rollout") + if latents.ndim != 5: + raise ValueError("TrainingBatch.latents must be [B, T, C, H, W]" + f", got shape={tuple(latents.shape)}") + + device = latents.device + dtype = latents.dtype + batch_size = int(latents.shape[0]) + num_frames = int(latents.shape[1]) + + denoising_steps = self._get_denoising_step_list(device) + num_steps = int(denoising_steps.numel()) + + noise_full = torch.randn_like(latents, device=device, dtype=dtype) + + chunk = int(self._chunk_size) + if chunk <= 0: + raise ValueError("chunk_size must be positive") + + remaining = num_frames % chunk + num_blocks = num_frames // chunk + if num_blocks == 0: + num_blocks = 1 + remaining = num_frames + + exit_indices = self._sample_exit_indices( + num_blocks=num_blocks, + num_steps=num_steps, + device=device, + ) + + denoised_blocks: list[torch.Tensor] = [] + + cache_tag = "pos" + self.student.clear_caches(cache_tag=cache_tag) + + for block_idx in range(num_blocks): + if block_idx == 0: + start = 0 + end = remaining + chunk if remaining else chunk + else: + start = remaining + block_idx * chunk + end = remaining + (block_idx + 1) * chunk + start = int(start) + end = int(min(end, num_frames)) + if start >= end: + break + + noisy_block = noise_full[:, start:end] + exit_idx = int(exit_indices[block_idx]) + + for step_idx, current_timestep in enumerate(denoising_steps): + exit_flag = step_idx == exit_idx + + timestep_block = (current_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + )) + + enable_grad = (bool(with_grad) and bool(self._enable_gradient_in_rollout) and torch.is_grad_enabled() + and start >= int(self._start_gradient_frame)) + + if not exit_flag: + with torch.no_grad(): + pred_noise = (self.student.predict_noise_streaming( + noisy_block, + timestep_block, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=False, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + )) + if pred_noise is None: + raise RuntimeError("predict_noise_streaming " + "returned None " + "(store_kv=False)") + pred_x0_chunk = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=(noisy_block.flatten(0, 1)), + timestep=timestep_block, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + + if step_idx + 1 >= num_steps: + break + next_timestep = denoising_steps[step_idx + 1] + if self._student_sample_type == "sde": + noisy_block = self._sf_add_noise( + pred_x0_chunk, + torch.randn_like(pred_x0_chunk), + next_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ), + ) + else: + sigma_cur = self._timestep_to_sigma(timestep_block).view(batch_size, end - start, 1, 1, 1) + sigma_next = self._timestep_to_sigma(next_timestep * torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + )).view(batch_size, end - start, 1, 1, 1) + eps = (noisy_block - (1 - sigma_cur) * pred_x0_chunk) / sigma_cur.clamp_min(1e-8) + noisy_block = ((1 - sigma_next) * pred_x0_chunk + sigma_next * eps) + continue + + with torch.set_grad_enabled(enable_grad): + pred_noise = (self.student.predict_noise_streaming( + noisy_block, + timestep_block, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=False, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + )) + if pred_noise is None: + raise RuntimeError("predict_noise_streaming returned " + "None (store_kv=False)") + pred_x0_chunk = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=(noisy_block.flatten(0, 1)), + timestep=timestep_block, + scheduler=self._sf_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + break + + denoised_blocks.append(pred_x0_chunk) + + with torch.no_grad(): + if self._context_noise > 0.0: + context_timestep = torch.ones( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ) * float(self._context_noise) + context_latents = self._sf_add_noise( + pred_x0_chunk.detach(), + torch.randn_like(pred_x0_chunk), + context_timestep, + ) + else: + context_timestep = torch.zeros( + (batch_size, end - start), + device=device, + dtype=torch.float32, + ) + context_latents = pred_x0_chunk.detach() + + _ = self.student.predict_noise_streaming( + context_latents, + context_timestep, + batch, + conditional=True, + cache_tag=cache_tag, + store_kv=True, + cur_start_frame=start, + cfg_uncond=self._cfg_uncond, + attn_kind="vsa", + ) + + if not denoised_blocks: + raise RuntimeError("Self-forcing rollout produced no blocks") + + self.student.clear_caches(cache_tag=cache_tag) + return torch.cat(denoised_blocks, dim=1) + + def _critic_flow_matching_loss(self, batch: Any) -> tuple[torch.Tensor, Any, dict[str, Any]]: + with torch.no_grad(): + generator_pred_x0 = self._student_rollout(batch, with_grad=False) + + device = generator_pred_x0.device + fake_score_timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + fake_score_timestep = (self.student.shift_and_clamp_timestep(fake_score_timestep)) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_x0 = self._sf_add_noise(generator_pred_x0, noise, fake_score_timestep) + + pred_noise = self.critic.predict_noise( + noisy_x0, + fake_score_timestep, + batch, + conditional=True, + cfg_uncond=self._cfg_uncond, + attn_kind="dense", + ) + target = noise - generator_pred_x0 + flow_matching_loss = torch.mean((pred_noise - target)**2) + + batch.fake_score_latent_vis_dict = { + "generator_pred_video": generator_pred_x0, + "fake_score_timestep": fake_score_timestep, + } + outputs = {"fake_score_latent_vis_dict": (batch.fake_score_latent_vis_dict)} + return ( + flow_matching_loss, + (batch.timesteps, batch.attn_metadata), + outputs, + ) + + def _dmd_loss( + self, + generator_pred_x0: torch.Tensor, + batch: Any, + ) -> torch.Tensor: + guidance_scale = get_optional_float( + self.method_config, + "real_score_guidance_scale", + where="method.real_score_guidance_scale", + ) + if guidance_scale is None: + guidance_scale = 1.0 + device = generator_pred_x0.device + + with torch.no_grad(): + timestep = torch.randint( + 0, + int(self.student.num_train_timesteps), + [1], + device=device, + dtype=torch.long, + ) + timestep = self.student.shift_and_clamp_timestep(timestep) + + noise = torch.randn( + generator_pred_x0.shape, + device=device, + dtype=generator_pred_x0.dtype, + ) + noisy_latents = self._sf_add_noise(generator_pred_x0, noise, timestep) + + faker_x0 = self._predict_x0_with_scheduler( + self.critic, + noisy_latents, + timestep, + batch, + conditional=True, + attn_kind="dense", + ) + real_cond_x0 = self._predict_x0_with_scheduler( + self.teacher, + noisy_latents, + timestep, + batch, + conditional=True, + attn_kind="dense", + ) + real_uncond_x0 = self._predict_x0_with_scheduler( + self.teacher, + noisy_latents, + timestep, + batch, + conditional=False, + attn_kind="dense", + ) + real_cfg_x0 = real_uncond_x0 + (real_cond_x0 - real_uncond_x0) * guidance_scale + + denom = torch.abs(generator_pred_x0 - real_cfg_x0).mean() + grad = (faker_x0 - real_cfg_x0) / denom + grad = torch.nan_to_num(grad) + + loss = 0.5 * torch.mean((generator_pred_x0.float() - (generator_pred_x0.float() - grad.float()).detach())**2) + return loss diff --git a/fastvideo/train/methods/fine_tuning/__init__.py b/fastvideo/train/methods/fine_tuning/__init__.py new file mode 100644 index 000000000..6f862df4e --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.train.methods.fine_tuning.dfsft import DiffusionForcingSFTMethod + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + +__all__ = [ + "DiffusionForcingSFTMethod", + "FineTuneMethod", +] + + +def __getattr__(name: str) -> object: + # Lazy import to avoid circular imports during registry bring-up. + if name == "DiffusionForcingSFTMethod": + from fastvideo.train.methods.fine_tuning.dfsft import ( + DiffusionForcingSFTMethod, ) + + return DiffusionForcingSFTMethod + if name == "FineTuneMethod": + from fastvideo.train.methods.fine_tuning.finetune import FineTuneMethod + + return FineTuneMethod + raise AttributeError(name) diff --git a/fastvideo/train/methods/fine_tuning/dfsft.py b/fastvideo/train/methods/fine_tuning/dfsft.py new file mode 100644 index 000000000..4f91110a2 --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/dfsft.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Diffusion-forcing SFT method (DFSFT; algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) + + +class DiffusionForcingSFTMethod(TrainingMethod): + """Diffusion-forcing SFT (DFSFT): train only ``student`` + with inhomogeneous timesteps. + """ + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError("DFSFT requires role 'student'") + if not self.student._trainable: + raise ValueError( + "DFSFT requires student to be trainable" + ) + self._attn_kind: Literal["dense", "vsa"] = ( + self._parse_attn_kind( + self.method_config.get("attn_kind", None) + ) + ) + + self._chunk_size = self._parse_chunk_size( + self.method_config.get("chunk_size", None) + ) + self._timestep_index_range = ( + self._parse_timestep_index_range() + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict(self) -> dict[str, Any]: + return {"student": self._student_optimizer} + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return {"student": self._student_lr_scheduler} + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + del iteration + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source="data", + ) + + if training_batch.latents is None: + raise RuntimeError( + "prepare_batch() must set TrainingBatch.latents" + ) + + clean_latents = training_batch.latents + if not torch.is_tensor(clean_latents): + raise TypeError( + "TrainingBatch.latents must be a torch.Tensor" + ) + if clean_latents.ndim != 5: + raise ValueError( + "TrainingBatch.latents must be " + "[B, T, C, H, W], got " + f"shape={tuple(clean_latents.shape)}" + ) + + batch_size, num_latents = ( + int(clean_latents.shape[0]), + int(clean_latents.shape[1]), + ) + + expected_chunk = getattr( + self.student.transformer, + "num_frame_per_block", + None, + ) + if ( + expected_chunk is not None + and int(expected_chunk) != int(self._chunk_size) + ): + raise ValueError( + "DFSFT chunk_size must match " + "transformer.num_frame_per_block for " + f"causal training (got {self._chunk_size}, " + f"expected {expected_chunk})." + ) + + timestep_indices = self._sample_t_inhom_indices( + batch_size=batch_size, + num_latents=num_latents, + device=clean_latents.device, + ) + sp_size = int( + self.training_config.distributed.sp_size + ) + sp_group = getattr(self.student, "sp_group", None) + if ( + sp_size > 1 + and sp_group is not None + and hasattr(sp_group, "broadcast") + ): + sp_group.broadcast(timestep_indices, src=0) + + scheduler = self.student.noise_scheduler + if scheduler is None: + raise ValueError( + "DFSFT requires student.noise_scheduler" + ) + + schedule_timesteps = scheduler.timesteps.to( + device=clean_latents.device, dtype=torch.float32 + ) + schedule_sigmas = scheduler.sigmas.to( + device=clean_latents.device, + dtype=clean_latents.dtype, + ) + t_inhom = schedule_timesteps[timestep_indices] + + # Override the homogeneous timesteps from prepare_batch + # so that set_forward_context (in predict_noise and + # backward) receives the correct per-chunk timesteps. + training_batch.timesteps = t_inhom + + noise = getattr(training_batch, "noise", None) + if noise is None: + noise = torch.randn_like(clean_latents) + else: + if not torch.is_tensor(noise): + raise TypeError( + "TrainingBatch.noise must be a " + "torch.Tensor when set" + ) + noise = noise.permute(0, 2, 1, 3, 4).to( + dtype=clean_latents.dtype + ) + + noisy_latents = self.student.add_noise( + clean_latents, + noise, + t_inhom.flatten(), + ) + + pred = self.student.predict_noise( + noisy_latents, + t_inhom, + training_batch, + conditional=True, + attn_kind=self._attn_kind, + ) + + if bool( + self.training_config.model.precondition_outputs + ): + sigmas = schedule_sigmas[timestep_indices] + sigmas = sigmas.unsqueeze(-1).unsqueeze( + -1 + ).unsqueeze(-1) + pred_x0 = noisy_latents - pred * sigmas + loss = F.mse_loss( + pred_x0.float(), clean_latents.float() + ) + else: + target = noise - clean_latents + loss = F.mse_loss( + pred.float(), target.float() + ) + + if self._attn_kind == "vsa": + attn_metadata = training_batch.attn_metadata_vsa + else: + attn_metadata = training_batch.attn_metadata + + loss_map = {"total_loss": loss, "dfsft_loss": loss} + outputs: dict[str, Any] = { + "_fv_backward": ( + training_batch.timesteps, + attn_metadata, + ) + } + metrics: dict[str, LogScalar] = {} + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + ctx = outputs.get("_fv_backward") + if ctx is None: + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + self.student.backward( + loss_map["total_loss"], + ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + del iteration + return [self._student_optimizer] + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + del iteration + return [self._student_lr_scheduler] + + def _parse_chunk_size(self, raw: Any) -> int: + if raw in (None, ""): + return 3 + if isinstance(raw, bool): + raise ValueError( + "method_config.chunk_size must be an int, " + "got bool" + ) + if isinstance(raw, float) and not raw.is_integer(): + raise ValueError( + "method_config.chunk_size must be an int, " + "got float" + ) + if isinstance(raw, str) and not raw.strip(): + raise ValueError( + "method_config.chunk_size must be an int, " + "got empty string" + ) + try: + value = int(raw) + except (TypeError, ValueError) as e: + raise ValueError( + "method_config.chunk_size must be an int, " + f"got {type(raw).__name__}" + ) from e + if value <= 0: + raise ValueError( + "method_config.chunk_size must be > 0" + ) + return value + + def _parse_ratio( + self, + raw: Any, + *, + where: str, + default: float, + ) -> float: + if raw in (None, ""): + return float(default) + if isinstance(raw, bool): + raise ValueError( + f"{where} must be a number/string, got bool" + ) + if isinstance(raw, int | float): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError( + f"{where} must be a number/string, " + f"got {type(raw).__name__}" + ) + + def _parse_timestep_index_range( + self, + ) -> tuple[int, int]: + scheduler = self.student.noise_scheduler + if scheduler is None: + raise ValueError( + "DFSFT requires student.noise_scheduler" + ) + num_steps = int( + getattr( + scheduler, "config", scheduler + ).num_train_timesteps + ) + + min_ratio = self._parse_ratio( + self.method_config.get( + "min_timestep_ratio", None + ), + where="method.min_timestep_ratio", + default=0.0, + ) + max_ratio = self._parse_ratio( + self.method_config.get( + "max_timestep_ratio", None + ), + where="method.max_timestep_ratio", + default=1.0, + ) + + if not ( + 0.0 <= min_ratio <= 1.0 + and 0.0 <= max_ratio <= 1.0 + ): + raise ValueError( + "DFSFT timestep ratios must be in [0,1], " + f"got min={min_ratio}, max={max_ratio}" + ) + if max_ratio < min_ratio: + raise ValueError( + "method_config.max_timestep_ratio must be " + ">= min_timestep_ratio" + ) + + min_index = int(min_ratio * num_steps) + max_index = int(max_ratio * num_steps) + min_index = max(0, min(min_index, num_steps - 1)) + max_index = max(0, min(max_index, num_steps - 1)) + + if max_index <= min_index: + max_index = min(num_steps - 1, min_index + 1) + + return min_index, max_index + 1 + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + student_lr = float(tc.optimizer.learning_rate) + if student_lr <= 0.0: + raise ValueError( + "training.learning_rate must be > 0 " + "for dfsft" + ) + + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) + + def _sample_t_inhom_indices( + self, + *, + batch_size: int, + num_latents: int, + device: torch.device, + ) -> torch.Tensor: + chunk_size = self._chunk_size + num_chunks = ( + (num_latents + chunk_size - 1) // chunk_size + ) + low, high = self._timestep_index_range + chunk_indices = torch.randint( + low=low, + high=high, + size=(batch_size, num_chunks), + device=device, + dtype=torch.long, + ) + expanded = chunk_indices.repeat_interleave( + chunk_size, dim=1 + ) + return expanded[:, :num_latents] diff --git a/fastvideo/train/methods/fine_tuning/finetune.py b/fastvideo/train/methods/fine_tuning/finetune.py new file mode 100644 index 000000000..cf7dc3139 --- /dev/null +++ b/fastvideo/train/methods/fine_tuning/finetune.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Supervised finetuning method (algorithm layer).""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +import torch.nn.functional as F + +from fastvideo.train.methods.base import TrainingMethod, LogScalar +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.optimizer import ( + build_optimizer_and_scheduler, +) + + +class FineTuneMethod(TrainingMethod): + """Supervised finetuning: only ``student`` participates.""" + + def __init__( + self, + *, + cfg: Any, + role_models: dict[str, ModelBase], + ) -> None: + super().__init__(cfg=cfg, role_models=role_models) + + if "student" not in role_models: + raise ValueError( + "FineTuneMethod requires role 'student'" + ) + if not self.student._trainable: + raise ValueError( + "FineTuneMethod requires student to be " + "trainable" + ) + self._attn_kind: Literal["dense", "vsa"] = ( + self._parse_attn_kind( + self.method_config.get("attn_kind", None) + ) + ) + + # Initialize preprocessors on student. + self.student.init_preprocessors(self.training_config) + + self._init_optimizers_and_schedulers() + + @property + def _optimizer_dict(self) -> dict[str, Any]: + return {"student": self._student_optimizer} + + @property + def _lr_scheduler_dict(self) -> dict[str, Any]: + return {"student": self._student_lr_scheduler} + + # TrainingMethod override: single_train_step + def single_train_step( + self, + batch: dict[str, Any], + iteration: int, + *, + current_vsa_sparsity: float = 0.0, + ) -> tuple[ + dict[str, torch.Tensor], + dict[str, Any], + dict[str, LogScalar], + ]: + del iteration + training_batch = self.student.prepare_batch( + batch, + current_vsa_sparsity=current_vsa_sparsity, + latents_source="data", + ) + + if training_batch.latents is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.latents" + ) + if training_batch.noisy_model_input is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.noisy_model_input" + ) + if training_batch.noise is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.noise" + ) + if training_batch.sigmas is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.sigmas" + ) + if training_batch.timesteps is None: + raise RuntimeError( + "prepare_batch() must set " + "TrainingBatch.timesteps" + ) + + clean_latents = training_batch.latents + noisy_latents = ( + training_batch.noisy_model_input.permute( + 0, 2, 1, 3, 4 + ) + ) + noise = training_batch.noise.permute( + 0, 2, 1, 3, 4 + ) + sigmas = training_batch.sigmas + timesteps = training_batch.timesteps + + pred = self.student.predict_noise( + noisy_latents, + timesteps, + training_batch, + conditional=True, + attn_kind=self._attn_kind, + ) + + if bool( + self.training_config.model.precondition_outputs + ): + pred_x0 = noisy_latents - pred * sigmas + loss = F.mse_loss( + pred_x0.float(), clean_latents.float() + ) + else: + target = noise - clean_latents + loss = F.mse_loss( + pred.float(), target.float() + ) + + if self._attn_kind == "vsa": + attn_metadata = training_batch.attn_metadata_vsa + else: + attn_metadata = training_batch.attn_metadata + + loss_map = { + "total_loss": loss, + "finetune_loss": loss, + } + outputs: dict[str, Any] = { + "_fv_backward": ( + training_batch.timesteps, + attn_metadata, + ) + } + metrics: dict[str, LogScalar] = {} + return loss_map, outputs, metrics + + # TrainingMethod override: backward + def backward( + self, + loss_map: dict[str, torch.Tensor], + outputs: dict[str, Any], + *, + grad_accum_rounds: int = 1, + ) -> None: + grad_accum_rounds = max(1, int(grad_accum_rounds)) + ctx = outputs.get("_fv_backward") + if ctx is None: + super().backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum_rounds, + ) + return + self.student.backward( + loss_map["total_loss"], + ctx, + grad_accum_rounds=grad_accum_rounds, + ) + + # TrainingMethod override: get_optimizers + def get_optimizers( + self, iteration: int, + ) -> list[torch.optim.Optimizer]: + del iteration + return [self._student_optimizer] + + # TrainingMethod override: get_lr_schedulers + def get_lr_schedulers( + self, iteration: int, + ) -> list[Any]: + del iteration + return [self._student_lr_scheduler] + + def _init_optimizers_and_schedulers(self) -> None: + tc = self.training_config + + student_lr = float(tc.optimizer.learning_rate) + if student_lr <= 0.0: + raise ValueError( + "training.learning_rate must be > 0 " + "for finetune" + ) + + student_betas = tc.optimizer.betas + student_sched = str(tc.optimizer.lr_scheduler) + student_params = [ + p + for p in self.student.transformer.parameters() + if p.requires_grad + ] + ( + self._student_optimizer, + self._student_lr_scheduler, + ) = build_optimizer_and_scheduler( + params=student_params, + optimizer_config=tc.optimizer, + loop_config=tc.loop, + learning_rate=student_lr, + betas=student_betas, + scheduler_name=student_sched, + ) diff --git a/fastvideo/train/methods/knowledge_distillation/__init__.py b/fastvideo/train/methods/knowledge_distillation/__init__.py new file mode 100644 index 000000000..324710b84 --- /dev/null +++ b/fastvideo/train/methods/knowledge_distillation/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 + +__all__: list[str] = [] diff --git a/fastvideo/train/models/__init__.py b/fastvideo/train/models/__init__.py new file mode 100644 index 000000000..56b47b1af --- /dev/null +++ b/fastvideo/train/models/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Model build plugins for Phase 2/2.9 distillation. + +These are "model plugins" selected by ``recipe.family`` / ``roles..family``. +""" diff --git a/fastvideo/train/models/base.py b/fastvideo/train/models/base.py new file mode 100644 index 000000000..d74406278 --- /dev/null +++ b/fastvideo/train/models/base.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Literal, TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + from fastvideo.pipelines import TrainingBatch + + +class ModelBase(ABC): + """Per-role model instance. + + Every role (student, teacher, critic, …) gets its own ``ModelBase`` + instance. Each instance owns its own ``transformer`` and + ``noise_scheduler``. Heavyweight resources (VAE, dataloader, RNG + seeds) are loaded lazily via :meth:`init_preprocessors`, which the + method calls **only on the student**. + """ + + transformer: torch.nn.Module + noise_scheduler: Any + _trainable: bool + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_preprocessors(self, training_config: TrainingConfig) -> None: + """Load VAE, build dataloader, seed RNGs. + + Called only on the student by the method's ``__init__``. + Default is a no-op so teacher/critic instances skip this. + """ + + def on_train_start(self) -> None: + """Called once before the training loop begins.""" + + def get_rng_generators(self) -> dict[str, torch.Generator]: + """Return RNG generators for checkpoint resume.""" + return {} + + # ------------------------------------------------------------------ + # Timestep helpers + # ------------------------------------------------------------------ + + @property + def num_train_timesteps(self) -> int: + """Return the scheduler's training timestep horizon.""" + return int(self.noise_scheduler.num_train_timesteps) + + def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor: + """Apply model/pipeline timestep shifting and clamp.""" + return timestep + + # ------------------------------------------------------------------ + # Runtime primitives + # ------------------------------------------------------------------ + + @abstractmethod + def prepare_batch( + self, + raw_batch: dict[str, Any], + *, + current_vsa_sparsity: float = 0.0, + latents_source: Literal["data", "zeros"] = "data", + ) -> TrainingBatch: + """Convert a dataloader batch into forward primitives.""" + + @abstractmethod + def add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + """Apply forward-process noise at *timestep*.""" + + @abstractmethod + def predict_noise( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + """Predict noise/flow for the given noisy latents.""" + + @abstractmethod + def predict_x0( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + """Predict x0 for the given noisy latents.""" + + @abstractmethod + def backward( + self, + loss: torch.Tensor, + ctx: Any, + *, + grad_accum_rounds: int, + ) -> None: + """Backward that may restore forward-context.""" + + +class CausalModelBase(ModelBase): + """Extension for causal / streaming model plugins. + + Cache state is internal to the model instance and keyed by + *cache_tag* (no role handle needed). + """ + + @abstractmethod + def clear_caches(self, *, cache_tag: str = "pos") -> None: + """Clear internal caches before starting a new rollout.""" + + @abstractmethod + def predict_noise_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + """Streaming predict-noise that may update internal caches.""" + + @abstractmethod + def predict_x0_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + """Streaming predict-x0 that may update internal caches.""" diff --git a/fastvideo/train/models/wan/__init__.py b/fastvideo/train/models/wan/__init__.py new file mode 100644 index 000000000..9a8113ac1 --- /dev/null +++ b/fastvideo/train/models/wan/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Wan model plugin package.""" + +from fastvideo.train.models.wan.wan import ( + WanModel as WanModel, ) +from fastvideo.train.models.wan.wan_causal import ( + WanCausalModel as WanCausalModel, ) diff --git a/fastvideo/train/models/wan/wan.py b/fastvideo/train/models/wan/wan.py new file mode 100644 index 000000000..4ff0b1eeb --- /dev/null +++ b/fastvideo/train/models/wan/wan.py @@ -0,0 +1,735 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Wan model plugin (per-role instance).""" + +from __future__ import annotations + +import copy +import gc +from typing import Any, Literal, TYPE_CHECKING + +import torch + +import fastvideo.envs as envs +from fastvideo.configs.sample import SamplingParam +from fastvideo.distributed import ( + get_local_torch_device, + get_sp_group, + get_world_group, +) +from fastvideo.forward_context import set_forward_context +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, ) +from fastvideo.models.utils import pred_noise_to_pred_video +from fastvideo.pipelines import TrainingBatch +from fastvideo.pipelines.basic.wan.wan_pipeline import ( + WanPipeline, ) +from fastvideo.pipelines.pipeline_batch_info import ( + ForwardBatch, ) +from fastvideo.training.activation_checkpoint import ( + apply_activation_checkpointing, ) +from fastvideo.training.training_utils import ( + compute_density_for_timestep_sampling, + get_sigmas, + normalize_dit_input, + shift_timestep, +) +from fastvideo.utils import ( + is_vmoba_available, + is_vsa_available, + set_random_seed, +) + +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.module_state import ( + apply_trainable, ) +from fastvideo.train.utils.moduleloader import ( + load_module_from_path, ) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +try: + from fastvideo.attention.backends.video_sparse_attn import ( + VideoSparseAttentionMetadataBuilder, ) + from fastvideo.attention.backends.vmoba import ( + VideoMobaAttentionMetadataBuilder, ) +except Exception: + VideoSparseAttentionMetadataBuilder = None # type: ignore[assignment] + VideoMobaAttentionMetadataBuilder = None # type: ignore[assignment] + + +class WanModel(ModelBase): + """Wan per-role model: owns transformer + noise_scheduler.""" + + _transformer_cls_name: str = "WanTransformer3DModel" + + def __init__( + self, + *, + init_from: str, + training_config: TrainingConfig, + trainable: bool = True, + disable_custom_init_weights: bool = False, + flow_shift: float = 3.0, + enable_gradient_checkpointing_type: str + | None = None, + ) -> None: + self._init_from = str(init_from) + self._trainable = bool(trainable) + + self.transformer = self._load_transformer( + init_from=self._init_from, + trainable=self._trainable, + disable_custom_init_weights=(disable_custom_init_weights), + enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type), + training_config=training_config, + ) + + self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift))) + + # Filled by init_preprocessors (student only). + self.vae: Any = None + self.training_config: TrainingConfig = training_config + self.dataloader: Any = None + self.validator: Any = None + self.start_step: int = 0 + + self.world_group: Any = None + self.sp_group: Any = None + self.device: Any = get_local_torch_device() + + self.noise_random_generator: (torch.Generator | None) = None + self.noise_gen_cuda: torch.Generator | None = None + + self.negative_prompt_embeds: (torch.Tensor | None) = None + self.negative_prompt_attention_mask: (torch.Tensor | None) = None + + # Timestep mechanics. + self.timestep_shift: float = float(flow_shift) + self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps) + self.min_timestep: int = 0 + self.max_timestep: int = self.num_train_timestep + + def _load_transformer( + self, + *, + init_from: str, + trainable: bool, + disable_custom_init_weights: bool, + enable_gradient_checkpointing_type: str | None, + training_config: TrainingConfig, + ) -> torch.nn.Module: + transformer = load_module_from_path( + model_path=init_from, + module_type="transformer", + training_config=training_config, + disable_custom_init_weights=(disable_custom_init_weights), + override_transformer_cls_name=(self._transformer_cls_name), + ) + transformer = apply_trainable(transformer, trainable=trainable) + # Fall back to training_config.model if not set on the + # model YAML section directly. + ckpt_type = ( + enable_gradient_checkpointing_type + or getattr( + getattr(training_config, "model", None), + "enable_gradient_checkpointing_type", + None, + ) + ) + if trainable and ckpt_type: + transformer = apply_activation_checkpointing( + transformer, + checkpointing_type=ckpt_type, + ) + return transformer + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_preprocessors(self, training_config: TrainingConfig) -> None: + self.vae = load_module_from_path( + model_path=str(training_config.model_path), + module_type="vae", + training_config=training_config, + ) + + self.world_group = get_world_group() + self.sp_group = get_sp_group() + + self._init_timestep_mechanics() + + from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_t2v, ) + from fastvideo.train.utils.dataloader import ( + build_parquet_t2v_train_dataloader, ) + + text_len = ( + training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr] + 0].arch_config.text_len) + self.dataloader = build_parquet_t2v_train_dataloader( + training_config.data, + text_len=int(text_len), + parquet_schema=pyarrow_schema_t2v, + ) + self.start_step = 0 + + @property + def num_train_timesteps(self) -> int: + return int(self.num_train_timestep) + + def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor: + timestep = shift_timestep( + timestep, + self.timestep_shift, + self.num_train_timestep, + ) + return timestep.clamp(self.min_timestep, self.max_timestep) + + def on_train_start(self) -> None: + assert self.training_config is not None + seed = self.training_config.data.seed + if seed is None: + raise ValueError("training.data.seed must be set " + "for training") + + global_rank = int(getattr(self.world_group, "rank", 0)) + sp_world_size = int(self.training_config.distributed.sp_size or 1) + if sp_world_size > 1: + sp_group_seed = int(seed) + (global_rank // sp_world_size) + set_random_seed(sp_group_seed) + else: + set_random_seed(int(seed) + global_rank) + + self.noise_random_generator = torch.Generator(device="cpu").manual_seed(int(seed)) + self.noise_gen_cuda = torch.Generator(device=self.device).manual_seed(int(seed)) + + self.ensure_negative_conditioning() + + def get_rng_generators(self, ) -> dict[str, torch.Generator]: + generators: dict[str, torch.Generator] = {} + if self.noise_random_generator is not None: + generators["noise_cpu"] = (self.noise_random_generator) + if self.noise_gen_cuda is not None: + generators["noise_cuda"] = self.noise_gen_cuda + return generators + + # ------------------------------------------------------------------ + # Runtime primitives + # ------------------------------------------------------------------ + + def prepare_batch( + self, + raw_batch: dict[str, Any], + *, + current_vsa_sparsity: float = 0.0, + latents_source: Literal["data", "zeros"] = "data", + ) -> TrainingBatch: + self.ensure_negative_conditioning() + assert self.training_config is not None + tc = self.training_config + + dtype = self._get_training_dtype() + device = self.device + + training_batch = TrainingBatch(current_vsa_sparsity=current_vsa_sparsity) + encoder_hidden_states = raw_batch["text_embedding"] + encoder_attention_mask = raw_batch["text_attention_mask"] + infos = raw_batch.get("info_list") + + if latents_source == "zeros": + batch_size = encoder_hidden_states.shape[0] + vae_config = ( + tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr] + ) + num_channels = vae_config.z_dim + spatial_compression_ratio = (vae_config.spatial_compression_ratio) + latent_height = (tc.data.num_height // spatial_compression_ratio) + latent_width = (tc.data.num_width // spatial_compression_ratio) + latents = torch.zeros( + batch_size, + num_channels, + tc.data.num_latent_t, + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + elif latents_source == "data": + if "vae_latent" not in raw_batch: + raise ValueError("vae_latent not found in batch " + "and latents_source='data'") + latents = raw_batch["vae_latent"] + latents = latents[:, :, :tc.data.num_latent_t] + latents = latents.to(device, dtype=dtype) + else: + raise ValueError(f"Unknown latents_source: " + f"{latents_source!r}") + + training_batch.latents = latents + training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype)) + training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype)) + training_batch.infos = infos + + training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae) + training_batch = self._prepare_dit_inputs(training_batch) + training_batch = self._build_attention_metadata(training_batch) + + training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata) + if training_batch.attn_metadata is not None: + training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined] + + return training_batch + + def add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + b, t = clean_latents.shape[:2] + noisy = self.noise_scheduler.add_noise( + clean_latents.flatten(0, 1), + noise.flatten(0, 1), + timestep, + ).unflatten(0, (b, t)) + return noisy + + def predict_x0( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + if conditional: + text_dict = batch.conditional_dict + if text_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + else: + text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond) + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict)) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + if conditional: + text_dict = batch.conditional_dict + if text_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + else: + text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond) + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict)) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + return pred_noise + + def backward( + self, + loss: torch.Tensor, + ctx: Any, + *, + grad_accum_rounds: int, + ) -> None: + timesteps, attn_metadata = ctx + with set_forward_context( + current_timestep=timesteps, + attn_metadata=attn_metadata, + ): + (loss / max(1, int(grad_accum_rounds))).backward() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_training_dtype(self) -> torch.dtype: + return torch.bfloat16 + + def _init_timestep_mechanics(self) -> None: + assert self.training_config is not None + tc = self.training_config + self.timestep_shift = float(tc.pipeline_config.flow_shift # type: ignore[union-attr] + ) + self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps) + # min/max timestep ratios now come from method_config; + # default to full range. + self.min_timestep = 0 + self.max_timestep = self.num_train_timestep + + def ensure_negative_conditioning(self) -> None: + if self.negative_prompt_embeds is not None: + return + + assert self.training_config is not None + tc = self.training_config + world_group = self.world_group + device = self.device + dtype = self._get_training_dtype() + + from fastvideo.train.utils.moduleloader import ( + make_inference_args, ) + + neg_embeds: torch.Tensor | None = None + neg_mask: torch.Tensor | None = None + + if world_group.rank_in_group == 0: + sampling_param = SamplingParam.from_pretrained(tc.model_path) + negative_prompt = sampling_param.negative_prompt + + inference_args = make_inference_args(tc, model_path=tc.model_path) + + prompt_pipeline = WanPipeline.from_pretrained( + tc.model_path, + args=inference_args, + inference_mode=True, + loaded_modules={"transformer": self.transformer}, + tp_size=tc.distributed.tp_size, + sp_size=tc.distributed.sp_size, + num_gpus=tc.distributed.num_gpus, + pin_cpu_memory=(tc.distributed.pin_cpu_memory), + dit_cpu_offload=True, + ) + + batch_negative = ForwardBatch( + data_type="video", + prompt=negative_prompt, + prompt_embeds=[], + prompt_attention_mask=[], + ) + result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined] + batch_negative, + inference_args, + ) + + neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype) + neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype)) + + del prompt_pipeline + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + meta = torch.zeros((2, ), device=device, dtype=torch.int64) + if world_group.rank_in_group == 0: + assert neg_embeds is not None + assert neg_mask is not None + meta[0] = neg_embeds.ndim + meta[1] = neg_mask.ndim + world_group.broadcast(meta, src=0) + embed_ndim, mask_ndim = ( + int(meta[0].item()), + int(meta[1].item()), + ) + + max_ndim = 8 + embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) + mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) + if world_group.rank_in_group == 0: + assert neg_embeds is not None + assert neg_mask is not None + embed_shape[:embed_ndim] = torch.tensor( + list(neg_embeds.shape), + device=device, + dtype=torch.int64, + ) + mask_shape[:mask_ndim] = torch.tensor( + list(neg_mask.shape), + device=device, + dtype=torch.int64, + ) + world_group.broadcast(embed_shape, src=0) + world_group.broadcast(mask_shape, src=0) + + embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist()) + mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist()) + + if world_group.rank_in_group != 0: + neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype) + neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype) + assert neg_embeds is not None + assert neg_mask is not None + + world_group.broadcast(neg_embeds, src=0) + world_group.broadcast(neg_mask, src=0) + + self.negative_prompt_embeds = neg_embeds + self.negative_prompt_attention_mask = neg_mask + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: + if self.noise_random_generator is None: + raise RuntimeError("on_train_start() must be called before " + "prepare_batch()") + assert self.training_config is not None + tc = self.training_config + + u = compute_density_for_timestep_sampling( + weighting_scheme=tc.model.weighting_scheme, + batch_size=batch_size, + generator=self.noise_random_generator, + logit_mean=tc.model.logit_mean, + logit_std=tc.model.logit_std, + mode_scale=tc.model.mode_scale, + ) + indices = (u * self.noise_scheduler.config.num_train_timesteps).long() + return self.noise_scheduler.timesteps[indices].to(device=device) + + def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents_shape = training_batch.raw_latent_shape + patch_size = ( + tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr] + ) + current_vsa_sparsity = (training_batch.current_vsa_sparsity) + assert latents_shape is not None + assert training_batch.timesteps is not None + + if (envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"): + if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VIDEO_SPARSE_ATTN, but " + "fastvideo_kernel is not correctly " + "installed or detected.") + training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc] + raw_latent_shape=latents_shape[2:5], + current_timestep=(training_batch.timesteps), + patch_size=patch_size, + VSA_sparsity=current_vsa_sparsity, + device=self.device, + ) + elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"): + if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VMOBA_ATTN, but fastvideo_kernel " + "(or flash_attn>=2.7.4) is not " + "correctly installed.") + moba_params = tc.model.moba_config.copy() + moba_params.update({ + "current_timestep": (training_batch.timesteps), + "raw_latent_shape": (training_batch.raw_latent_shape[2:5]), + "patch_size": patch_size, + "device": self.device, + }) + training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(** + moba_params) # type: ignore[misc] + else: + training_batch.attn_metadata = None + + return training_batch + + def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents = training_batch.latents + assert isinstance(latents, torch.Tensor) + batch_size = latents.shape[0] + + if self.noise_gen_cuda is None: + raise RuntimeError("on_train_start() must be called before " + "prepare_batch()") + + noise = torch.randn( + latents.shape, + generator=self.noise_gen_cuda, + device=latents.device, + dtype=latents.dtype, + ) + timesteps = self._sample_timesteps(batch_size, latents.device) + if int(tc.distributed.sp_size or 1) > 1: + self.sp_group.broadcast(timesteps, src=0) + + sigmas = get_sigmas( + self.noise_scheduler, + latents.device, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + ) + noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise) + + training_batch.noisy_model_input = (noisy_model_input) + training_batch.timesteps = timesteps + training_batch.sigmas = sigmas + training_batch.noise = noise + training_batch.raw_latent_shape = latents.shape + + training_batch.conditional_dict = { + "encoder_hidden_states": (training_batch.encoder_hidden_states), + "encoder_attention_mask": (training_batch.encoder_attention_mask), + } + + if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None): + neg_embeds = self.negative_prompt_embeds + neg_mask = (self.negative_prompt_attention_mask) + if (neg_embeds.shape[0] == 1 and batch_size > 1): + neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous() + if (neg_mask.shape[0] == 1 and batch_size > 1): + neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous() + training_batch.unconditional_dict = { + "encoder_hidden_states": neg_embeds, + "encoder_attention_mask": neg_mask, + } + + training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4)) + return training_batch + + def _build_distill_input_kwargs( + self, + noise_input: torch.Tensor, + timestep: torch.Tensor, + text_dict: dict[str, torch.Tensor] | None, + ) -> dict[str, Any]: + if text_dict is None: + raise ValueError("text_dict cannot be None for " + "Wan distillation") + return { + "hidden_states": noise_input.permute(0, 2, 1, 3, 4), + "encoder_hidden_states": text_dict["encoder_hidden_states"], + "encoder_attention_mask": text_dict["encoder_attention_mask"], + "timestep": timestep, + "return_dict": False, + } + + def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module: + return self.transformer + + def _get_uncond_text_dict( + self, + batch: TrainingBatch, + *, + cfg_uncond: dict[str, Any] | None, + ) -> dict[str, torch.Tensor]: + if cfg_uncond is None: + text_dict = getattr(batch, "unconditional_dict", None) + if text_dict is None: + raise RuntimeError("Missing unconditional_dict; " + "ensure_negative_conditioning() " + "may have failed") + return text_dict + + on_missing_raw = cfg_uncond.get("on_missing", "error") + if not isinstance(on_missing_raw, str): + raise ValueError("method_config.cfg_uncond.on_missing " + "must be a string, got " + f"{type(on_missing_raw).__name__}") + on_missing = on_missing_raw.strip().lower() + if on_missing not in {"error", "ignore"}: + raise ValueError("method_config.cfg_uncond.on_missing " + "must be one of {error, ignore}, got " + f"{on_missing_raw!r}") + + for channel, policy_raw in cfg_uncond.items(): + if channel in {"on_missing", "text"}: + continue + if policy_raw is None: + continue + if not isinstance(policy_raw, str): + raise ValueError("method_config.cfg_uncond values " + "must be strings, got " + f"{channel}=" + f"{type(policy_raw).__name__}") + policy = policy_raw.strip().lower() + if policy == "keep": + continue + if on_missing == "ignore": + continue + raise ValueError("WanModel does not support " + "cfg_uncond channel " + f"{channel!r} (policy={policy!r}). " + "Set cfg_uncond.on_missing=ignore or " + "remove the channel.") + + text_policy_raw = cfg_uncond.get("text", None) + if text_policy_raw is None: + text_policy = "negative_prompt" + elif not isinstance(text_policy_raw, str): + raise ValueError("method_config.cfg_uncond.text must be " + "a string, got " + f"{type(text_policy_raw).__name__}") + else: + text_policy = (text_policy_raw.strip().lower()) + + if text_policy in {"negative_prompt"}: + text_dict = getattr(batch, "unconditional_dict", None) + if text_dict is None: + raise RuntimeError("Missing unconditional_dict; " + "ensure_negative_conditioning() " + "may have failed") + return text_dict + if text_policy == "keep": + if batch.conditional_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + return batch.conditional_dict + if text_policy == "zero": + if batch.conditional_dict is None: + raise RuntimeError("Missing conditional_dict in " + "TrainingBatch") + cond = batch.conditional_dict + enc = cond["encoder_hidden_states"] + mask = cond["encoder_attention_mask"] + if not torch.is_tensor(enc) or not torch.is_tensor(mask): + raise TypeError("conditional_dict must contain " + "tensor text inputs") + return { + "encoder_hidden_states": (torch.zeros_like(enc)), + "encoder_attention_mask": (torch.zeros_like(mask)), + } + if text_policy == "drop": + raise ValueError("cfg_uncond.text=drop is not supported " + "for Wan. Use " + "{negative_prompt, keep, zero}.") + raise ValueError("cfg_uncond.text must be one of " + "{negative_prompt, keep, zero, drop}, got " + f"{text_policy_raw!r}") diff --git a/fastvideo/train/models/wan/wan_causal.py b/fastvideo/train/models/wan/wan_causal.py new file mode 100644 index 000000000..ba4f2edf8 --- /dev/null +++ b/fastvideo/train/models/wan/wan_causal.py @@ -0,0 +1,570 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Wan causal model plugin (per-role instance, streaming/cache).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, TYPE_CHECKING + +import torch + +from fastvideo.forward_context import set_forward_context +from fastvideo.models.utils import pred_noise_to_pred_video + +from fastvideo.train.models.base import CausalModelBase +from fastvideo.train.models.wan.wan import WanModel + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + + +@dataclass(slots=True) +class _StreamingCaches: + kv_cache: list[dict[str, Any]] + crossattn_cache: list[dict[str, Any]] | None + frame_seq_length: int + local_attn_size: int + sliding_window_num_frames: int + batch_size: int + dtype: torch.dtype + device: torch.device + + +class WanCausalModel(WanModel, CausalModelBase): + """Wan per-role model with causal/streaming primitives.""" + + _transformer_cls_name: str = ( + "CausalWanTransformer3DModel") + + def __init__( + self, + *, + init_from: str, + training_config: TrainingConfig, + trainable: bool = True, + disable_custom_init_weights: bool = False, + flow_shift: float = 3.0, + enable_gradient_checkpointing_type: str + | None = None, + ) -> None: + super().__init__( + init_from=init_from, + training_config=training_config, + trainable=trainable, + disable_custom_init_weights=( + disable_custom_init_weights), + flow_shift=flow_shift, + enable_gradient_checkpointing_type=( + enable_gradient_checkpointing_type), + ) + self._streaming_caches: ( + dict[tuple[int, str], _StreamingCaches] + ) = {} + + # --- CausalModelBase override: clear_caches --- + def clear_caches( + self, *, cache_tag: str = "pos", + ) -> None: + self._streaming_caches.pop( + (id(self), str(cache_tag)), None) + + # --- CausalModelBase override: predict_noise_streaming --- + def predict_noise_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: Any, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError( + f"Unknown attn_kind: {attn_kind!r}") + + cache_tag = str(cache_tag) + cur_start_frame = int(cur_start_frame) + if cur_start_frame < 0: + raise ValueError( + "cur_start_frame must be >= 0") + + batch_size = int(noisy_latents.shape[0]) + num_frames = int(noisy_latents.shape[1]) + timestep_full = self._ensure_per_frame_timestep( + timestep=timestep, + batch_size=batch_size, + num_frames=num_frames, + device=noisy_latents.device, + ) + + transformer = self._get_transformer( + timestep_full) + caches = self._get_or_init_streaming_caches( + cache_tag=cache_tag, + transformer=transformer, + noisy_latents=noisy_latents, + ) + + frame_seq_length = int(caches.frame_seq_length) + kv_cache = caches.kv_cache + crossattn_cache = caches.crossattn_cache + + if (self._should_snapshot_streaming_cache() + and torch.is_grad_enabled()): + kv_cache = self._snapshot_kv_cache_indices( + kv_cache) + + model_kwargs: dict[str, Any] = { + "kv_cache": kv_cache, + "crossattn_cache": crossattn_cache, + "current_start": ( + cur_start_frame * frame_seq_length), + "start_frame": cur_start_frame, + "is_cache": bool(store_kv), + } + + device_type = self.device.type + dtype = noisy_latents.dtype + + if conditional: + text_dict = batch.conditional_dict + if text_dict is None: + raise RuntimeError( + "Missing conditional_dict in " + "TrainingBatch") + else: + text_dict = self._get_uncond_text_dict( + batch, cfg_uncond=cfg_uncond) + + with ( + torch.autocast(device_type, dtype=dtype), + set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ), + ): + input_kwargs = ( + self._build_distill_input_kwargs( + noisy_latents, + timestep_full, + text_dict, + )) + input_kwargs["timestep"] = ( + timestep_full.to( + device=self.device, + dtype=torch.long, + )) + input_kwargs.update(model_kwargs) + + if store_kv: + with torch.no_grad(): + _ = transformer(**input_kwargs) + return None + + pred_noise = transformer( + **input_kwargs, + ).permute(0, 2, 1, 3, 4) + return pred_noise + + # --- CausalModelBase override: predict_x0_streaming --- + def predict_x0_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: Any, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + pred_noise = self.predict_noise_streaming( + noisy_latents, + timestep, + batch, + conditional=conditional, + cache_tag=cache_tag, + store_kv=store_kv, + cur_start_frame=cur_start_frame, + cfg_uncond=cfg_uncond, + attn_kind=attn_kind, + ) + if pred_noise is None: + return None + + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=( + noisy_latents.flatten(0, 1)), + timestep=self.shift_and_clamp_timestep( + self._ensure_per_frame_timestep( + timestep=timestep, + batch_size=int( + noisy_latents.shape[0]), + num_frames=int( + noisy_latents.shape[1]), + device=noisy_latents.device, + ).flatten()), + scheduler=self.noise_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _ensure_per_frame_timestep( + self, + *, + timestep: torch.Tensor, + batch_size: int, + num_frames: int, + device: torch.device, + ) -> torch.Tensor: + if timestep.ndim == 0: + return ( + timestep.view(1, 1) + .expand(batch_size, num_frames) + .to(device=device)) + if timestep.ndim == 1: + if int(timestep.shape[0]) == batch_size: + return ( + timestep.view(batch_size, 1) + .expand(batch_size, num_frames) + .to(device=device)) + raise ValueError( + "streaming timestep must be scalar, " + "[B], or [B, T]; got shape=" + f"{tuple(timestep.shape)}") + if timestep.ndim == 2: + return timestep.to(device=device) + raise ValueError( + "streaming timestep must be scalar, " + "[B], or [B, T]; got ndim=" + f"{int(timestep.ndim)}") + + def _get_or_init_streaming_caches( + self, + *, + cache_tag: str, + transformer: torch.nn.Module, + noisy_latents: torch.Tensor, + ) -> _StreamingCaches: + key = (id(self), cache_tag) + cached = self._streaming_caches.get(key) + + batch_size = int(noisy_latents.shape[0]) + dtype = noisy_latents.dtype + device = noisy_latents.device + + frame_seq_length = ( + self._compute_frame_seq_length( + transformer, noisy_latents)) + local_attn_size = self._get_local_attn_size( + transformer) + sliding_window_num_frames = ( + self._get_sliding_window_num_frames( + transformer)) + + meta = ( + frame_seq_length, + local_attn_size, + sliding_window_num_frames, + batch_size, + dtype, + device, + ) + + if cached is not None: + cached_meta = ( + cached.frame_seq_length, + cached.local_attn_size, + cached.sliding_window_num_frames, + cached.batch_size, + cached.dtype, + cached.device, + ) + if cached_meta == meta: + return cached + + kv_cache = self._initialize_kv_cache( + transformer=transformer, + batch_size=batch_size, + dtype=dtype, + device=device, + frame_seq_length=frame_seq_length, + local_attn_size=local_attn_size, + sliding_window_num_frames=( + sliding_window_num_frames), + checkpoint_safe=( + self + ._should_use_checkpoint_safe_kv_cache() + ), + ) + crossattn_cache = ( + self._initialize_crossattn_cache( + transformer=transformer, + device=device, + )) + + caches = _StreamingCaches( + kv_cache=kv_cache, + crossattn_cache=crossattn_cache, + frame_seq_length=frame_seq_length, + local_attn_size=local_attn_size, + sliding_window_num_frames=( + sliding_window_num_frames), + batch_size=batch_size, + dtype=dtype, + device=device, + ) + self._streaming_caches[key] = caches + return caches + + def _compute_frame_seq_length( + self, + transformer: torch.nn.Module, + noisy_latents: torch.Tensor, + ) -> int: + latent_seq_length = ( + int(noisy_latents.shape[-1]) + * int(noisy_latents.shape[-2])) + patch_size = getattr( + transformer, "patch_size", None) + if patch_size is None: + patch_size = getattr( + getattr( + getattr(transformer, "config", None), + "arch_config", + None, + ), + "patch_size", + None, + ) + if patch_size is None: + raise ValueError( + "Unable to determine " + "transformer.patch_size " + "for causal streaming") + patch_ratio = ( + int(patch_size[-1]) * int(patch_size[-2])) + if patch_ratio <= 0: + raise ValueError( + "Invalid patch_size for causal " + "streaming") + return latent_seq_length // patch_ratio + + def _get_sliding_window_num_frames( + self, transformer: torch.nn.Module, + ) -> int: + cfg = getattr(transformer, "config", None) + arch_cfg = getattr(cfg, "arch_config", None) + value = ( + getattr( + arch_cfg, + "sliding_window_num_frames", + None, + ) + if arch_cfg is not None + else None) + if value is None: + return 15 + return int(value) + + def _get_local_attn_size( + self, transformer: torch.nn.Module, + ) -> int: + try: + value = getattr( + transformer, "local_attn_size", -1) + except Exception: + value = -1 + if value is None: + return -1 + return int(value) + + def _initialize_kv_cache( + self, + *, + transformer: torch.nn.Module, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + frame_seq_length: int, + local_attn_size: int, + sliding_window_num_frames: int, + checkpoint_safe: bool, + ) -> list[dict[str, Any]]: + num_blocks = len( + getattr(transformer, "blocks", [])) + if num_blocks <= 0: + raise ValueError( + "Unexpected transformer.blocks " + "for causal streaming") + + try: + num_attention_heads = int( + transformer.num_attention_heads) # type: ignore[attr-defined] + except AttributeError as e: + raise ValueError( + "Transformer is missing " + "num_attention_heads") from e + + try: + attention_head_dim = int( + transformer.attention_head_dim) # type: ignore[attr-defined] + except AttributeError: + try: + hidden_size = int( + transformer.hidden_size) # type: ignore[attr-defined] + except AttributeError as e: + raise ValueError( + "Transformer is missing " + "attention_head_dim and " + "hidden_size") from e + attention_head_dim = ( + hidden_size + // max(1, num_attention_heads)) + + if local_attn_size != -1: + kv_cache_size = ( + int(local_attn_size) + * int(frame_seq_length)) + else: + kv_cache_size = ( + int(frame_seq_length) + * int(sliding_window_num_frames)) + + if checkpoint_safe: + tc = getattr( + self, "training_config", None) + total_frames = int( + tc.data.num_frames + if tc is not None + else 0) + if total_frames <= 0: + raise ValueError( + "training.num_frames must be set " + "to enable checkpoint-safe " + "streaming KV cache; got " + f"{total_frames}") + kv_cache_size = max( + kv_cache_size, + int(frame_seq_length) + * total_frames, + ) + + kv_cache: list[dict[str, Any]] = [] + for _ in range(num_blocks): + kv_cache.append({ + "k": + torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": + torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "global_end_index": + torch.zeros( + (), dtype=torch.long, + device=device), + "local_end_index": + torch.zeros( + (), dtype=torch.long, + device=device), + }) + + return kv_cache + + def _should_use_checkpoint_safe_kv_cache( + self, + ) -> bool: + tc = getattr( + self, "training_config", None) + if tc is not None: + checkpointing_type = ( + tc.model + .enable_gradient_checkpointing_type) + else: + checkpointing_type = None + return (bool(checkpointing_type) + and bool(self._trainable)) + + def _should_snapshot_streaming_cache( + self, + ) -> bool: + return ( + self + ._should_use_checkpoint_safe_kv_cache()) + + def _snapshot_kv_cache_indices( + self, + kv_cache: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + snapshot: list[dict[str, Any]] = [] + for block_cache in kv_cache: + global_end_index = block_cache.get( + "global_end_index") + local_end_index = block_cache.get( + "local_end_index") + if ( + not isinstance( + global_end_index, torch.Tensor) + or not isinstance( + local_end_index, torch.Tensor) + ): + raise ValueError( + "Unexpected kv_cache index " + "tensors; expected tensors at " + "kv_cache[*].{global_end_index, " + "local_end_index}") + + copied = dict(block_cache) + copied["global_end_index"] = ( + global_end_index.detach().clone()) + copied["local_end_index"] = ( + local_end_index.detach().clone()) + snapshot.append(copied) + return snapshot + + def _initialize_crossattn_cache( + self, + *, + transformer: torch.nn.Module, + device: torch.device, + ) -> list[dict[str, Any]] | None: + num_blocks = len( + getattr(transformer, "blocks", [])) + if num_blocks <= 0: + return None + return [{ + "is_init": False, + "k": torch.empty(0, device=device), + "v": torch.empty(0, device=device), + } for _ in range(num_blocks)] diff --git a/fastvideo/train/models/wangame/__init__.py b/fastvideo/train/models/wangame/__init__.py new file mode 100644 index 000000000..101e3e7a8 --- /dev/null +++ b/fastvideo/train/models/wangame/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WanGame model plugin package.""" + +from fastvideo.train.models.wangame.wangame import ( + WanGameModel as WanGameModel, ) +from fastvideo.train.models.wangame.wangame_causal import ( + WanGameCausalModel as WanGameCausalModel, ) diff --git a/fastvideo/train/models/wangame/wangame.py b/fastvideo/train/models/wangame/wangame.py new file mode 100644 index 000000000..1d7a25855 --- /dev/null +++ b/fastvideo/train/models/wangame/wangame.py @@ -0,0 +1,816 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WanGame bidirectional model plugin (per-role instance).""" + +from __future__ import annotations + +import copy +from typing import Any, Literal, TYPE_CHECKING + +import torch + +import fastvideo.envs as envs +from fastvideo.distributed import ( + get_local_torch_device, + get_sp_group, + get_world_group, +) +from fastvideo.forward_context import set_forward_context +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, ) +from fastvideo.models.utils import pred_noise_to_pred_video +from fastvideo.pipelines import TrainingBatch +from fastvideo.training.activation_checkpoint import ( + apply_activation_checkpointing, ) +from fastvideo.training.training_utils import ( + compute_density_for_timestep_sampling, + get_sigmas, + normalize_dit_input, + shift_timestep, +) +from fastvideo.utils import ( + is_vmoba_available, + is_vsa_available, + set_random_seed, +) + +from fastvideo.train.models.base import ModelBase +from fastvideo.train.utils.module_state import ( + apply_trainable, ) +from fastvideo.train.utils.moduleloader import ( + load_module_from_path, ) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +try: + from fastvideo.attention.backends.video_sparse_attn import ( + VideoSparseAttentionMetadataBuilder, ) + from fastvideo.attention.backends.vmoba import ( + VideoMobaAttentionMetadataBuilder, ) +except Exception: + VideoSparseAttentionMetadataBuilder = None # type: ignore[assignment] + VideoMobaAttentionMetadataBuilder = None # type: ignore[assignment] + + +class WanGameModel(ModelBase): + """WanGame per-role model: owns transformer + noise_scheduler.""" + + _transformer_cls_name: str = ("WanGameActionTransformer3DModel") + + def __init__( + self, + *, + init_from: str, + training_config: TrainingConfig, + trainable: bool = True, + disable_custom_init_weights: bool = False, + flow_shift: float = 3.0, + enable_gradient_checkpointing_type: str | None = None, + ) -> None: + self._init_from = str(init_from) + self._trainable = bool(trainable) + + self.transformer = self._load_transformer( + init_from=self._init_from, + trainable=self._trainable, + disable_custom_init_weights=(disable_custom_init_weights), + enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type), + training_config=training_config, + ) + + self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift))) + + # Filled by init_preprocessors (student only). + self.vae: Any = None + self.training_config: TrainingConfig = training_config + self.dataloader: Any = None + self.validator: Any = None + self.start_step: int = 0 + + self.world_group: Any = None + self.sp_group: Any = None + self.device: Any = get_local_torch_device() + + self.noise_random_generator: (torch.Generator | None) = None + self.noise_gen_cuda: torch.Generator | None = None + + self.timestep_shift: float = float(flow_shift) + self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps) + self.min_timestep: int = 0 + self.max_timestep: int = self.num_train_timestep + + def _load_transformer( + self, + *, + init_from: str, + trainable: bool, + disable_custom_init_weights: bool, + enable_gradient_checkpointing_type: str | None, + training_config: TrainingConfig, + ) -> torch.nn.Module: + transformer = load_module_from_path( + model_path=init_from, + module_type="transformer", + training_config=training_config, + disable_custom_init_weights=(disable_custom_init_weights), + override_transformer_cls_name=(self._transformer_cls_name), + ) + transformer = apply_trainable(transformer, trainable=trainable) + if (trainable and enable_gradient_checkpointing_type): + transformer = apply_activation_checkpointing( + transformer, + checkpointing_type=(enable_gradient_checkpointing_type), + ) + return transformer + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_preprocessors(self, training_config: TrainingConfig) -> None: + """Load VAE, build dataloader, seed RNGs.""" + self.vae = load_module_from_path( + model_path=str(training_config.model_path), + module_type="vae", + training_config=training_config, + ) + + self.world_group = get_world_group() + self.sp_group = get_sp_group() + + self._init_timestep_mechanics() + + from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_wangame, ) + from fastvideo.train.utils.dataloader import ( + build_parquet_wangame_train_dataloader, ) + + self.dataloader = (build_parquet_wangame_train_dataloader( + training_config.data, + parquet_schema=pyarrow_schema_wangame, + )) + self.start_step = 0 + + # ------------------------------------------------------------------ + # ModelBase overrides: timestep helpers + # ------------------------------------------------------------------ + + @property + def num_train_timesteps(self) -> int: + return int(self.num_train_timestep) + + def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor: + timestep = shift_timestep( + timestep, + self.timestep_shift, + self.num_train_timestep, + ) + return timestep.clamp(self.min_timestep, self.max_timestep) + + # ------------------------------------------------------------------ + # ModelBase overrides: lifecycle hooks + # ------------------------------------------------------------------ + + def on_train_start(self) -> None: + assert self.training_config is not None + tc = self.training_config + seed = tc.data.seed + if seed is None: + raise ValueError("training.data.seed must be set " + "for training") + + global_rank = int(getattr(self.world_group, "rank", 0)) + sp_world_size = int(tc.distributed.sp_size or 1) + if sp_world_size > 1: + sp_group_seed = int(seed) + (global_rank // sp_world_size) + set_random_seed(sp_group_seed) + else: + set_random_seed(int(seed) + global_rank) + + self.noise_random_generator = torch.Generator(device="cpu").manual_seed(int(seed)) + self.noise_gen_cuda = torch.Generator(device=self.device).manual_seed(int(seed)) + + def get_rng_generators(self, ) -> dict[str, torch.Generator]: + generators: dict[str, torch.Generator] = {} + if self.noise_random_generator is not None: + generators["noise_cpu"] = (self.noise_random_generator) + if self.noise_gen_cuda is not None: + generators["noise_cuda"] = self.noise_gen_cuda + return generators + + # ------------------------------------------------------------------ + # ModelBase overrides: runtime primitives + # ------------------------------------------------------------------ + + def prepare_batch( + self, + raw_batch: dict[str, Any], + *, + current_vsa_sparsity: float = 0.0, + latents_source: Literal["data", "zeros"] = "data", + ) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + dtype = self._get_training_dtype() + device = self.device + + training_batch = TrainingBatch(current_vsa_sparsity=current_vsa_sparsity) + infos = raw_batch.get("info_list") + + if latents_source == "zeros": + clip_feature = raw_batch["clip_feature"] + batch_size = int(clip_feature.shape[0]) + vae_config = ( + tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr] + ) + num_channels = int(vae_config.z_dim) + spatial_compression_ratio = int(vae_config.spatial_compression_ratio) + latent_height = (int(tc.data.num_height) // spatial_compression_ratio) + latent_width = (int(tc.data.num_width) // spatial_compression_ratio) + latents = torch.zeros( + batch_size, + num_channels, + int(tc.data.num_latent_t), + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + elif latents_source == "data": + if "vae_latent" not in raw_batch: + raise ValueError("vae_latent not found in batch " + "and latents_source='data'") + latents = raw_batch["vae_latent"] + latents = latents[:, :, :tc.data.num_latent_t] + latents = latents.to(device, dtype=dtype) + else: + raise ValueError(f"Unknown latents_source: " + f"{latents_source!r}") + + if "clip_feature" not in raw_batch: + raise ValueError("clip_feature must be present for WanGame") + image_embeds = raw_batch["clip_feature"].to(device, dtype=dtype) + + if "first_frame_latent" not in raw_batch: + raise ValueError("first_frame_latent must be present " + "for WanGame") + image_latents = raw_batch["first_frame_latent"] + image_latents = image_latents[:, :, :tc.data.num_latent_t] + image_latents = image_latents.to(device, dtype=dtype) + + pil_image = raw_batch.get("pil_image") + if isinstance(pil_image, torch.Tensor): + training_batch.preprocessed_image = (pil_image.to(device=device)) + else: + training_batch.preprocessed_image = pil_image + + keyboard_cond = raw_batch.get("keyboard_cond") + if (isinstance(keyboard_cond, torch.Tensor) and keyboard_cond.numel() > 0): + training_batch.keyboard_cond = (keyboard_cond.to(device, dtype=dtype)) + else: + training_batch.keyboard_cond = None + + mouse_cond = raw_batch.get("mouse_cond") + if (isinstance(mouse_cond, torch.Tensor) and mouse_cond.numel() > 0): + training_batch.mouse_cond = mouse_cond.to(device, dtype=dtype) + else: + training_batch.mouse_cond = None + + temporal_compression_ratio = ( + tc.pipeline_config.vae_config.arch_config.temporal_compression_ratio # type: ignore[union-attr] + ) + expected_num_frames = ((tc.data.num_latent_t - 1) * temporal_compression_ratio + 1) + if (training_batch.keyboard_cond is not None + and int(training_batch.keyboard_cond.shape[1]) != int(expected_num_frames)): + raise ValueError("keyboard_cond temporal dim mismatch: " + f"got {int(training_batch.keyboard_cond.shape[1])}, " + f"expected {int(expected_num_frames)}") + if (training_batch.mouse_cond is not None + and int(training_batch.mouse_cond.shape[1]) != int(expected_num_frames)): + raise ValueError("mouse_cond temporal dim mismatch: " + f"got {int(training_batch.mouse_cond.shape[1])}, " + f"expected {int(expected_num_frames)}") + + training_batch.latents = latents + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.image_embeds = image_embeds + training_batch.image_latents = image_latents + training_batch.infos = infos + + training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae) + training_batch = self._prepare_dit_inputs(training_batch) + training_batch = self._build_attention_metadata(training_batch) + + training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata) + if training_batch.attn_metadata is not None: + training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined] + + training_batch.mask_lat_size = (self._build_i2v_mask_latents(image_latents)) + viewmats, intrinsics, action_labels = (self._process_actions(training_batch)) + training_batch.viewmats = viewmats + training_batch.Ks = intrinsics + training_batch.action = action_labels + + return training_batch + + def add_noise( + self, + clean_latents: torch.Tensor, + noise: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + b, t = clean_latents.shape[:2] + noisy = self.noise_scheduler.add_noise( + clean_latents.flatten(0, 1), + noise.flatten(0, 1), + timestep, + ).unflatten(0, (b, t)) + return noisy + + def predict_x0( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + cond_inputs = (self._select_cfg_condition_inputs( + batch, + conditional=conditional, + cfg_uncond=cfg_uncond, + )) + input_kwargs = (self._build_distill_input_kwargs( + noisy_latents, + timestep, + image_embeds=cond_inputs["image_embeds"], + image_latents=cond_inputs["image_latents"], + mask_lat_size=cond_inputs["mask_lat_size"], + viewmats=cond_inputs["viewmats"], + Ks=cond_inputs["Ks"], + action=cond_inputs["action"], + mouse_cond=cond_inputs["mouse_cond"], + keyboard_cond=cond_inputs["keyboard_cond"], + )) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor: + device_type = self.device.type + dtype = noisy_latents.dtype + + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + cond_inputs = (self._select_cfg_condition_inputs( + batch, + conditional=conditional, + cfg_uncond=cfg_uncond, + )) + input_kwargs = (self._build_distill_input_kwargs( + noisy_latents, + timestep, + image_embeds=cond_inputs["image_embeds"], + image_latents=cond_inputs["image_latents"], + mask_lat_size=cond_inputs["mask_lat_size"], + viewmats=cond_inputs["viewmats"], + Ks=cond_inputs["Ks"], + action=cond_inputs["action"], + mouse_cond=cond_inputs["mouse_cond"], + keyboard_cond=cond_inputs["keyboard_cond"], + )) + transformer = self._get_transformer(timestep) + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + return pred_noise + + def backward( + self, + loss: torch.Tensor, + ctx: Any, + *, + grad_accum_rounds: int, + ) -> None: + timesteps, attn_metadata = ctx + with set_forward_context( + current_timestep=timesteps, + attn_metadata=attn_metadata, + ): + (loss / max(1, int(grad_accum_rounds))).backward() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_training_dtype(self) -> torch.dtype: + return torch.bfloat16 + + def _init_timestep_mechanics(self) -> None: + assert self.training_config is not None + tc = self.training_config + self.timestep_shift = float(tc.pipeline_config.flow_shift # type: ignore[union-attr] + ) + self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps) + self.min_timestep = 0 + self.max_timestep = self.num_train_timestep + + def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: + if self.noise_random_generator is None: + raise RuntimeError("on_train_start() must be called " + "before prepare_batch()") + assert self.training_config is not None + tc = self.training_config + + u = compute_density_for_timestep_sampling( + weighting_scheme=tc.model.weighting_scheme, + batch_size=batch_size, + generator=self.noise_random_generator, + logit_mean=tc.model.logit_mean, + logit_std=tc.model.logit_std, + mode_scale=tc.model.mode_scale, + ) + indices = (u * self.noise_scheduler.config.num_train_timesteps).long() + return self.noise_scheduler.timesteps[indices].to(device=device) + + def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents_shape = training_batch.raw_latent_shape + patch_size = ( + tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr] + ) + current_vsa_sparsity = (training_batch.current_vsa_sparsity) + assert latents_shape is not None + assert training_batch.timesteps is not None + + if (envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN"): + if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VIDEO_SPARSE_ATTN, but " + "fastvideo_kernel is not correctly " + "installed or detected.") + training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc] + raw_latent_shape=latents_shape[2:5], + current_timestep=(training_batch.timesteps), + patch_size=patch_size, + VSA_sparsity=current_vsa_sparsity, + device=self.device, + ) + elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"): + if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None): + raise ImportError("FASTVIDEO_ATTENTION_BACKEND is " + "VMOBA_ATTN, but fastvideo_kernel " + "(or flash_attn>=2.7.4) is not " + "correctly installed.") + moba_params = tc.model.moba_config.copy() + moba_params.update({ + "current_timestep": (training_batch.timesteps), + "raw_latent_shape": (training_batch.raw_latent_shape[2:5]), + "patch_size": patch_size, + "device": self.device, + }) + training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(** + moba_params) # type: ignore[misc] + else: + training_batch.attn_metadata = None + + return training_batch + + def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch: + assert self.training_config is not None + tc = self.training_config + latents = training_batch.latents + assert isinstance(latents, torch.Tensor) + batch_size = latents.shape[0] + + if self.noise_gen_cuda is None: + raise RuntimeError("on_train_start() must be called " + "before prepare_batch()") + + noise = torch.randn( + latents.shape, + generator=self.noise_gen_cuda, + device=latents.device, + dtype=latents.dtype, + ) + timesteps = self._sample_timesteps(batch_size, latents.device) + if int(tc.distributed.sp_size or 1) > 1: + self.sp_group.broadcast(timesteps, src=0) + + sigmas = get_sigmas( + self.noise_scheduler, + latents.device, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + ) + noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise) + + training_batch.noisy_model_input = (noisy_model_input) + training_batch.timesteps = timesteps + training_batch.sigmas = sigmas + training_batch.noise = noise + training_batch.raw_latent_shape = latents.shape + + training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4)) + return training_batch + + def _build_i2v_mask_latents(self, image_latents: torch.Tensor) -> torch.Tensor: + assert self.training_config is not None + tc = self.training_config + temporal_compression_ratio = ( + tc.pipeline_config.vae_config.arch_config.temporal_compression_ratio # type: ignore[union-attr] + ) + num_frames = ((tc.data.num_latent_t - 1) * temporal_compression_ratio + 1) + + ( + batch_size, + _num_channels, + _t, + latent_height, + latent_width, + ) = image_latents.shape + mask_lat_size = torch.ones( + batch_size, + 1, + num_frames, + latent_height, + latent_width, + ) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, + dim=2, + repeats=temporal_compression_ratio, + ) + mask_lat_size = torch.cat( + [first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2, + ) + mask_lat_size = mask_lat_size.view( + batch_size, + -1, + temporal_compression_ratio, + latent_height, + latent_width, + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + return mask_lat_size.to( + device=image_latents.device, + dtype=image_latents.dtype, + ) + + def _process_actions(self, training_batch: TrainingBatch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + keyboard_cond = getattr(training_batch, "keyboard_cond", None) + mouse_cond = getattr(training_batch, "mouse_cond", None) + if keyboard_cond is None or mouse_cond is None: + raise ValueError("WanGame batch must provide " + "keyboard_cond and mouse_cond") + + from fastvideo.models.dits.hyworld.pose import ( + process_custom_actions, ) + + batch_size = int(training_batch.noisy_model_input.shape[0] # type: ignore[union-attr] + ) + viewmats_list: list[torch.Tensor] = [] + intrinsics_list: list[torch.Tensor] = [] + action_labels_list: list[torch.Tensor] = [] + for b in range(batch_size): + v, i, a = process_custom_actions(keyboard_cond[b], mouse_cond[b]) + viewmats_list.append(v) + intrinsics_list.append(i) + action_labels_list.append(a) + + viewmats = torch.stack(viewmats_list, dim=0).to(device=self.device, dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, dim=0).to(device=self.device, dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, dim=0).to(device=self.device, dtype=torch.bfloat16) + + num_latent_t = int(training_batch.noisy_model_input.shape[2] # type: ignore[union-attr] + ) + if int(action_labels.shape[1]) != num_latent_t: + raise ValueError("Action conditioning temporal dim " + "mismatch: " + f"action={tuple(action_labels.shape)} " + f"vs latent_t={num_latent_t}") + if int(viewmats.shape[1]) != num_latent_t: + raise ValueError("Viewmats temporal dim mismatch: " + f"viewmats={tuple(viewmats.shape)} " + f"vs latent_t={num_latent_t}") + + return viewmats, intrinsics, action_labels + + def _build_distill_input_kwargs( + self, + noisy_video_latents: torch.Tensor, + timestep: torch.Tensor, + *, + image_embeds: torch.Tensor, + image_latents: torch.Tensor, + mask_lat_size: torch.Tensor, + viewmats: torch.Tensor | None, + Ks: torch.Tensor | None, + action: torch.Tensor | None, + mouse_cond: torch.Tensor | None, + keyboard_cond: torch.Tensor | None, + ) -> dict[str, Any]: + hidden_states = torch.cat( + [ + noisy_video_latents.permute(0, 2, 1, 3, 4), + mask_lat_size, + image_latents, + ], + dim=1, + ) + return { + "hidden_states": hidden_states, + "encoder_hidden_states": None, + "timestep": timestep.to(device=self.device, dtype=torch.bfloat16), + "encoder_hidden_states_image": image_embeds, + "viewmats": viewmats, + "Ks": Ks, + "action": action, + "mouse_cond": mouse_cond, + "keyboard_cond": keyboard_cond, + "return_dict": False, + } + + def _select_cfg_condition_inputs( + self, + batch: TrainingBatch, + *, + conditional: bool, + cfg_uncond: dict[str, Any] | None, + ) -> dict[str, Any]: + image_embeds = batch.image_embeds + image_latents = batch.image_latents + mask_lat_size = batch.mask_lat_size + if image_embeds is None: + raise RuntimeError("WanGameModel requires " + "TrainingBatch.image_embeds") + if image_latents is None: + raise RuntimeError("WanGameModel requires " + "TrainingBatch.image_latents") + if mask_lat_size is None: + raise RuntimeError("WanGameModel requires " + "TrainingBatch.mask_lat_size") + + viewmats = getattr(batch, "viewmats", None) + Ks = getattr(batch, "Ks", None) + action = getattr(batch, "action", None) + mouse_cond = getattr(batch, "mouse_cond", None) + keyboard_cond = getattr(batch, "keyboard_cond", None) + + if conditional or cfg_uncond is None: + return { + "image_embeds": image_embeds, + "image_latents": image_latents, + "mask_lat_size": mask_lat_size, + "viewmats": viewmats, + "Ks": Ks, + "action": action, + "mouse_cond": mouse_cond, + "keyboard_cond": keyboard_cond, + } + + on_missing_raw = cfg_uncond.get("on_missing", "error") + if not isinstance(on_missing_raw, str): + raise ValueError("method_config.cfg_uncond.on_missing " + "must be a string, got " + f"{type(on_missing_raw).__name__}") + on_missing = on_missing_raw.strip().lower() + if on_missing not in {"error", "ignore"}: + raise ValueError("method_config.cfg_uncond.on_missing " + "must be one of {error, ignore}, got " + f"{on_missing_raw!r}") + + supported_channels = {"image", "action"} + for channel, policy_raw in cfg_uncond.items(): + if channel in {"on_missing"}: + continue + if channel in supported_channels: + continue + if policy_raw is None: + continue + if not isinstance(policy_raw, str): + raise ValueError("method_config.cfg_uncond values " + "must be strings, got " + f"{channel}=" + f"{type(policy_raw).__name__}") + policy = policy_raw.strip().lower() + if policy == "keep": + continue + if on_missing == "ignore": + continue + raise ValueError("WanGameModel does not support " + "cfg_uncond channel " + f"{channel!r} (policy={policy!r}). " + "Set cfg_uncond.on_missing=ignore or " + "remove the channel.") + + def _get_policy(channel: str) -> str: + raw = cfg_uncond.get(channel, "keep") + if raw is None: + return "keep" + if not isinstance(raw, str): + raise ValueError("method_config.cfg_uncond values " + "must be strings, got " + f"{channel}={type(raw).__name__}") + policy = raw.strip().lower() + if policy not in {"keep", "zero", "drop"}: + raise ValueError("method_config.cfg_uncond values " + "must be one of " + "{keep, zero, drop}, got " + f"{channel}={raw!r}") + return policy + + image_policy = _get_policy("image") + if image_policy == "zero": + image_embeds = torch.zeros_like(image_embeds) + image_latents = torch.zeros_like(image_latents) + mask_lat_size = torch.zeros_like(mask_lat_size) + elif image_policy == "drop": + raise ValueError("cfg_uncond.image=drop is not supported " + "for WanGame I2V; use " + "cfg_uncond.image=zero or keep.") + + action_policy = _get_policy("action") + if action_policy == "zero": + if (viewmats is None or Ks is None or action is None): + if on_missing == "ignore": + pass + else: + raise ValueError("cfg_uncond.action=zero requires " + "action conditioning tensors, " + "but TrainingBatch is missing " + "{viewmats, Ks, action}.") + else: + viewmats = torch.zeros_like(viewmats) + Ks = torch.zeros_like(Ks) + action = torch.zeros_like(action) + if mouse_cond is not None: + mouse_cond = torch.zeros_like(mouse_cond) + if keyboard_cond is not None: + keyboard_cond = torch.zeros_like(keyboard_cond) + elif action_policy == "drop": + viewmats = None + Ks = None + action = None + mouse_cond = None + keyboard_cond = None + + return { + "image_embeds": image_embeds, + "image_latents": image_latents, + "mask_lat_size": mask_lat_size, + "viewmats": viewmats, + "Ks": Ks, + "action": action, + "mouse_cond": mouse_cond, + "keyboard_cond": keyboard_cond, + } + + def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module: + return self.transformer diff --git a/fastvideo/train/models/wangame/wangame_causal.py b/fastvideo/train/models/wangame/wangame_causal.py new file mode 100644 index 000000000..902ed7824 --- /dev/null +++ b/fastvideo/train/models/wangame/wangame_causal.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WanGame causal model plugin (per-role instance, streaming/cache).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, TYPE_CHECKING + +import torch + +from fastvideo.forward_context import set_forward_context +from fastvideo.models.utils import pred_noise_to_pred_video + +from fastvideo.train.models.base import CausalModelBase +from fastvideo.train.models.wangame.wangame import WanGameModel + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + + +@dataclass(slots=True) +class _StreamingCaches: + kv_cache: list[dict[str, Any]] + crossattn_cache: list[dict[str, Any]] | None + frame_seq_length: int + local_attn_size: int + sliding_window_num_frames: int + batch_size: int + dtype: torch.dtype + device: torch.device + + +class WanGameCausalModel(WanGameModel, CausalModelBase): + """WanGame per-role model with causal/streaming primitives.""" + + _transformer_cls_name: str = ("CausalWanGameActionTransformer3DModel") + + def __init__( + self, + *, + init_from: str, + training_config: TrainingConfig, + trainable: bool = True, + disable_custom_init_weights: bool = False, + flow_shift: float = 3.0, + enable_gradient_checkpointing_type: str | None = None, + ) -> None: + super().__init__( + init_from=init_from, + trainable=trainable, + disable_custom_init_weights=disable_custom_init_weights, + flow_shift=flow_shift, + enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type), + training_config=training_config, + ) + self._streaming_caches: dict[tuple[int, str], _StreamingCaches] = {} + + # --- CausalModelBase override: clear_caches --- + def clear_caches(self, *, cache_tag: str = "pos") -> None: + self._streaming_caches.pop((id(self), str(cache_tag)), None) + + # --- CausalModelBase override: predict_noise_streaming --- + def predict_noise_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: Any, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + if attn_kind == "dense": + attn_metadata = batch.attn_metadata + elif attn_kind == "vsa": + attn_metadata = batch.attn_metadata_vsa + else: + raise ValueError(f"Unknown attn_kind: {attn_kind!r}") + + cache_tag = str(cache_tag) + cur_start_frame = int(cur_start_frame) + if cur_start_frame < 0: + raise ValueError("cur_start_frame must be >= 0") + + batch_size = int(noisy_latents.shape[0]) + num_frames = int(noisy_latents.shape[1]) + timestep_full = self._ensure_per_frame_timestep( + timestep=timestep, + batch_size=batch_size, + num_frames=num_frames, + device=noisy_latents.device, + ) + + transformer = self._get_transformer(timestep_full) + caches = self._get_or_init_streaming_caches( + cache_tag=cache_tag, + transformer=transformer, + noisy_latents=noisy_latents, + ) + + frame_seq_length = int(caches.frame_seq_length) + kv_cache = caches.kv_cache + crossattn_cache = caches.crossattn_cache + + if (self._should_snapshot_streaming_cache() and torch.is_grad_enabled()): + kv_cache = self._snapshot_kv_cache_indices(kv_cache) + + model_kwargs: dict[str, Any] = { + "kv_cache": kv_cache, + "crossattn_cache": crossattn_cache, + "current_start": cur_start_frame * frame_seq_length, + "start_frame": cur_start_frame, + "is_cache": bool(store_kv), + } + + device_type = self.device.type + dtype = noisy_latents.dtype + with torch.autocast(device_type, dtype=dtype), set_forward_context( + current_timestep=batch.timesteps, + attn_metadata=attn_metadata, + ): + cond_inputs = self._select_cfg_condition_inputs( + batch, + conditional=conditional, + cfg_uncond=cfg_uncond, + ) + cond_inputs = self._slice_cond_inputs_for_streaming( + cond_inputs=cond_inputs, + cur_start_frame=cur_start_frame, + num_frames=num_frames, + ) + input_kwargs = self._build_distill_input_kwargs( + noisy_latents, + timestep_full, + image_embeds=cond_inputs["image_embeds"], + image_latents=cond_inputs["image_latents"], + mask_lat_size=cond_inputs["mask_lat_size"], + viewmats=cond_inputs["viewmats"], + Ks=cond_inputs["Ks"], + action=cond_inputs["action"], + mouse_cond=cond_inputs["mouse_cond"], + keyboard_cond=cond_inputs["keyboard_cond"], + ) + + input_kwargs["timestep"] = timestep_full.to(device=self.device, dtype=torch.long) + input_kwargs.update(model_kwargs) + + if store_kv: + with torch.no_grad(): + _ = transformer(**input_kwargs) + return None + + pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + return pred_noise + + # --- CausalModelBase override: predict_x0_streaming --- + def predict_x0_streaming( + self, + noisy_latents: torch.Tensor, + timestep: torch.Tensor, + batch: Any, + *, + conditional: bool, + cache_tag: str = "pos", + store_kv: bool = False, + cur_start_frame: int = 0, + cfg_uncond: dict[str, Any] | None = None, + attn_kind: Literal["dense", "vsa"] = "dense", + ) -> torch.Tensor | None: + pred_noise = self.predict_noise_streaming( + noisy_latents, + timestep, + batch, + conditional=conditional, + cache_tag=cache_tag, + store_kv=store_kv, + cur_start_frame=cur_start_frame, + cfg_uncond=cfg_uncond, + attn_kind=attn_kind, + ) + if pred_noise is None: + return None + + pred_x0 = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noisy_latents.flatten(0, 1), + timestep=self.shift_and_clamp_timestep( + self._ensure_per_frame_timestep( + timestep=timestep, + batch_size=int(noisy_latents.shape[0]), + num_frames=int(noisy_latents.shape[1]), + device=noisy_latents.device, + ).flatten()), + scheduler=self.noise_scheduler, + ).unflatten(0, pred_noise.shape[:2]) + return pred_x0 + + # --- internal helpers --- + + def _ensure_per_frame_timestep( + self, + *, + timestep: torch.Tensor, + batch_size: int, + num_frames: int, + device: torch.device, + ) -> torch.Tensor: + if timestep.ndim == 0: + return (timestep.view(1, 1).expand(batch_size, num_frames).to(device=device)) + if timestep.ndim == 1: + if int(timestep.shape[0]) == batch_size: + return (timestep.view(batch_size, 1).expand(batch_size, num_frames).to(device=device)) + raise ValueError("streaming timestep must be scalar, [B], or " + f"[B, T]; got shape={tuple(timestep.shape)}") + if timestep.ndim == 2: + return timestep.to(device=device) + raise ValueError("streaming timestep must be scalar, [B], or [B, T]; " + f"got ndim={int(timestep.ndim)}") + + def _slice_cond_inputs_for_streaming( + self, + *, + cond_inputs: dict[str, Any], + cur_start_frame: int, + num_frames: int, + ) -> dict[str, Any]: + start = int(cur_start_frame) + num_frames = int(num_frames) + if num_frames <= 0: + raise ValueError("num_frames must be positive for streaming") + if start < 0: + raise ValueError("cur_start_frame must be >= 0 for streaming") + end = start + num_frames + + sliced: dict[str, Any] = dict(cond_inputs) + + image_latents = cond_inputs.get("image_latents") + if isinstance(image_latents, torch.Tensor): + sliced["image_latents"] = image_latents[:, :, start:end] + + mask_lat_size = cond_inputs.get("mask_lat_size") + if isinstance(mask_lat_size, torch.Tensor): + sliced["mask_lat_size"] = mask_lat_size[:, :, start:end] + + viewmats = cond_inputs.get("viewmats") + if isinstance(viewmats, torch.Tensor): + sliced["viewmats"] = viewmats[:, start:end] + + Ks = cond_inputs.get("Ks") + if isinstance(Ks, torch.Tensor): + sliced["Ks"] = Ks[:, start:end] + + action = cond_inputs.get("action") + if isinstance(action, torch.Tensor): + sliced["action"] = action[:, start:end] + + temporal_compression_ratio = int( + self.training_config.pipeline_config.vae_config.arch_config.temporal_compression_ratio) + raw_end_frame_idx = (1 + temporal_compression_ratio * max(0, end - 1)) + + mouse_cond = cond_inputs.get("mouse_cond") + if isinstance(mouse_cond, torch.Tensor): + sliced["mouse_cond"] = mouse_cond[:, :raw_end_frame_idx] + + keyboard_cond = cond_inputs.get("keyboard_cond") + if isinstance(keyboard_cond, torch.Tensor): + sliced["keyboard_cond"] = keyboard_cond[:, :raw_end_frame_idx] + + return sliced + + def _get_or_init_streaming_caches( + self, + *, + cache_tag: str, + transformer: torch.nn.Module, + noisy_latents: torch.Tensor, + ) -> _StreamingCaches: + key = (id(self), cache_tag) + cached = self._streaming_caches.get(key) + + batch_size = int(noisy_latents.shape[0]) + dtype = noisy_latents.dtype + device = noisy_latents.device + + frame_seq_length = self._compute_frame_seq_length(transformer, noisy_latents) + local_attn_size = self._get_local_attn_size(transformer) + sliding_window_num_frames = (self._get_sliding_window_num_frames(transformer)) + + meta = ( + frame_seq_length, + local_attn_size, + sliding_window_num_frames, + batch_size, + dtype, + device, + ) + + if cached is not None: + cached_meta = ( + cached.frame_seq_length, + cached.local_attn_size, + cached.sliding_window_num_frames, + cached.batch_size, + cached.dtype, + cached.device, + ) + if cached_meta == meta: + return cached + + kv_cache = self._initialize_kv_cache( + transformer=transformer, + batch_size=batch_size, + dtype=dtype, + device=device, + frame_seq_length=frame_seq_length, + local_attn_size=local_attn_size, + sliding_window_num_frames=sliding_window_num_frames, + checkpoint_safe=(self._should_use_checkpoint_safe_kv_cache()), + ) + crossattn_cache = self._initialize_crossattn_cache(transformer=transformer, device=device) + + caches = _StreamingCaches( + kv_cache=kv_cache, + crossattn_cache=crossattn_cache, + frame_seq_length=frame_seq_length, + local_attn_size=local_attn_size, + sliding_window_num_frames=sliding_window_num_frames, + batch_size=batch_size, + dtype=dtype, + device=device, + ) + self._streaming_caches[key] = caches + return caches + + def _compute_frame_seq_length( + self, + transformer: torch.nn.Module, + noisy_latents: torch.Tensor, + ) -> int: + latent_seq_length = int(noisy_latents.shape[-1]) * int(noisy_latents.shape[-2]) + patch_size = getattr(transformer, "patch_size", None) + if patch_size is None: + patch_size = getattr( + getattr(transformer, "config", None), + "arch_config", + None, + ) + patch_size = getattr(patch_size, "patch_size", None) + if patch_size is None: + raise ValueError("Unable to determine transformer.patch_size " + "for causal streaming") + patch_ratio = int(patch_size[-1]) * int(patch_size[-2]) + if patch_ratio <= 0: + raise ValueError("Invalid patch_size for causal streaming") + return latent_seq_length // patch_ratio + + def _get_sliding_window_num_frames(self, transformer: torch.nn.Module) -> int: + cfg = getattr(transformer, "config", None) + arch_cfg = getattr(cfg, "arch_config", None) + value = (getattr(arch_cfg, "sliding_window_num_frames", None) if arch_cfg is not None else None) + if value is None: + return 15 + return int(value) + + def _get_local_attn_size(self, transformer: torch.nn.Module) -> int: + try: + value = getattr(transformer, "local_attn_size", -1) + except Exception: + value = -1 + if value is None: + return -1 + return int(value) + + def _initialize_kv_cache( + self, + *, + transformer: torch.nn.Module, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + frame_seq_length: int, + local_attn_size: int, + sliding_window_num_frames: int, + checkpoint_safe: bool, + ) -> list[dict[str, Any]]: + num_blocks = len(getattr(transformer, "blocks", [])) + if num_blocks <= 0: + raise ValueError("Unexpected transformer.blocks for causal " + "streaming") + + try: + num_attention_heads = int(transformer.num_attention_heads # type: ignore[attr-defined] + ) + except AttributeError as e: + raise ValueError("Transformer is missing num_attention_heads") from e + + try: + attention_head_dim = int(transformer.attention_head_dim # type: ignore[attr-defined] + ) + except AttributeError: + try: + hidden_size = int(transformer.hidden_size # type: ignore[attr-defined] + ) + except AttributeError as e: + raise ValueError("Transformer is missing attention_head_dim " + "and hidden_size") from e + attention_head_dim = hidden_size // max(1, num_attention_heads) + + if local_attn_size != -1: + kv_cache_size = (int(local_attn_size) * int(frame_seq_length)) + else: + kv_cache_size = int(frame_seq_length) * int(sliding_window_num_frames) + + if checkpoint_safe: + tc = getattr(self, "training_config", None) + total_frames = int(tc.data.num_frames if tc is not None else 0) + if total_frames <= 0: + raise ValueError("training.num_frames must be set to enable " + "checkpoint-safe streaming KV cache; " + f"got {total_frames}") + kv_cache_size = max( + kv_cache_size, + int(frame_seq_length) * total_frames, + ) + + kv_cache: list[dict[str, Any]] = [] + for _ in range(num_blocks): + kv_cache.append({ + "k": + torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": + torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "global_end_index": + torch.zeros((), dtype=torch.long, device=device), + "local_end_index": + torch.zeros((), dtype=torch.long, device=device), + }) + + return kv_cache + + def _should_use_checkpoint_safe_kv_cache(self) -> bool: + tc = getattr(self, "training_config", None) + if tc is not None: + checkpointing_type = tc.model.enable_gradient_checkpointing_type + else: + checkpointing_type = None + return bool(checkpointing_type) and bool(self._trainable) + + def _should_snapshot_streaming_cache(self) -> bool: + return self._should_use_checkpoint_safe_kv_cache() + + def _snapshot_kv_cache_indices(self, kv_cache: list[dict[str, Any]]) -> list[dict[str, Any]]: + snapshot: list[dict[str, Any]] = [] + for block_cache in kv_cache: + global_end_index = block_cache.get("global_end_index") + local_end_index = block_cache.get("local_end_index") + if not isinstance(global_end_index, torch.Tensor) or not isinstance(local_end_index, torch.Tensor): + raise ValueError("Unexpected kv_cache index tensors; expected " + "tensors at kv_cache[*].{global_end_index, " + "local_end_index}") + + copied = dict(block_cache) + copied["global_end_index"] = (global_end_index.detach().clone()) + copied["local_end_index"] = (local_end_index.detach().clone()) + snapshot.append(copied) + return snapshot + + def _initialize_crossattn_cache( + self, + *, + transformer: torch.nn.Module, + device: torch.device, + ) -> list[dict[str, Any]] | None: + num_blocks = len(getattr(transformer, "blocks", [])) + if num_blocks <= 0: + return None + return [{ + "is_init": False, + "k": torch.empty(0, device=device), + "v": torch.empty(0, device=device), + } for _ in range(num_blocks)] diff --git a/fastvideo/train/trainer.py b/fastvideo/train/trainer.py new file mode 100644 index 000000000..b44fe57a5 --- /dev/null +++ b/fastvideo/train/trainer.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, TYPE_CHECKING + +import torch +from tqdm.auto import tqdm + +from fastvideo.distributed import get_sp_group, get_world_group +from fastvideo.train.callbacks.callback import CallbackDict +from fastvideo.train.methods.base import TrainingMethod +from fastvideo.train.utils.tracking import build_tracker + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + + +def _coerce_log_scalar(value: Any, *, where: str) -> float: + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise ValueError(f"Expected scalar tensor at {where}, " + f"got shape={tuple(value.shape)}") + return float(value.detach().item()) + if isinstance(value, float | int): + return float(value) + raise TypeError(f"Expected a scalar (float/int/Tensor) at " + f"{where}, got {type(value).__name__}") + + +@dataclass(slots=True) +class TrainLoopState: + step: int + accum_iter: int + current_vsa_sparsity: float + + +class Trainer: + + def __init__( + self, + training_config: TrainingConfig, + *, + config: dict[str, Any] | None = None, + callback_configs: dict[str, dict[str, Any]] + | None = None, + ) -> None: + self.training_config = training_config + self.world_group = get_world_group() + self.sp_group = get_sp_group() + self.global_rank = self.world_group.rank + self.local_rank = self.world_group.local_rank + self.tracker = build_tracker( + training_config.tracker, + training_config.checkpoint, + config=config, + ) + self.callbacks = CallbackDict( + callback_configs or {}, + training_config, + ) + + def _iter_dataloader(self, dataloader: Any) -> Iterator[dict[str, Any]]: + data_iter = iter(dataloader) + while True: + batch = next(data_iter, None) + if batch is None: + data_iter = iter(dataloader) + batch = next(data_iter) + yield batch + + def _get_current_vsa_sparsity(self, step: int) -> float: + tc = self.training_config + vsa_sparsity = tc.vsa.sparsity + vsa_decay_rate = tc.vsa.decay_rate + vsa_decay_interval_steps = (tc.vsa.decay_interval_steps) + if vsa_decay_interval_steps > 1: + current_decay_times = min( + step // vsa_decay_interval_steps, + int(vsa_sparsity // vsa_decay_rate), + ) + return current_decay_times * vsa_decay_rate + return vsa_sparsity + + def run( + self, + method: TrainingMethod, + *, + dataloader: Any, + max_steps: int, + start_step: int = 0, + checkpoint_manager: Any | None = None, + ) -> None: + tc = self.training_config + grad_accum = max( + 1, + int(tc.loop.gradient_accumulation_steps or 1), + ) + + method.set_tracker(self.tracker) + method.on_train_start() + + resume_from_checkpoint = (tc.checkpoint.resume_from_checkpoint or "") + if checkpoint_manager is not None: + resumed_step = (checkpoint_manager.maybe_resume(resume_from_checkpoint=(resume_from_checkpoint))) + if resumed_step is not None: + start_step = int(resumed_step) + + self.callbacks.on_train_start( + method, iteration=start_step, + ) + self.callbacks.on_validation_begin( + method, iteration=start_step, + ) + method.optimizers_zero_grad(start_step) + + data_stream = self._iter_dataloader(dataloader) + progress = tqdm( + range(start_step + 1, max_steps + 1), + initial=start_step, + desc="Steps", + disable=self.local_rank > 0, + ) + for step in progress: + t0 = time.perf_counter() + current_vsa_sparsity = (self._get_current_vsa_sparsity(step)) + + loss_sums: dict[str, float] = {} + metric_sums: dict[str, float] = {} + for accum_iter in range(grad_accum): + batch = next(data_stream) + loss_map, outputs, step_metrics = method.single_train_step( + batch, + step, + current_vsa_sparsity=(current_vsa_sparsity), + ) + + method.backward( + loss_map, + outputs, + grad_accum_rounds=grad_accum, + ) + + for k, v in loss_map.items(): + if isinstance(v, torch.Tensor): + loss_sums[k] = loss_sums.get(k, 0.0) + float(v.detach().item()) + for k, v in step_metrics.items(): + if k in loss_sums: + raise ValueError(f"Metric key {k!r} collides " + "with loss key. Use a " + "different name (e.g. prefix " + "with 'train/').") + metric_sums[k] = metric_sums.get(k, 0.0) + _coerce_log_scalar( + v, + where=("method.single_train_step()" + f".metrics[{k!r}]"), + ) + + self.callbacks.on_before_optimizer_step( + method, iteration=step, + ) + method.optimizers_schedulers_step(step) + method.optimizers_zero_grad(step) + + metrics = {k: v / grad_accum for k, v in loss_sums.items()} + metrics.update({k: v / grad_accum for k, v in metric_sums.items()}) + metrics["step_time_sec"] = (time.perf_counter() - t0) + metrics["vsa_sparsity"] = float(current_vsa_sparsity) + if self.global_rank == 0 and metrics: + self.tracker.log(metrics, step) + + self.callbacks.on_training_step_end( + method, metrics, iteration=step, + ) + + if checkpoint_manager is not None: + checkpoint_manager.maybe_save(step) + + self.callbacks.on_validation_begin( + method, iteration=step, + ) + self.callbacks.on_validation_end( + method, iteration=step, + ) + + self.callbacks.on_train_end( + method, iteration=max_steps, + ) + + if checkpoint_manager is not None: + checkpoint_manager.save_final(max_steps) + + self.tracker.finish() diff --git a/fastvideo/train/utils/__init__.py b/fastvideo/train/utils/__init__.py new file mode 100644 index 000000000..d7eba4033 --- /dev/null +++ b/fastvideo/train/utils/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Distillation utilities shared across families/methods/entrypoints.""" diff --git a/fastvideo/train/utils/builder.py b/fastvideo/train/utils/builder.py new file mode 100644 index 000000000..d6d8d9976 --- /dev/null +++ b/fastvideo/train/utils/builder.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Assembly: build method + dataloader from a ``_target_``-based config.""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from fastvideo.train.utils.instantiate import ( + instantiate, + resolve_target, +) +from fastvideo.train.utils.config import RunConfig + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + from fastvideo.train.methods.base import TrainingMethod + + +def build_from_config(cfg: RunConfig, ) -> tuple[TrainingConfig, TrainingMethod, Any, int]: + """Build method + dataloader from a v3 run config. + + 1. Instantiate each model in ``cfg.models`` via ``_target_``. + 2. Resolve the method class from ``cfg.method["_target_"]`` + and construct it with ``(cfg=cfg, role_models=...)``. + 3. Return ``(training_args, method, dataloader, start_step)``. + """ + from fastvideo.train.models.base import ModelBase + + # --- 1. Build role model instances --- + role_models: dict[str, ModelBase] = {} + for role, model_cfg in cfg.models.items(): + model = instantiate( + model_cfg, training_config=cfg.training) + if not isinstance(model, ModelBase): + raise TypeError(f"models.{role}._target_ must resolve to a " + f"ModelBase subclass, got {type(model).__name__}") + role_models[role] = model + + # --- 2. Build method --- + method_cfg = dict(cfg.method) + method_target = str(method_cfg.pop("_target_")) + method_cls = resolve_target(method_target) + + # The student model provides the dataloader. + student = role_models.get("student") + + method = method_cls( + cfg=cfg, + role_models=role_models, + ) + + # --- 3. Gather dataloader and start_step --- + dataloader = (getattr(student, "dataloader", None) if student is not None else None) + start_step = int(getattr(student, "start_step", 0) if student is not None else 0) + + return cfg.training, method, dataloader, start_step diff --git a/fastvideo/train/utils/checkpoint.py b/fastvideo/train/utils/checkpoint.py new file mode 100644 index 000000000..62166dde7 --- /dev/null +++ b/fastvideo/train/utils/checkpoint.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import re +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +_CHECKPOINT_DIR_RE = re.compile(r"^checkpoint-(\d+)$") + + +def _is_stateful(obj: Any) -> bool: + return callable(getattr(obj, "state_dict", None)) and callable(getattr(obj, "load_state_dict", None)) + + +def _rank() -> int: + if dist.is_available() and dist.is_initialized(): + return int(dist.get_rank()) + return 0 + + +def _barrier() -> None: + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def _parse_step_from_dir(checkpoint_dir: Path) -> int: + match = _CHECKPOINT_DIR_RE.match(checkpoint_dir.name) + if not match: + raise ValueError(f"Invalid checkpoint directory name {checkpoint_dir.name!r}; " + "expected 'checkpoint-'") + return int(match.group(1)) + + +def _find_latest_checkpoint(output_dir: Path) -> Path | None: + if not output_dir.exists(): + return None + + candidates: list[tuple[int, Path]] = [] + for child in output_dir.iterdir(): + if not child.is_dir(): + continue + if not _CHECKPOINT_DIR_RE.match(child.name): + continue + if not (child / "dcp").is_dir(): + continue + try: + step = _parse_step_from_dir(child) + except Exception: + continue + candidates.append((step, child)) + + if not candidates: + return None + candidates.sort(key=lambda x: x[0]) + return candidates[-1][1] + + +def _resolve_resume_checkpoint(resume_from_checkpoint: str, *, output_dir: str) -> Path: + """Resolve a user-provided resume path to a concrete checkpoint dir. + + Accepted values: + - /path/to/output_dir/checkpoint- + - /path/to/output_dir/checkpoint-/dcp + - /path/to/output_dir (auto-pick latest checkpoint-*/dcp) + """ + + raw = os.path.expanduser(str(resume_from_checkpoint)) + path = Path(raw).resolve() + if not path.exists(): + raise FileNotFoundError(f"resume_from_checkpoint not found: {path}") + + if path.is_dir() and path.name == "dcp": + path = path.parent + + if path.is_dir() and _CHECKPOINT_DIR_RE.match(path.name): + if not (path / "dcp").is_dir(): + raise FileNotFoundError(f"Missing dcp dir under checkpoint: {path / 'dcp'}") + return path + + # Treat as output_dir -> pick latest. + latest = _find_latest_checkpoint(path) + if latest is not None: + return latest + + # Give a clearer error message. + out = Path(os.path.expanduser(str(output_dir))).resolve() + raise ValueError("Could not resolve resume checkpoint. Expected a checkpoint directory " + f"named 'checkpoint-' (with 'dcp/' inside), or an output_dir " + f"containing such checkpoints. Got: {path} (output_dir={out}).") + + +class _RoleModuleContainer(torch.nn.Module): + """Ephemeral container to expose multiple role modules as a single + ``nn.Module``. + + Used by ``OptimizerWrapper`` which expects a single root module + covering all parameters owned by the optimizer. + """ + + def __init__(self, modules: dict[str, torch.nn.Module]) -> None: + super().__init__() + for name, module in modules.items(): + self.add_module(name, module) + + +class _CallbackStateWrapper: + """Wraps a CallbackDict for DCP save/load.""" + + def __init__(self, callbacks: Any) -> None: + self._callbacks = callbacks + + def state_dict(self) -> dict[str, Any]: + return self._callbacks.state_dict() + + def load_state_dict( + self, state_dict: dict[str, Any], + ) -> None: + self._callbacks.load_state_dict(state_dict) + + +@dataclass(slots=True) +class CheckpointConfig: + save_steps: int + keep_last: int + + +class CheckpointManager: + """Role-based checkpoint manager for training runtime. + + - Checkpoint policy lives in YAML (via TrainingArgs fields). + - Resume path is typically provided via CLI (``--resume-from-checkpoint``). + """ + + def __init__( + self, + *, + method: Any, + dataloader: Any, + output_dir: str, + config: CheckpointConfig, + callbacks: Any | None = None, + raw_config: dict[str, Any] | None = None, + ) -> None: + self.method = method + self.dataloader = dataloader + self.output_dir = str(output_dir) + self.config = config + self._callbacks = callbacks + self._raw_config = raw_config + self._last_saved_step: int | None = None + + def _build_states(self) -> dict[str, Any]: + states: dict[str, Any] = self.method.checkpoint_state() + + # Dataloader (optional but recommended for exact resume). + if _is_stateful(self.dataloader): + states["dataloader"] = self.dataloader + + # Callback state (e.g. EMA shadow weights, validation RNG). + if self._callbacks is not None and _is_stateful(self._callbacks): + states["callbacks"] = _CallbackStateWrapper( + self._callbacks, + ) + + return states + + def _checkpoint_dir(self, step: int) -> Path: + return Path(self.output_dir) / f"checkpoint-{step}" + + def _dcp_dir(self, step: int) -> Path: + return self._checkpoint_dir(step) / "dcp" + + def maybe_save(self, step: int) -> None: + save_steps = int(self.config.save_steps or 0) + if save_steps <= 0: + return + if step % save_steps != 0: + return + if self._last_saved_step == step: + return + self.save(step) + + def save_final(self, step: int) -> None: + save_steps = int(self.config.save_steps or 0) + if save_steps <= 0: + return + self.save(step) + + def save(self, step: int) -> None: + checkpoint_dir = self._checkpoint_dir(step) + dcp_dir = self._dcp_dir(step) + os.makedirs(dcp_dir, exist_ok=True) + + states = self._build_states() + if _rank() == 0: + logger.info( + "Saving checkpoint to %s", checkpoint_dir, + ) + self._write_metadata(checkpoint_dir, step) + dcp.save(states, checkpoint_id=str(dcp_dir)) + _barrier() + self._last_saved_step = step + + self._cleanup_old_checkpoints() + + def _write_metadata( + self, checkpoint_dir: Path, step: int, + ) -> None: + metadata: dict[str, Any] = {"step": step} + if self._raw_config is not None: + metadata["config"] = self._raw_config + meta_path = checkpoint_dir / "metadata.json" + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + @staticmethod + def load_metadata( + checkpoint_dir: str | Path, + ) -> dict[str, Any]: + """Read ``metadata.json`` from a checkpoint dir.""" + meta_path = Path(checkpoint_dir) / "metadata.json" + if not meta_path.is_file(): + raise FileNotFoundError( + f"No metadata.json in {checkpoint_dir}" + ) + with open(meta_path, encoding="utf-8") as f: + return json.load(f) # type: ignore[no-any-return] + + def maybe_resume(self, *, resume_from_checkpoint: str | None) -> int | None: + if not resume_from_checkpoint: + return None + + resolved = _resolve_resume_checkpoint( + resume_from_checkpoint, + output_dir=self.output_dir, + ) + step = _parse_step_from_dir(resolved) + + states = self._build_states() + logger.info("Loading Phase 2 checkpoint from %s", resolved) + dcp.load(states, checkpoint_id=str(resolved / "dcp")) + _barrier() + logger.info("Checkpoint loaded; resuming from step=%s", step) + return step + + def _cleanup_old_checkpoints(self) -> None: + keep_last = int(self.config.keep_last or 0) + if keep_last <= 0: + return + + if _rank() != 0: + _barrier() + return + + output_dir = Path(self.output_dir) + candidates: list[tuple[int, Path]] = [] + for child in output_dir.iterdir(): + if not child.is_dir(): + continue + if not _CHECKPOINT_DIR_RE.match(child.name): + continue + try: + step = _parse_step_from_dir(child) + except Exception: + continue + candidates.append((step, child)) + + candidates.sort(key=lambda x: x[0]) + to_delete = candidates[:-keep_last] if len(candidates) > keep_last else [] + for step, path in to_delete: + logger.info("Removing old checkpoint (keep_last=%s): %s", keep_last, path) + shutil.rmtree(path, ignore_errors=True) + + _barrier() diff --git a/fastvideo/train/utils/config.py b/fastvideo/train/utils/config.py new file mode 100644 index 000000000..704362b98 --- /dev/null +++ b/fastvideo/train/utils/config.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Training run config (``_target_`` based YAML).""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml + +from fastvideo.train.utils.training_config import ( + CheckpointConfig, + DataConfig, + TrainingConfig, + DistributedConfig, + ModelTrainingConfig, + OptimizerConfig, + TrackerConfig, + TrainingLoopConfig, + VSAConfig, +) + + +@dataclass(slots=True) +class RunConfig: + """Parsed run config loaded from YAML.""" + + models: dict[str, dict[str, Any]] + method: dict[str, Any] + training: TrainingConfig + callbacks: dict[str, dict[str, Any]] + raw: dict[str, Any] + + def resolved_config(self) -> dict[str, Any]: + """Return a fully-resolved config dict with defaults. + + Suitable for logging to W&B so that every parameter + (including defaults) is visible. + """ + import dataclasses + + def _safe_asdict(obj: Any) -> Any: + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return { + f.name: _safe_asdict(getattr(obj, f.name)) + for f in dataclasses.fields(obj) + if not callable(getattr(obj, f.name)) + } + if isinstance(obj, dict): + return {k: _safe_asdict(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return type(obj)(_safe_asdict(v) for v in obj) + return obj + + resolved: dict[str, Any] = {} + resolved["models"] = dict(self.models) + resolved["method"] = dict(self.method) + resolved["training"] = _safe_asdict(self.training) + resolved["callbacks"] = dict(self.callbacks) + return resolved + + +# ---- parsing helpers (kept for use by methods) ---- + + +def _resolve_existing_file(path: str) -> str: + if not path: + return path + expanded = os.path.expanduser(path) + resolved = Path(expanded).resolve() + if not resolved.exists(): + raise FileNotFoundError(f"Config file not found: {resolved}") + if not resolved.is_file(): + raise ValueError(f"Expected a file path, got: {resolved}") + return str(resolved) + + +def _require_mapping(raw: Any, *, where: str) -> dict[str, Any]: + if not isinstance(raw, dict): + raise ValueError(f"Expected mapping at {where}, " + f"got {type(raw).__name__}") + return raw + + +def _require_str(raw: Any, *, where: str) -> str: + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"Expected non-empty string at {where}") + return raw + + +def get_optional_int(mapping: dict[str, Any], key: str, *, where: str) -> int | None: + raw = mapping.get(key) + if raw is None: + return None + if isinstance(raw, bool): + raise ValueError(f"Expected int at {where}, got bool") + if isinstance(raw, int): + return int(raw) + if isinstance(raw, float) and raw.is_integer(): + return int(raw) + if isinstance(raw, str) and raw.strip(): + return int(raw) + raise ValueError(f"Expected int at {where}, " + f"got {type(raw).__name__}") + + +def get_optional_float(mapping: dict[str, Any], key: str, *, where: str) -> float | None: + raw = mapping.get(key) + if raw is None: + return None + if isinstance(raw, bool): + raise ValueError(f"Expected float at {where}, got bool") + if isinstance(raw, int | float): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError(f"Expected float at {where}, " + f"got {type(raw).__name__}") + + +def parse_betas(raw: Any, *, where: str) -> tuple[float, float]: + if raw is None: + raise ValueError(f"Missing betas for {where}") + if isinstance(raw, tuple | list) and len(raw) == 2: + return float(raw[0]), float(raw[1]) + if isinstance(raw, str): + parts = [p.strip() for p in raw.split(",") if p.strip()] + if len(parts) != 2: + raise ValueError(f"Expected betas as 'b1,b2' at {where}, " + f"got {raw!r}") + return float(parts[0]), float(parts[1]) + raise ValueError(f"Expected betas as 'b1,b2' at {where}, " + f"got {type(raw).__name__}") + + +# ---- config convenience helpers ---- + + +def require_positive_int( + mapping: dict[str, Any], + key: str, + *, + default: int | None = None, + where: str | None = None, +) -> int: + """Read an int that must be > 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_int(mapping, key, where=loc) + if val is None or val <= 0: + raise ValueError(f"{loc} must be a positive integer, got {raw!r}") + return val + + +def require_non_negative_int( + mapping: dict[str, Any], + key: str, + *, + default: int | None = None, + where: str | None = None, +) -> int: + """Read an int that must be >= 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_int(mapping, key, where=loc) + if val is None or val < 0: + raise ValueError(f"{loc} must be a non-negative integer, " + f"got {raw!r}") + return val + + +def require_non_negative_float( + mapping: dict[str, Any], + key: str, + *, + default: float | None = None, + where: str | None = None, +) -> float: + """Read a float that must be >= 0.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + val = get_optional_float(mapping, key, where=loc) + if val is None or val < 0.0: + raise ValueError(f"{loc} must be a non-negative float, " + f"got {raw!r}") + return val + + +def require_choice( + mapping: dict[str, Any], + key: str, + choices: set[str] | frozenset[str], + *, + default: str | None = None, + where: str | None = None, +) -> str: + """Read a string that must be one of *choices*.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + if default not in choices: + raise ValueError(f"Default {default!r} not in {choices}") + return default + raise ValueError(f"Missing required key {loc!r}") + if not isinstance(raw, str) or not raw.strip(): + raise ValueError(f"{loc} must be a non-empty string, " + f"got {type(raw).__name__}") + val = raw.strip().lower() + if val not in choices: + raise ValueError(f"{loc} must be one of {sorted(choices)}, " + f"got {raw!r}") + return val + + +def require_bool( + mapping: dict[str, Any], + key: str, + *, + default: bool | None = None, + where: str | None = None, +) -> bool: + """Read a bool value.""" + loc = where or key + raw = mapping.get(key) + if raw is None: + if default is not None: + return default + raise ValueError(f"Missing required key {loc!r}") + if not isinstance(raw, bool): + raise ValueError(f"{loc} must be a bool, " + f"got {type(raw).__name__}") + return raw + + +def _parse_pipeline_config( + cfg: dict[str, Any], + *, + models: dict[str, dict[str, Any]], +) -> Any: + """Resolve PipelineConfig from the ``pipeline:`` YAML key.""" + from fastvideo.configs.pipelines.base import PipelineConfig + + pipeline_raw = cfg.get("pipeline") + if pipeline_raw is None: + return None + + # Derive model_path from models.student.init_from — + # needed by PipelineConfig.from_kwargs. + model_path: str | None = None + student_cfg = models.get("student") + if student_cfg is not None: + init_from = student_cfg.get("init_from") + if init_from is not None: + model_path = str(init_from) + + kwargs: dict[str, Any] = {"pipeline_config": pipeline_raw} + if model_path is not None: + kwargs["model_path"] = model_path + + if isinstance(pipeline_raw, str): + kwargs["pipeline_config"] = _resolve_existing_file( + pipeline_raw) + + return PipelineConfig.from_kwargs(kwargs) + + +def _build_training_config( + t: dict[str, Any], + *, + models: dict[str, dict[str, Any]], + pipeline_config: Any, +) -> TrainingConfig: + """Build TrainingConfig from nested training: YAML.""" + d = dict(t.get("distributed", {}) or {}) + da = dict(t.get("data", {}) or {}) + o = dict(t.get("optimizer", {}) or {}) + lo = dict(t.get("loop", {}) or {}) + ck = dict(t.get("checkpoint", {}) or {}) + tr = dict(t.get("tracker", {}) or {}) + vs = dict(t.get("vsa", {}) or {}) + m = dict(t.get("model", {}) or {}) + + num_gpus = int(d.get("num_gpus", 1) or 1) + + betas_raw = o.get("betas", "0.9,0.999") + betas = parse_betas(betas_raw, + where="training.optimizer.betas") + + model_path = str(t.get("model_path", "") or "") + if not model_path: + student_cfg = models.get("student") + if student_cfg is not None: + init_from = student_cfg.get("init_from") + if init_from is not None: + model_path = str(init_from) + + return TrainingConfig( + distributed=DistributedConfig( + num_gpus=num_gpus, + tp_size=int(d.get("tp_size", 1) or 1), + sp_size=int( + d.get("sp_size", num_gpus) or num_gpus), + hsdp_replicate_dim=int( + d.get("hsdp_replicate_dim", 1) or 1), + hsdp_shard_dim=int( + d.get("hsdp_shard_dim", num_gpus) + or num_gpus), + pin_cpu_memory=bool( + d.get("pin_cpu_memory", False)), + ), + data=DataConfig( + data_path=str(da.get("data_path", "") or ""), + train_batch_size=int( + da.get("train_batch_size", 1) or 1), + dataloader_num_workers=int( + da.get("dataloader_num_workers", 0) or 0), + training_cfg_rate=float( + da.get("training_cfg_rate", 0.0) or 0.0), + seed=int(da.get("seed", 0) or 0), + num_height=int( + da.get("num_height", 0) or 0), + num_width=int(da.get("num_width", 0) or 0), + num_latent_t=int( + da.get("num_latent_t", 0) or 0), + num_frames=int( + da.get("num_frames", 0) or 0), + ), + optimizer=OptimizerConfig( + learning_rate=float( + o.get("learning_rate", 0.0) or 0.0), + betas=betas, + weight_decay=float( + o.get("weight_decay", 0.0) or 0.0), + lr_scheduler=str( + o.get("lr_scheduler", "constant") + or "constant"), + lr_warmup_steps=int( + o.get("lr_warmup_steps", 0) or 0), + lr_num_cycles=int( + o.get("lr_num_cycles", 0) or 0), + lr_power=float( + o.get("lr_power", 0.0) or 0.0), + min_lr_ratio=float( + o.get("min_lr_ratio", 0.5) or 0.5), + ), + loop=TrainingLoopConfig( + max_train_steps=int( + lo.get("max_train_steps", 0) or 0), + gradient_accumulation_steps=int( + lo.get("gradient_accumulation_steps", 1) + or 1), + ), + checkpoint=CheckpointConfig( + output_dir=str( + ck.get("output_dir", "") or ""), + resume_from_checkpoint=str( + ck.get("resume_from_checkpoint", "") + or ""), + training_state_checkpointing_steps=int( + ck.get( + "training_state_checkpointing_steps", + 0) or 0), + checkpoints_total_limit=int( + ck.get("checkpoints_total_limit", 0) + or 0), + ), + tracker=TrackerConfig( + trackers=list( + tr.get("trackers", []) or []), + project_name=str( + tr.get("project_name", "fastvideo") + or "fastvideo"), + run_name=str(tr.get("run_name", "") or ""), + ), + vsa=VSAConfig( + sparsity=float( + vs.get("sparsity", 0.0) or 0.0), + decay_rate=float( + vs.get("decay_rate", 0.0) or 0.0), + decay_interval_steps=int( + vs.get("decay_interval_steps", 0) or 0), + ), + model=ModelTrainingConfig( + weighting_scheme=str( + m.get("weighting_scheme", "uniform") + or "uniform"), + logit_mean=float( + m.get("logit_mean", 0.0) or 0.0), + logit_std=float( + m.get("logit_std", 1.0) or 1.0), + mode_scale=float( + m.get("mode_scale", 1.0) or 1.0), + precondition_outputs=bool( + m.get("precondition_outputs", False)), + moba_config=dict( + m.get("moba_config", {}) or {}), + enable_gradient_checkpointing_type=( + m.get( + "enable_gradient_checkpointing_type" + )), + ), + pipeline_config=pipeline_config, + model_path=model_path, + dit_precision=str( + t.get("dit_precision", "fp32") or "fp32"), + ) + + +def load_run_config(path: str) -> RunConfig: + """Load a run config from YAML. + + Expected top-level keys: ``models``, ``method``, + ``training`` (nested), and optionally ``callbacks`` + and ``pipeline``. + """ + path = _resolve_existing_file(path) + with open(path, encoding="utf-8") as f: + raw = yaml.safe_load(f) + cfg = _require_mapping(raw, where=path) + + # --- models --- + models_raw = _require_mapping( + cfg.get("models"), where="models") + models: dict[str, dict[str, Any]] = {} + for role, model_cfg_raw in models_raw.items(): + role_str = _require_str( + role, where="models.") + model_cfg = _require_mapping( + model_cfg_raw, where=f"models.{role_str}") + if "_target_" not in model_cfg: + raise ValueError( + f"models.{role_str} must have a " + "'_target_' key") + models[role_str] = dict(model_cfg) + + # --- method --- + method_raw = _require_mapping( + cfg.get("method"), where="method") + if "_target_" not in method_raw: + raise ValueError( + "method must have a '_target_' key") + method = dict(method_raw) + + # --- callbacks --- + callbacks_raw = cfg.get("callbacks", None) + if callbacks_raw is None: + callbacks: dict[str, dict[str, Any]] = {} + else: + callbacks = _require_mapping( + callbacks_raw, where="callbacks") + + # --- pipeline config --- + pipeline_config = _parse_pipeline_config( + cfg, models=models) + + # --- training config --- + training_raw = _require_mapping( + cfg.get("training"), where="training") + t = dict(training_raw) + training = _build_training_config( + t, models=models, + pipeline_config=pipeline_config) + + return RunConfig( + models=models, + method=method, + training=training, + callbacks=callbacks, + raw=cfg, + ) diff --git a/fastvideo/train/utils/dataloader.py b/fastvideo/train/utils/dataloader.py new file mode 100644 index 000000000..9771db6d2 --- /dev/null +++ b/fastvideo/train/utils/dataloader.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + DataConfig, ) + + +def build_parquet_t2v_train_dataloader( + data_config: DataConfig, + *, + text_len: int, + parquet_schema: Any, +) -> Any: + """Build a parquet dataloader for T2V-style datasets.""" + + from fastvideo.dataset import ( + build_parquet_map_style_dataloader, ) + + _dataset, dataloader = (build_parquet_map_style_dataloader( + data_config.data_path, + data_config.train_batch_size, + num_data_workers=(data_config.dataloader_num_workers), + parquet_schema=parquet_schema, + cfg_rate=data_config.training_cfg_rate, + drop_last=True, + text_padding_length=int(text_len), + seed=int(data_config.seed or 0), + )) + return dataloader + + +def build_parquet_wangame_train_dataloader( + data_config: DataConfig, + *, + parquet_schema: Any, +) -> Any: + """Build a parquet dataloader for WanGame datasets.""" + + from fastvideo.dataset import ( + build_parquet_map_style_dataloader, ) + + _dataset, dataloader = (build_parquet_map_style_dataloader( + data_config.data_path, + data_config.train_batch_size, + num_data_workers=(data_config.dataloader_num_workers), + parquet_schema=parquet_schema, + cfg_rate=float(data_config.training_cfg_rate or 0.0), + drop_last=True, + text_padding_length=512, + seed=int(data_config.seed or 0), + )) + return dataloader diff --git a/fastvideo/train/utils/instantiate.py b/fastvideo/train/utils/instantiate.py new file mode 100644 index 000000000..ed43122ee --- /dev/null +++ b/fastvideo/train/utils/instantiate.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +"""``_target_``-based instantiation utilities. + +These helpers resolve a dotted Python path to a class and instantiate it, +filtering constructor kwargs through ``inspect.signature`` so that only +recognized parameters are forwarded. Unrecognized keys emit a warning +rather than raising — this keeps YAML configs forward-compatible when +a class drops a parameter in a later version. +""" + +from __future__ import annotations + +import importlib +import inspect +import warnings +from typing import Any + + +def resolve_target(target: str) -> type: + """Import and return the class (or callable) at *target*. + + *target* must be a fully-qualified dotted path, e.g. + ``"fastvideo.train.models.wangame.wangame.WanGameModel"``. + """ + if not isinstance(target, str) or not target.strip(): + raise ValueError(f"_target_ must be a non-empty dotted path string, " + f"got {target!r}") + target = target.strip() + parts = target.rsplit(".", 1) + if len(parts) != 2: + raise ValueError(f"_target_ must contain at least one dot " + f"(module.ClassName), got {target!r}") + module_path, attr_name = parts + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as exc: + raise ImportError(f"Cannot import module {module_path!r} " + f"(from _target_={target!r})") from exc + try: + cls = getattr(module, attr_name) + except AttributeError as exc: + raise ImportError(f"Module {module_path!r} has no attribute " + f"{attr_name!r} (from _target_={target!r})") from exc + return cls + + +def instantiate(cfg: dict[str, Any], **extra: Any) -> Any: + """Instantiate the class specified by ``cfg["_target_"]``. + + All remaining keys in *cfg* (minus ``_target_``) plus any *extra* + keyword arguments are forwarded to the constructor. Keys that do + not match an ``__init__`` parameter are silently warned about and + dropped, so callers can safely pass a superset. + """ + if not isinstance(cfg, dict): + raise TypeError(f"instantiate() expects a dict with '_target_', " + f"got {type(cfg).__name__}") + target_str = cfg.get("_target_") + if target_str is None: + raise KeyError("Config dict is missing '_target_' key") + + cls = resolve_target(str(target_str)) + kwargs: dict[str, Any] = {k: v for k, v in cfg.items() if k != "_target_"} + kwargs.update(extra) + + sig = inspect.signature(cls.__init__) + params = sig.parameters + has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + + if not has_var_keyword: + valid_names = { + name + for name, p in params.items() if p.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + valid_names.discard("self") + unrecognized = set(kwargs) - valid_names + if unrecognized: + warnings.warn( + f"instantiate({target_str}): dropping unrecognized " + f"kwargs {sorted(unrecognized)}", + stacklevel=2, + ) + for key in unrecognized: + del kwargs[key] + + return cls(**kwargs) diff --git a/fastvideo/train/utils/module_state.py b/fastvideo/train/utils/module_state.py new file mode 100644 index 000000000..6d28a005f --- /dev/null +++ b/fastvideo/train/utils/module_state.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch + + +def apply_trainable(module: torch.nn.Module, *, trainable: bool) -> torch.nn.Module: + """Apply train/eval mode + requires_grad based on a role's trainable flag.""" + + module.requires_grad_(bool(trainable)) + if trainable: + module.train() + else: + module.eval() + return module diff --git a/fastvideo/train/utils/moduleloader.py b/fastvideo/train/utils/moduleloader.py new file mode 100644 index 000000000..7d18db197 --- /dev/null +++ b/fastvideo/train/utils/moduleloader.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from typing import Any, TYPE_CHECKING + +import torch + +from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.fastvideo_args import ExecutionMode, TrainingArgs +from fastvideo.models.loader.component_loader import ( + PipelineComponentLoader, ) +from fastvideo.utils import ( + maybe_download_model, + verify_model_config_and_directory, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + TrainingConfig, ) + +# ------------------------------------------------------------------ +# TrainingArgs builders (only place that creates FastVideoArgs) +# ------------------------------------------------------------------ + + +def _make_training_args( + tc: TrainingConfig, + *, + model_path: str, +) -> TrainingArgs: + """Build a TrainingArgs for PipelineComponentLoader.""" + pipeline_config = tc.pipeline_config or PipelineConfig() + # Propagate dit_precision from TrainingConfig to PipelineConfig + # so that TransformerLoader.load() picks up the correct + # default_dtype (e.g. fp32 master weights for training). + if tc.dit_precision and tc.dit_precision != pipeline_config.dit_precision: + pipeline_config.dit_precision = tc.dit_precision + return TrainingArgs( + model_path=model_path, + mode=ExecutionMode.DISTILLATION, + inference_mode=False, + pipeline_config=pipeline_config, + num_gpus=tc.distributed.num_gpus, + tp_size=tc.distributed.tp_size, + sp_size=tc.distributed.sp_size, + hsdp_replicate_dim=tc.distributed.hsdp_replicate_dim, + hsdp_shard_dim=tc.distributed.hsdp_shard_dim, + pin_cpu_memory=tc.distributed.pin_cpu_memory, + dit_cpu_offload=False, + dit_layerwise_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, + image_encoder_cpu_offload=False, + use_fsdp_inference=False, + enable_torch_compile=False, + ) + + +def make_inference_args( + tc: TrainingConfig, + *, + model_path: str, +) -> TrainingArgs: + """Build a TrainingArgs for inference (validation / pipelines).""" + args = _make_training_args(tc, model_path=model_path) + args.inference_mode = True + args.mode = ExecutionMode.INFERENCE + args.dit_cpu_offload = True + args.VSA_sparsity = tc.vsa.sparsity + return args + + +# ------------------------------------------------------------------ +# Module loading +# ------------------------------------------------------------------ + + +def load_module_from_path( + *, + model_path: str, + module_type: str, + training_config: TrainingConfig, + disable_custom_init_weights: bool = False, + override_transformer_cls_name: str | None = None, +) -> torch.nn.Module: + """Load a single pipeline component module. + + Accepts a ``TrainingConfig`` and internally builds the + ``TrainingArgs`` needed by ``PipelineComponentLoader``. + """ + fastvideo_args: Any = _make_training_args( + training_config, model_path=model_path) + + local_model_path = maybe_download_model(model_path) + config = verify_model_config_and_directory(local_model_path) + + if module_type not in config: + raise ValueError(f"Module {module_type!r} not found in " + f"config at {local_model_path}") + + module_info = config[module_type] + if module_info is None: + raise ValueError(f"Module {module_type!r} has null value in " + f"config at {local_model_path}") + + transformers_or_diffusers, _architecture = module_info + component_path = os.path.join(local_model_path, module_type) + + old_override: str | None = None + if override_transformer_cls_name is not None: + old_override = getattr( + fastvideo_args, + "override_transformer_cls_name", + None, + ) + fastvideo_args.override_transformer_cls_name = str(override_transformer_cls_name) + + if disable_custom_init_weights: + fastvideo_args._loading_teacher_critic_model = True + try: + module = PipelineComponentLoader.load_module( + module_name=module_type, + component_model_path=component_path, + transformers_or_diffusers=(transformers_or_diffusers), + fastvideo_args=fastvideo_args, + ) + finally: + if disable_custom_init_weights and hasattr(fastvideo_args, "_loading_teacher_critic_model"): + del fastvideo_args._loading_teacher_critic_model + if override_transformer_cls_name is not None: + if old_override is None: + if hasattr( + fastvideo_args, + "override_transformer_cls_name", + ): + fastvideo_args.override_transformer_cls_name = (None) + else: + fastvideo_args.override_transformer_cls_name = (old_override) + + if not isinstance(module, torch.nn.Module): + raise TypeError(f"Loaded {module_type!r} is not a " + f"torch.nn.Module: {type(module)}") + return module diff --git a/fastvideo/train/utils/optimizer.py b/fastvideo/train/utils/optimizer.py new file mode 100644 index 000000000..43a79d98d --- /dev/null +++ b/fastvideo/train/utils/optimizer.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from fastvideo.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, + get_scheduler, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + OptimizerConfig, + TrainingLoopConfig, + ) + + +def build_optimizer_and_scheduler( + *, + params: list[torch.nn.Parameter], + optimizer_config: OptimizerConfig, + loop_config: TrainingLoopConfig, + learning_rate: float, + betas: tuple[float, float], + scheduler_name: str, +) -> tuple[torch.optim.Optimizer, object]: + """Build an AdamW optimizer and LR scheduler. + + Returns ``(optimizer, lr_scheduler)`` so the caller can store them + as method-level attributes. + """ + if not params: + raise ValueError("No trainable parameters passed to " + "build_optimizer_and_scheduler") + + optimizer = torch.optim.AdamW( + params, + lr=float(learning_rate), + betas=betas, + weight_decay=float(optimizer_config.weight_decay), + eps=1e-8, + ) + + scheduler = get_scheduler( + str(scheduler_name), + optimizer=optimizer, + num_warmup_steps=int(optimizer_config.lr_warmup_steps), + num_training_steps=int(loop_config.max_train_steps), + num_cycles=int(optimizer_config.lr_num_cycles), + power=float(optimizer_config.lr_power), + min_lr_ratio=float(optimizer_config.min_lr_ratio), + last_epoch=-1, + ) + + return optimizer, scheduler + + +def clip_grad_norm_if_needed( + module: torch.nn.Module, + max_grad_norm: float, +) -> float: + if max_grad_norm <= 0.0: + return 0.0 + grad_norm = (clip_grad_norm_while_handling_failing_dtensor_cases( + [p for p in module.parameters()], + max_grad_norm, + foreach=None, + )) + return (float(grad_norm.item()) if grad_norm is not None else 0.0) diff --git a/fastvideo/train/utils/tracking.py b/fastvideo/train/utils/tracking.py new file mode 100644 index 000000000..7ad28a2e4 --- /dev/null +++ b/fastvideo/train/utils/tracking.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from typing import Any, TYPE_CHECKING + +from fastvideo.distributed import get_world_group +from fastvideo.training.trackers import ( + initialize_trackers, + Trackers, +) + +if TYPE_CHECKING: + from fastvideo.train.utils.training_config import ( + CheckpointConfig, + TrackerConfig, + ) + + +def build_tracker( + tracker_config: TrackerConfig, + checkpoint_config: CheckpointConfig, + *, + config: dict[str, Any] | None, +) -> Any: + """Build a tracker instance for a distillation run.""" + + world_group = get_world_group() + + trackers = list(tracker_config.trackers) + if not trackers and str(tracker_config.project_name): + trackers.append(Trackers.WANDB.value) + if world_group.rank != 0: + trackers = [] + + tracker_log_dir = (checkpoint_config.output_dir or os.getcwd()) + if trackers: + tracker_log_dir = os.path.join(tracker_log_dir, "tracker") + + tracker_config_dict = config if trackers else None + tracker_run_name = tracker_config.run_name or None + project = (tracker_config.project_name or "fastvideo") + + return initialize_trackers( + trackers, + experiment_name=project, + config=tracker_config_dict, + log_dir=tracker_log_dir, + run_name=tracker_run_name, + ) diff --git a/fastvideo/train/utils/training_config.py b/fastvideo/train/utils/training_config.py new file mode 100644 index 000000000..0167751db --- /dev/null +++ b/fastvideo/train/utils/training_config.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Typed training config — replaces TrainingArgs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastvideo.configs.pipelines.base import PipelineConfig + + +@dataclass(slots=True) +class DistributedConfig: + num_gpus: int = 1 + tp_size: int = 1 + sp_size: int = 1 + hsdp_replicate_dim: int = 1 + hsdp_shard_dim: int = -1 + pin_cpu_memory: bool = False + + +@dataclass(slots=True) +class DataConfig: + data_path: str = "" + train_batch_size: int = 1 + dataloader_num_workers: int = 0 + training_cfg_rate: float = 0.0 + seed: int = 0 + num_height: int = 0 + num_width: int = 0 + num_latent_t: int = 0 + num_frames: int = 0 + + +@dataclass(slots=True) +class OptimizerConfig: + learning_rate: float = 0.0 + betas: tuple[float, float] = (0.9, 0.999) + weight_decay: float = 0.0 + lr_scheduler: str = "constant" + lr_warmup_steps: int = 0 + lr_num_cycles: int = 0 + lr_power: float = 0.0 + min_lr_ratio: float = 0.5 + + +@dataclass(slots=True) +class TrainingLoopConfig: + max_train_steps: int = 0 + gradient_accumulation_steps: int = 1 + + +@dataclass(slots=True) +class CheckpointConfig: + output_dir: str = "" + resume_from_checkpoint: str = "" + training_state_checkpointing_steps: int = 0 + checkpoints_total_limit: int = 0 + + +@dataclass(slots=True) +class TrackerConfig: + trackers: list[str] = field(default_factory=list) + project_name: str = "fastvideo" + run_name: str = "" + + +@dataclass(slots=True) +class VSAConfig: + sparsity: float = 0.0 + decay_rate: float = 0.0 + decay_interval_steps: int = 0 + + +@dataclass(slots=True) +class ModelTrainingConfig: + weighting_scheme: str = "uniform" + logit_mean: float = 0.0 + logit_std: float = 1.0 + mode_scale: float = 1.0 + precondition_outputs: bool = False + moba_config: dict = field(default_factory=dict) + enable_gradient_checkpointing_type: str | None = None + + +@dataclass(slots=True) +class TrainingConfig: + distributed: DistributedConfig = field(default_factory=DistributedConfig) + data: DataConfig = field(default_factory=DataConfig) + optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) + loop: TrainingLoopConfig = field(default_factory=TrainingLoopConfig) + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + tracker: TrackerConfig = field(default_factory=TrackerConfig) + vsa: VSAConfig = field(default_factory=VSAConfig) + model: ModelTrainingConfig = field(default_factory=ModelTrainingConfig) + pipeline_config: PipelineConfig | None = None + model_path: str = "" + dit_precision: str = "fp32" diff --git a/fastvideo/train/utils/validation.py b/fastvideo/train/utils/validation.py new file mode 100644 index 000000000..5d7722d97 --- /dev/null +++ b/fastvideo/train/utils/validation.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Literal, cast + +from fastvideo.train.utils.config import get_optional_int + + +def is_validation_enabled(cfg: dict[str, Any]) -> bool: + if not cfg: + return False + enabled = cfg.get("enabled") + if enabled is None: + return True + if isinstance(enabled, bool): + return bool(enabled) + raise ValueError("training.validation.enabled must be a bool when set, got " + f"{type(enabled).__name__}") + + +def parse_validation_every_steps(cfg: dict[str, Any]) -> int: + raw = cfg.get("every_steps") + if raw is None: + raise ValueError("training.validation.every_steps must be set when validation is enabled") + if isinstance(raw, bool): + raise ValueError("training.validation.every_steps must be an int, got bool") + if isinstance(raw, int): + return int(raw) + if isinstance(raw, float) and raw.is_integer(): + return int(raw) + if isinstance(raw, str) and raw.strip(): + return int(raw) + raise ValueError("training.validation.every_steps must be an int, got " + f"{type(raw).__name__}") + + +def parse_validation_dataset_file(cfg: dict[str, Any]) -> str: + raw = cfg.get("dataset_file") + if not isinstance(raw, str) or not raw.strip(): + raise ValueError("training.validation.dataset_file must be set when validation is enabled") + return raw.strip() + + +def parse_validation_sampling_steps(cfg: dict[str, Any]) -> list[int]: + raw = cfg.get("sampling_steps") + steps: list[int] = [] + if raw is None or raw == "": + raise ValueError("training.validation.sampling_steps must be set for validation") + if isinstance(raw, bool): + raise ValueError("validation sampling_steps must be an int/list/str, got bool") + if isinstance(raw, int) or (isinstance(raw, float) and raw.is_integer()): + steps = [int(raw)] + elif isinstance(raw, str): + steps = [int(s) for s in raw.split(",") if str(s).strip()] + elif isinstance(raw, list): + steps = [int(s) for s in raw] + else: + raise ValueError("validation sampling_steps must be an int/list/str, got " + f"{type(raw).__name__}") + return [s for s in steps if int(s) > 0] + + +def parse_validation_guidance_scale(cfg: dict[str, Any]) -> float | None: + raw = cfg.get("guidance_scale") + if raw in (None, ""): + return None + if isinstance(raw, bool): + raise ValueError("validation guidance_scale must be a number/string, got bool") + if isinstance(raw, (int, float)): + return float(raw) + if isinstance(raw, str) and raw.strip(): + return float(raw) + raise ValueError("validation guidance_scale must be a number/string, got " + f"{type(raw).__name__}") + + +def parse_validation_sampler_kind( + cfg: dict[str, Any], + *, + default: Literal["ode", "sde"], +) -> Literal["ode", "sde"]: + raw = cfg.get("sampler_kind", default) + if raw is None: + raw = default + if not isinstance(raw, str): + raise ValueError("training.validation.sampler_kind must be a string when set, got " + f"{type(raw).__name__}") + kind = raw.strip().lower() + if kind not in {"ode", "sde"}: + raise ValueError("training.validation.sampler_kind must be one of {ode, sde}, got " + f"{raw!r}") + return cast(Literal["ode", "sde"], kind) + + +def parse_validation_rollout_mode( + cfg: dict[str, Any], + *, + default: Literal["parallel", "streaming"] = "parallel", +) -> Literal["parallel", "streaming"]: + raw = cfg.get("rollout_mode", default) + if raw is None: + raw = default + if not isinstance(raw, str): + raise ValueError("training.validation.rollout_mode must be a string when set, got " + f"{type(raw).__name__}") + mode = raw.strip().lower() + if mode not in {"parallel", "streaming"}: + raise ValueError("training.validation.rollout_mode must be one of {parallel, streaming}, " + f"got {raw!r}") + return cast(Literal["parallel", "streaming"], mode) + + +def parse_validation_ode_solver( + cfg: dict[str, Any], + *, + sampler_kind: Literal["ode", "sde"], +) -> str | None: + raw = cfg.get("ode_solver") + if raw in (None, ""): + return None + if sampler_kind != "ode": + raise ValueError("training.validation.ode_solver is only valid when " + "training.validation.sampler_kind='ode'") + if not isinstance(raw, str): + raise ValueError("training.validation.ode_solver must be a string when set, got " + f"{type(raw).__name__}") + solver = raw.strip().lower() + if solver in {"unipc", "unipc_multistep", "multistep"}: + return "unipc" + if solver in {"euler", "flowmatch", "flowmatch_euler"}: + return "euler" + raise ValueError("training.validation.ode_solver must be one of {unipc, euler}, got " + f"{raw!r}") + + +def parse_validation_output_dir(cfg: dict[str, Any]) -> str | None: + raw = cfg.get("output_dir") + if raw is None: + return None + if not isinstance(raw, str): + raise ValueError("training.validation.output_dir must be a string when set, got " + f"{type(raw).__name__}") + return raw + + +def parse_validation_num_frames(cfg: dict[str, Any]) -> int | None: + num_frames = get_optional_int(cfg, "num_frames", where="training.validation.num_frames") + if num_frames is not None and num_frames <= 0: + raise ValueError("training.validation.num_frames must be > 0 when set") + return num_frames diff --git a/fastvideo/training/checkpointing_utils.py b/fastvideo/training/checkpointing_utils.py index bc6aeed55..a3d4e84e2 100644 --- a/fastvideo/training/checkpointing_utils.py +++ b/fastvideo/training/checkpointing_utils.py @@ -21,10 +21,25 @@ def state_dict(self) -> dict[str, Any]: state_dict = get_model_state_dict( self.model) # type: ignore[no-any-return] # filter out non-trainable parameters - param_requires_grad = set([ - k for k, v in dict(self.model.named_parameters()).items() - if v.requires_grad - ]) + param_requires_grad: set[str] = set() + for name, param in self.model.named_parameters(): + if not bool(param.requires_grad): + continue + param_requires_grad.add(name) + + # Activation checkpointing wraps modules with an internal attribute + # `_checkpoint_wrapped_module`, which changes the *parameter name* + # observed via `named_parameters()`: + # + # named_parameters: blocks.0._checkpoint_wrapped_module.weight + # state_dict: blocks.0.weight + # + # `get_model_state_dict()` returns the unwrapped key names, so we + # also add the unwrapped form for filtering. + if "._checkpoint_wrapped_module." in name: + param_requires_grad.add( + name.replace("._checkpoint_wrapped_module.", ".") + ) state_dict = { k: v for k, v in state_dict.items() if k in param_requires_grad diff --git a/fastvideo/training/distillation_pipeline.py b/fastvideo/training/distillation_pipeline.py index 8abcfa955..e505186bd 100644 --- a/fastvideo/training/distillation_pipeline.py +++ b/fastvideo/training/distillation_pipeline.py @@ -112,7 +112,7 @@ def load_modules(self, if training_args.real_score_model_path: logger.info("Loading real score transformer from: %s", training_args.real_score_model_path) - training_args.override_transformer_cls_name = "WanTransformer3DModel" + # training_args.override_transformer_cls_name = "WanTransformer3DModel" # TODO(will): can use deepcopy instead if the model is the same self.real_score_transformer = self.load_module_from_path( training_args.real_score_model_path, "transformer", @@ -138,7 +138,7 @@ def load_modules(self, if training_args.fake_score_model_path: logger.info("Loading fake score transformer from: %s", training_args.fake_score_model_path) - training_args.override_transformer_cls_name = "WanTransformer3DModel" + # training_args.override_transformer_cls_name = "WanTransformer3DModel" self.fake_score_transformer = self.load_module_from_path( training_args.fake_score_model_path, "transformer", training_args) @@ -1208,7 +1208,8 @@ def _log_validation(self, transformer, training_args, global_step) -> None: training_args.validation_dataset_file, local_main_process_only=False) validation_dataset = ValidationDataset( - training_args.validation_dataset_file) + training_args.validation_dataset_file, + num_samples=training_args.validation_num_samples) validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) @@ -1277,10 +1278,14 @@ def run_validation_with_ema( prompt_embeds=[], prompt_attention_mask=[], ) - result_batch = self.validation_pipeline.prompt_encoding_stage( # type: ignore - batch_negative, training_args) - self.negative_prompt_embeds, self.negative_prompt_attention_mask = result_batch.prompt_embeds[ - 0], result_batch.prompt_attention_mask[0] + if hasattr(self.validation_pipeline, "prompt_encoding_stage"): + result_batch = self.validation_pipeline.prompt_encoding_stage( # type: ignore + batch_negative, training_args) + self.negative_prompt_embeds, self.negative_prompt_attention_mask = result_batch.prompt_embeds[ + 0], result_batch.prompt_attention_mask[0] + else: + self.negative_prompt_embeds = None + self.negative_prompt_attention_mask = None logger.info( "rank: %s: rank_in_sp_group: %s, batch.prompt: %s", @@ -1308,6 +1313,7 @@ def run_validation_with_ema( x = torchvision.utils.make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8)) + frames = self._post_process_validation_frames(frames, batch) videos.append(frames) audios.append(output_batch.extra.get("audio")) audio_sample_rates.append( @@ -1441,6 +1447,8 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: fake_score_log_keys = ['generator_pred_video'] dmd_log_keys = ['faker_score_pred_video', 'real_score_pred_video'] + os.makedirs(training_args.output_dir, exist_ok=True) + for latent_key in fake_score_log_keys: latents = fake_score_latents_vis_dict[latent_key] latents = _prepare_vae_latents(latents) @@ -1460,8 +1468,20 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: video = video.cpu().float() video = video.permute(0, 2, 1, 3, 4) video = (video * 255).numpy().astype(np.uint8) + + video_filename = os.path.join(training_args.output_dir, + f"{latent_key}_step_{step}.mp4") + # [B, T, C, H, W] to [H, W, C] + video_frames = [ + np.transpose(video[0, t], (1, 2, 0)) + for t in range(video.shape[1]) + ] + video_frames = self._post_process_validation_frames( + video_frames, training_batch) + imageio.mimsave(video_filename, video_frames, fps=24) + video_artifact = self.tracker.video( - video, fps=24, format="mp4") # change to 16 for Wan2.1 + video, fps=24, format="mp4", caption=latent_key) # change to 16 for Wan2.1 if video_artifact is not None: tracker_loss_dict[latent_key] = video_artifact # Clean up references @@ -1489,8 +1509,20 @@ def _apply_vae_scale(latents: torch.Tensor) -> torch.Tensor: video = video.cpu().float() video = video.permute(0, 2, 1, 3, 4) video = (video * 255).numpy().astype(np.uint8) + + video_filename = os.path.join(training_args.output_dir, + f"{latent_key}_step_{step}.mp4") + # [B, T, C, H, W] to [H, W, C] + video_frames = [ + np.transpose(video[0, t], (1, 2, 0)) + for t in range(video.shape[1]) + ] + video_frames = self._post_process_validation_frames( + video_frames, training_batch) + imageio.mimsave(video_filename, video_frames, fps=24) + video_artifact = self.tracker.video( - video, fps=24, format="mp4") # change to 16 for Wan2.1 + video, fps=24, format="mp4", caption=latent_key) # change to 16 for Wan2.1 if video_artifact is not None: tracker_loss_dict[latent_key] = video_artifact # Clean up references diff --git a/fastvideo/training/trackers.py b/fastvideo/training/trackers.py index 281d79325..02578c58a 100644 --- a/fastvideo/training/trackers.py +++ b/fastvideo/training/trackers.py @@ -11,6 +11,7 @@ import copy import os import pathlib +import shutil import time from dataclasses import dataclass from enum import Enum @@ -92,6 +93,21 @@ def log_artifacts(self, artifacts: dict[str, Any], step: int) -> None: def finish(self) -> None: # pragma: no cover - interface """Finalize the tracker session.""" + def log_file( + self, + path: str, + *, + name: str | None = None, + ) -> None: + """Log a local file to the tracker run (best-effort). + + Useful for attaching the exact YAML config used for a run. + + Trackers that do not support files should treat this as a no-op. + """ + + del path, name + def video( self, data: Any, @@ -134,12 +150,13 @@ def __init__( import wandb - pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) + self._log_dir = os.path.abspath(str(log_dir)) + pathlib.Path(self._log_dir).mkdir(parents=True, exist_ok=True) self._wandb = wandb self._run = wandb.init( project=experiment_name, - dir=log_dir, + dir=self._log_dir, config=config, name=run_name, ) @@ -154,6 +171,45 @@ def log(self, metrics: dict[str, Any], step: int) -> None: def finish(self) -> None: self._run.finish() + def log_file(self, path: str, *, name: str | None = None) -> None: + resolved = os.path.abspath(os.path.expanduser(str(path))) + if not os.path.isfile(resolved): + logger.warning("W&B log_file skipped; file not found: %s", resolved) + return + + target_name = str(name).strip() if name is not None and str(name).strip() else None + if target_name is None: + target_name = os.path.basename(resolved) + + # Prefer placing files directly under the W&B run directory to avoid + # symlink-based saves (which may not sync reliably on some clusters). + run_dir = getattr(self._run, "dir", None) + dest_root = self._log_dir if not isinstance(run_dir, str) else run_dir + dest_root = os.path.abspath(str(dest_root)) + + save_path = resolved + dest_path = os.path.join(dest_root, target_name) + try: + pathlib.Path(dest_root).mkdir(parents=True, exist_ok=True) + shutil.copyfile(resolved, dest_path) + except Exception: + logger.exception( + "Failed to copy file for W&B upload: %s -> %s", + resolved, + dest_path, + ) + else: + save_path = dest_path + + try: + self._run.save( + save_path, + base_path=os.path.dirname(save_path), + policy="now", + ) + except Exception: + logger.exception("Failed to upload file to W&B: %s", save_path) + def video( self, data: Any, @@ -201,6 +257,10 @@ def log_artifacts(self, artifacts: dict[str, Any], step: int) -> None: tracker.log_artifacts(artifacts, step) self._timed_metrics = {} + def log_file(self, path: str, *, name: str | None = None) -> None: + for tracker in self._trackers: + tracker.log_file(path, name=name) + def finish(self) -> None: for tracker in self._trackers: tracker.finish() diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index 417e900ad..ce5793393 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist import torchvision +from diffusers import FlowMatchEulerDiscreteScheduler from einops import rearrange from torch.utils.data import DataLoader from torchdata.stateful_dataloader import StatefulDataLoader @@ -47,8 +48,9 @@ initialize_trackers, Trackers) from fastvideo.training.training_utils import ( clip_grad_norm_while_handling_failing_dtensor_cases, - compute_density_for_timestep_sampling, count_trainable, get_scheduler, - get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint) + compute_density_for_timestep_sampling, count_trainable, + count_trainable_total, get_scheduler, get_sigmas, load_checkpoint, + normalize_dit_input, save_checkpoint, shard_latents_across_sp) from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict) @@ -116,7 +118,7 @@ def initialize_training_pipeline(self, training_args: TrainingArgs): # Set random seeds for deterministic training assert self.seed is not None, "seed must be set" - set_random_seed(self.seed + self.global_rank) + set_random_seed(self.seed) self.transformer.train() if training_args.enable_gradient_checkpointing_type is not None: self.transformer = apply_activation_checkpointing( @@ -192,7 +194,8 @@ def initialize_training_pipeline(self, training_args: TrainingArgs): text_padding_length=training_args.pipeline_config. text_encoder_configs[0].arch_config. text_len, # type: ignore[attr-defined] - seed=self.seed) + seed=self.seed, + reshuffle_each_epoch=training_args.reshuffle_each_epoch) self.noise_scheduler = noise_scheduler if self.training_args.boundary_ratio is not None: @@ -256,6 +259,8 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: if batch is None: self.current_epoch += 1 logger.info("Starting epoch %s", self.current_epoch) + # Reshuffle dataset order each epoch + self.train_dataset.sampler.set_epoch(self.current_epoch) # Reset iterator for next epoch self.train_loader_iter = iter(self.train_dataloader) # Get first batch of new epoch @@ -568,26 +573,40 @@ def train(self) -> None: local_main_process_only=False) if not self.post_init_called: self.post_init() - num_trainable_params = count_trainable(self.transformer) - logger.info("Starting training with %s B trainable parameters", - round(num_trainable_params / 1e9, 3)) + local_trainable = count_trainable(self.transformer) + total_trainable = count_trainable_total( + self.transformer, + get_local_torch_device(), + ) + logger.info( + "Starting training with %s B trainable parameters (total); " + "this rank shard: %s B", + round(total_trainable / 1e9, 3), + round(local_trainable / 1e9, 3), + ) if getattr(self, "transformer_2", None) is not None: - num_trainable_params = count_trainable(self.transformer_2) + local_trainable_2 = count_trainable(self.transformer_2) + total_trainable_2 = count_trainable_total( + self.transformer_2, + get_local_torch_device(), + ) logger.info( - "Transformer 2: Starting training with %s B trainable parameters", - round(num_trainable_params / 1e9, 3)) + "Transformer 2: %s B trainable parameters (total); " + "this rank shard: %s B", + round(total_trainable_2 / 1e9, 3), + round(local_trainable_2 / 1e9, 3), + ) # Set random seeds for deterministic training - self.noise_random_generator = torch.Generator( - device="cpu").manual_seed(self.seed + self.global_rank) + self.noise_random_generator = torch.Generator(device="cpu").manual_seed( + self.seed) self.noise_gen_cuda = torch.Generator( - device=current_platform.device_name).manual_seed(self.seed + - self.global_rank) + device=current_platform.device_name).manual_seed(self.seed) self.validation_random_generator = torch.Generator( - device="cpu").manual_seed(self.seed + self.global_rank) - logger.info("Initialized random seeds with seed: %s", - self.seed + self.global_rank) + device="cpu").manual_seed(self.seed) + logger.info("Initialized random seeds with seed: %s", self.seed) + self.noise_scheduler = FlowMatchEulerDiscreteScheduler() if self.training_args.resume_from_checkpoint: @@ -599,6 +618,9 @@ def train(self) -> None: self._log_training_info() + self._best_mf_angle_err_mean = float('inf') + self._last_mf_angle_err_mean = float('inf') + self._log_validation(self.transformer, self.training_args, self.init_steps) @@ -708,6 +730,43 @@ def train(self) -> None: "GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage, trainable_params) + best_start = self.training_args.best_checkpoint_start_step + if (best_start > 0 + and step >= best_start + and self._last_mf_angle_err_mean + < self._best_mf_angle_err_mean): + self._best_mf_angle_err_mean = ( + self._last_mf_angle_err_mean) + logger.info( + "New best mf_angle_err_mean=%.6f at step %d, " + "saving best checkpoint", + self._best_mf_angle_err_mean, step) + save_checkpoint( + self.transformer, self.global_rank, + self.training_args.output_dir, "best", + self.optimizer, self.train_dataloader, + self.lr_scheduler, + self.noise_random_generator) + if self.global_rank == 0: + import json + meta_path = os.path.join( + self.training_args.output_dir, + "checkpoint-best", + "best_metric.json") + with open(meta_path, "w") as f: + json.dump({ + "step": step, + "mf_angle_err_mean": + self._best_mf_angle_err_mean, + }, f, indent=2) + self.tracker.log({ + "best/mf_angle_err_mean": + self._best_mf_angle_err_mean, + "best/step": step, + }, step) + self.transformer.train() + self.sp_group.barrier() + self.tracker.finish() save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, @@ -790,11 +849,42 @@ def _prepare_validation_batch(self, sampling_param: SamplingParam, return batch + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Post-process validation frames before saving. + + Override this method in subclasses to add custom processing, + e.g., overlay action indicators for action-conditioned models. + + Args: + frames: List of numpy arrays (H, W, C) representing video frames + batch: The ForwardBatch containing input data (may include action data) + + Returns: + Processed frames (same format as input) + """ + return frames + + def _evaluate_validation_video( + self, + video_path: str, + caption: str, + action_path: str | None, + global_step: int, + num_inference_steps: int, + ) -> dict[str, float] | None: + """Optionally evaluate a saved validation video and return scalars.""" + del video_path, caption, action_path, global_step + del num_inference_steps + return None + @torch.no_grad() def _log_validation(self, transformer, training_args, global_step) -> None: """ Generate a validation video and log it to the configured tracker to check the quality during training. """ + self._last_mf_angle_err_mean = float('inf') training_args.inference_mode = True training_args.dit_cpu_offload = False if not training_args.log_validation: @@ -813,7 +903,8 @@ def _log_validation(self, transformer, training_args, global_step) -> None: training_args.validation_dataset_file, local_main_process_only=False) validation_dataset = ValidationDataset( - training_args.validation_dataset_file) + training_args.validation_dataset_file, + num_samples=training_args.validation_num_samples) validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0) @@ -837,15 +928,16 @@ def _log_validation(self, transformer, training_args, global_step) -> None: local_main_process_only=False) step_videos: list[np.ndarray] = [] step_captions: list[str] = [] - - step_audio: list[np.ndarray | None] = [] - step_sample_rates: list[int | None] = [] + step_action_paths: list[str | None] = [] for validation_batch in validation_dataloader: batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch, num_inference_steps) + action_path = validation_batch.get("action_path") + if not isinstance(action_path, str): + action_path = None logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s", self.global_rank, self.rank_in_sp_group, @@ -881,75 +973,126 @@ def _log_validation(self, transformer, training_args, global_step) -> None: x = torchvision.utils.make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8)) + + # Apply optional post-processing (e.g., overlay for action-conditioned models) + frames = self._post_process_validation_frames(frames, batch) step_videos.append(frames) + step_action_paths.append(action_path) # Only sp_group leaders (rank_in_sp_group == 0) need to send their # results to global rank 0 - if self.rank_in_sp_group == 0 and self.global_rank == 0: - # Global rank 0 collects results from all sp_group leaders - all_videos = step_videos # Start with own results - all_captions = step_captions - all_audios = step_audio - all_sample_rates = step_sample_rates - - # Receive from other sp_group leaders - for sp_group_idx in range(1, num_sp_groups): - src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders - recv_videos = world_group.recv_object(src=src_rank) - recv_captions = world_group.recv_object(src=src_rank) - recv_audios = world_group.recv_object(src=src_rank) - recv_sample_rates = world_group.recv_object(src=src_rank) - - all_videos.extend(recv_videos) - all_captions.extend(recv_captions) - all_audios.extend(recv_audios) - all_sample_rates.extend(recv_sample_rates) - - video_filenames = [] - for i, (video, caption, audio, sample_rate) in enumerate( - zip(all_videos, - all_captions, - all_audios, - all_sample_rates, + if self.rank_in_sp_group == 0: + local_video_filenames: list[str] = [] + local_validation_metrics: list[dict[str, float]] = [] + local_eval_error: str | None = None + + for i, (video, caption, action_path) in enumerate( + zip(step_videos, + step_captions, + step_action_paths, strict=True)): os.makedirs(training_args.output_dir, exist_ok=True) filename = os.path.join( training_args.output_dir, - f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4" + f"validation_step_{global_step}_inference_steps_{num_inference_steps}_rank_{self.global_rank}_video_{i}.mp4" ) imageio.mimsave(filename, video, fps=sampling_param.fps) - # Mux audio if available - if (audio is not None and sample_rate is not None - and not self._mux_audio( - filename, - audio, - sample_rate, - )): - logger.warning( - "Audio mux failed for validation video %s; saved video without audio.", - filename) - video_filenames.append(filename) - - artifacts = [] - for filename, caption in zip(video_filenames, - all_captions, - strict=True): - video_artifact = self.tracker.video(filename, - caption=caption) - if video_artifact is not None: - artifacts.append(video_artifact) - if artifacts: - logs = { - f"validation_videos_{num_inference_steps}_steps": - artifacts - } - self.tracker.log_artifacts(logs, global_step) - elif self.rank_in_sp_group == 0: - # Other sp_group leaders send their results to global rank 0 - world_group.send_object(step_videos, dst=0) - world_group.send_object(step_captions, dst=0) - world_group.send_object(step_audio, dst=0) - world_group.send_object(step_sample_rates, dst=0) + local_video_filenames.append(filename) + + try: + sample_metrics = self._evaluate_validation_video( + video_path=filename, + caption=caption, + action_path=action_path, + global_step=global_step, + num_inference_steps=num_inference_steps, + ) + if sample_metrics: + local_validation_metrics.append(sample_metrics) + except Exception as e: + local_eval_error = ( + f"rank {self.global_rank} validation eval failed " + f"for {filename}: {e}") + logger.exception(local_eval_error) + break + + if self.global_rank == 0: + all_video_filenames = local_video_filenames + all_captions = step_captions + validation_metrics = local_validation_metrics + eval_errors: list[str] = [] + if local_eval_error: + eval_errors.append(local_eval_error) + + # Receive from other sp_group leaders + for sp_group_idx in range(1, num_sp_groups): + src_rank = sp_group_idx * self.sp_world_size + recv_video_filenames = world_group.recv_object( + src=src_rank) + recv_captions = world_group.recv_object(src=src_rank) + recv_metrics = world_group.recv_object(src=src_rank) + recv_error = world_group.recv_object(src=src_rank) + + all_video_filenames.extend(recv_video_filenames) + all_captions.extend(recv_captions) + validation_metrics.extend(recv_metrics) + if recv_error: + eval_errors.append(str(recv_error)) + + if eval_errors: + raise RuntimeError( + "Validation flow evaluation failed:\n" + + "\n".join(eval_errors)) + + artifacts = [] + for filename, caption in zip(all_video_filenames, + all_captions, + strict=True): + video_artifact = self.tracker.video(filename, + caption=caption) + if video_artifact is not None: + artifacts.append(video_artifact) + if artifacts: + logs = { + f"validation_videos_{num_inference_steps}_steps": + artifacts + } + self.tracker.log_artifacts(logs, global_step) + + if validation_metrics: + metric_logs: dict[str, float] = {} + metric_keys = sorted( + {k for row in validation_metrics for k in row.keys()}) + for metric_key in metric_keys: + metric_vals = [ + row[metric_key] for row in validation_metrics + if metric_key in row + and np.isfinite(row[metric_key]) + ] + if not metric_vals: + continue + metric_logs[f"metrics/{metric_key}"] = float( + np.mean(metric_vals)) + self.tracker.log(metric_logs, global_step) + + mf_val = metric_logs.get( + "metrics/mf_angle_err_mean") + if mf_val is not None: + self._last_mf_angle_err_mean = mf_val + else: + # Other sp_group leaders send their local results to rank 0 + world_group.send_object(local_video_filenames, dst=0) + world_group.send_object(step_captions, dst=0) + world_group.send_object(local_validation_metrics, dst=0) + world_group.send_object(local_eval_error, dst=0) + if local_eval_error: + raise RuntimeError(local_eval_error) + + # Broadcast the latest mf_angle_err_mean from rank 0 to all ranks + _mf_tensor = torch.tensor( + [self._last_mf_angle_err_mean], device=self.device) + dist.broadcast(_mf_tensor, src=0) + self._last_mf_angle_err_mean = _mf_tensor.item() # Re-enable gradients for training training_args.inference_mode = False diff --git a/fastvideo/training/training_utils.py b/fastvideo/training/training_utils.py index d4fa4efb9..d6b310fba 100644 --- a/fastvideo/training/training_utils.py +++ b/fastvideo/training/training_utils.py @@ -1739,9 +1739,27 @@ def _local_numel(p: torch.Tensor) -> int: def count_trainable(model: torch.nn.Module) -> int: + """Return this rank's trainable parameter count (FSDP local shard).""" return sum(_local_numel(p) for p in model.parameters() if p.requires_grad) +def count_trainable_total( + model: torch.nn.Module, + device: torch.device | None = None, +) -> int: + """Return total trainable parameter count across all ranks (FSDP-safe). + + When device is provided and dist is initialized, torch.distributed.all_reduce(SUM) + with the default world group is used. Otherwise returns local count. + """ + local = count_trainable(model) + if device is not None and dist.is_initialized(): + t = torch.tensor([local], dtype=torch.long, device=device) + dist.all_reduce(t, op=dist.ReduceOp.SUM) + return t.item() + return local + + class EMA_FSDP: """ FSDP2-friendly EMA with two modes: diff --git a/fastvideo/training/wangame_ar_diffusion_pipeline.py b/fastvideo/training/wangame_ar_diffusion_pipeline.py new file mode 100644 index 000000000..7adbbb6c8 --- /dev/null +++ b/fastvideo/training/wangame_ar_diffusion_pipeline.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 + +import sys +from copy import deepcopy +from typing import Any, cast + +import numpy as np +import torch +import torch.nn.functional as F + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import pyarrow_schema_wangame +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.dits.hyworld.pose import process_custom_actions +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler) +from fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline import ( + WanGameCausalDMDPipeline) +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.training_pipeline import TrainingPipeline +from fastvideo.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases) +from fastvideo.utils import shallow_asdict + +logger = init_logger(__name__) + + +class WanGameARDiffusionPipeline(TrainingPipeline): + + _required_config_modules = ["scheduler", "transformer", "vae"] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = SelfForcingFlowMatchScheduler( + shift=fastvideo_args.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + self.modules["scheduler"].set_timesteps(num_inference_steps=1000, + training=True) + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_wangame + + def initialize_training_pipeline(self, training_args: TrainingArgs): + super().initialize_training_pipeline(training_args) + + self.vae = self.get_module("vae") + self.vae.requires_grad_(False) + + self.num_frame_per_block = getattr(training_args, 'num_frame_per_block', 3) + self.timestep_shift = training_args.pipeline_config.flow_shift + self.ar_noise_scheduler = SelfForcingFlowMatchScheduler( + shift=self.timestep_shift, sigma_min=0.0, extra_one_step=True) + self.ar_noise_scheduler.set_timesteps(num_inference_steps=1000, + training=True) + + logger.info("AR Diffusion pipeline initialized with " + "num_frame_per_block=%d, timestep_shift=%.1f", + self.num_frame_per_block, self.timestep_shift) + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(training_args) + args_copy.inference_mode = True + + validation_scheduler = SelfForcingFlowMatchScheduler( + shift=args_copy.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + validation_scheduler.set_timesteps(num_inference_steps=1000, + training=True) + + num_val_steps = int( + training_args.validation_sampling_steps.split(",")[0]) + step_size = 1000 // num_val_steps + args_copy.pipeline_config.dmd_denoising_steps = list( + range(1000, 0, -step_size)) + args_copy.pipeline_config.warp_denoising_step = True + training_args.pipeline_config.dmd_denoising_steps = ( + args_copy.pipeline_config.dmd_denoising_steps) + training_args.pipeline_config.warp_denoising_step = True + + logger.info("Validation: %d-step causal denoising, " + "dmd_denoising_steps has %d entries", + num_val_steps, + len(args_copy.pipeline_config.dmd_denoising_steps)) + + self.validation_pipeline = WanGameCausalDMDPipeline.from_pretrained( + training_args.model_path, + args=args_copy, + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + "vae": self.get_module("vae"), + "scheduler": validation_scheduler, + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + pin_cpu_memory=training_args.pin_cpu_memory, + dit_cpu_offload=True) + + def _get_timestep( + self, + min_timestep: int, + max_timestep: int, + batch_size: int, + num_frame: int, + num_frame_per_block: int, + uniform_timestep: bool = False, + ) -> torch.Tensor: + """ + Sample per-block timesteps. + """ + device = get_local_torch_device() + if uniform_timestep: + timestep = torch.randint( + min_timestep, max_timestep, [batch_size, 1], + device=device, dtype=torch.long + ).repeat(1, num_frame) + return timestep + else: + timestep = torch.randint( + min_timestep, max_timestep, [batch_size, num_frame], + device=device, dtype=torch.long + ) + # Make the noise level the same within every block + timestep = timestep.reshape( + timestep.shape[0], -1, num_frame_per_block) + timestep[:, :, 1:] = timestep[:, :, 0:1] + timestep = timestep.reshape(timestep.shape[0], -1) + return timestep + + def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + logger.info("Starting epoch %s", self.current_epoch) + self.train_dataset.sampler.set_epoch(self.current_epoch) + self.train_loader_iter = iter(self.train_dataloader) + batch = next(self.train_loader_iter) + + latents = batch['vae_latent'] + latents = latents[:, :, :self.training_args.num_latent_t] + clip_features = batch['clip_feature'] + image_latents = batch['first_frame_latent'] + image_latents = image_latents[:, :, :self.training_args.num_latent_t] + pil_image = batch['pil_image'] + infos = batch['info_list'] + + training_batch.latents = latents.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.preprocessed_image = pil_image.to( + get_local_torch_device()) + training_batch.image_embeds = clip_features.to(get_local_torch_device()) + training_batch.image_latents = image_latents.to( + get_local_torch_device()) + training_batch.infos = infos + + # Action conditioning + if 'mouse_cond' in batch and batch['mouse_cond'].numel() > 0: + training_batch.mouse_cond = batch['mouse_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.mouse_cond = None + + if 'keyboard_cond' in batch and batch['keyboard_cond'].numel() > 0: + training_batch.keyboard_cond = batch['keyboard_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.keyboard_cond = None + + # Validate action temporal dimensions match video num_frames + expected_num_frames = (self.training_args.num_latent_t - 1) * 4 + 1 + if training_batch.keyboard_cond is not None: + assert training_batch.keyboard_cond.shape[1] >= expected_num_frames, ( + f"keyboard_cond has {training_batch.keyboard_cond.shape[1]} frames " + f"but need at least {expected_num_frames}") + training_batch.keyboard_cond = training_batch.keyboard_cond[:, :expected_num_frames] + if training_batch.mouse_cond is not None: + assert training_batch.mouse_cond.shape[1] >= expected_num_frames, ( + f"mouse_cond has {training_batch.mouse_cond.shape[1]} frames " + f"but need at least {expected_num_frames}") + training_batch.mouse_cond = training_batch.mouse_cond[:, :expected_num_frames] + + return training_batch + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + assert self.training_args is not None + latents = training_batch.latents # [B, C, T, H, W] + batch_size = latents.shape[0] + num_latent_t = latents.shape[2] + + # Reshape latents to [B, T, C, H, W] for per-frame operations + latents_btchw = latents.permute(0, 2, 1, 3, 4) # [B, T, C, H, W] + + # Sample per-block independent timestep indices: [B, T] + timestep_indices = self._get_timestep( + min_timestep=0, + max_timestep=self.ar_noise_scheduler.num_train_timesteps, + batch_size=batch_size, + num_frame=num_latent_t, + num_frame_per_block=self.num_frame_per_block, + uniform_timestep=False) + + # Convert indices to actual timestep values: [B, T] + self.ar_noise_scheduler.timesteps = self.ar_noise_scheduler.timesteps.to( + get_local_torch_device()) + timesteps = self.ar_noise_scheduler.timesteps[timestep_indices] + + # Generate noise: [B, T, C, H, W] + noise = torch.randn_like(latents_btchw) + + # Add noise per-frame: noisy = (1-σ) * clean + σ * noise + noisy_latents = self.ar_noise_scheduler.add_noise( + latents_btchw.flatten(0, 1), # [B*T, C, H, W] + noise.flatten(0, 1), # [B*T, C, H, W] + timesteps.flatten(0, 1) # [B*T] + ).unflatten(0, (batch_size, num_latent_t)) # [B, T, C, H, W] + + # Convert back to [B, C, T, H, W] for transformer input + noisy_model_input = noisy_latents.permute(0, 2, 1, 3, 4) + + # I2V concatenation: [mask(1ch), image_latent(16ch)] → 17+16=33 ch total + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + temporal_compression_ratio = self.training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (num_latent_t - 1) * temporal_compression_ratio + 1 + _, num_channels, _, latent_height, latent_width = image_latents.shape + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + latent_width) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=temporal_compression_ratio) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, + temporal_compression_ratio, + latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to( + image_latents.device).to(dtype=torch.bfloat16) + + noisy_model_input = torch.cat( + [noisy_model_input, mask_lat_size, image_latents], dim=1) + + # Compute flow-matching training target: target = noise - clean + # Shape: [B, T, C, H, W] + training_target = self.ar_noise_scheduler.training_target( + latents_btchw.flatten(0, 1), + noise.flatten(0, 1), + timesteps.flatten(0, 1) + ).unflatten(0, (batch_size, num_latent_t)) + + # Store everything on training_batch + training_batch.noisy_model_input = noisy_model_input + training_batch.timesteps = timesteps # [B, T] per-frame timesteps + training_batch.noise = noise.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] + training_batch.raw_latent_shape = latents.shape + # Store extra data for the custom loss function + training_batch._ar_training_target = training_target # [B, T, C, H, W] + + return training_batch + + def _build_input_kwargs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Build transformer kwargs with action conditioning and per-frame timesteps.""" + # Image Embeds for conditioning + image_embeds = training_batch.image_embeds + assert torch.isnan(image_embeds).sum() == 0 + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) + + # Process actions for each batch sample + batch_size = training_batch.noisy_model_input.shape[0] + keyboard_cond = training_batch.keyboard_cond + mouse_cond = training_batch.mouse_cond + + if keyboard_cond is not None and mouse_cond is not None: + viewmats_list, intrinsics_list, action_labels_list = [], [], [] + for b in range(batch_size): + v, i, a = process_custom_actions(keyboard_cond[b], + mouse_cond[b]) + viewmats_list.append(v) + intrinsics_list.append(i) + action_labels_list.append(a) + viewmats = torch.stack(viewmats_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + else: + viewmats = None + intrinsics = None + action_labels = None + + # Per-frame timesteps: [B, T] + timesteps = training_batch.timesteps + assert timesteps.ndim == 2, ( + f"Expected per-frame timesteps [B, T], got shape {timesteps.shape}") + + training_batch.input_kwargs = { + "hidden_states": training_batch.noisy_model_input, + "encoder_hidden_states": None, # No text conditioning for WanGame + "timestep": timesteps.to(get_local_torch_device(), + dtype=torch.bfloat16), + "encoder_hidden_states_image": image_embeds, + "viewmats": viewmats, + "Ks": intrinsics, + "action": action_labels, + "num_frame_per_block": self.num_frame_per_block, + "return_dict": False, + } + return training_batch + + def _transformer_forward_and_compute_loss( + self, training_batch: TrainingBatch) -> TrainingBatch: + """ + Run transformer forward pass and compute flow-matching loss. + """ + input_kwargs = training_batch.input_kwargs + + # Forward with causal attention via set_forward_context + with set_forward_context(current_timestep=training_batch.timesteps, + attn_metadata=None, + forward_batch=None): + # model_pred: [B, C, T, H, W] (flow prediction) + model_pred = self.transformer(**input_kwargs) + + # model_pred is [B, C, T, H, W], convert to [B, T, C, H, W] + model_pred_btchw = model_pred.permute(0, 2, 1, 3, 4) + + # Training target: [B, T, C, H, W] + training_target = training_batch._ar_training_target.to( + model_pred_btchw.device, dtype=model_pred_btchw.dtype) + + batch_size, num_frame = model_pred_btchw.shape[:2] + + # Per-frame MSE loss with training weight + # loss shape before weight: [B, T] + loss = F.mse_loss( + model_pred_btchw.float(), + training_target.float(), + reduction='none' + ).mean(dim=(2, 3, 4)) # Average over C, H, W → [B, T] + + # Apply per-timestep training weight from scheduler + timesteps = training_batch.timesteps # [B, T] + weights = self.ar_noise_scheduler.training_weight( + timesteps.flatten(0, 1) + ).unflatten(0, (batch_size, num_frame)) + loss = (loss * weights).mean() + + loss = loss / self.training_args.gradient_accumulation_steps + loss.backward() + + avg_loss = loss.detach().clone() + training_batch.total_loss += avg_loss.item() + + return training_batch + + def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: + """Override to use custom AR diffusion training logic.""" + self.transformer.train() + self.optimizer.zero_grad() + training_batch.total_loss = 0.0 + args = cast(TrainingArgs, self.training_args) + + for _ in range(args.gradient_accumulation_steps): + training_batch = self._get_next_batch(training_batch) + + # Prepare noisy inputs with per-block timesteps + I2V concat + training_batch = self._prepare_dit_inputs(training_batch) + + # Build transformer input kwargs (action conditioning etc.) + training_batch = self._build_input_kwargs(training_batch) + + # Forward + loss + training_batch = self._transformer_forward_and_compute_loss( + training_batch) + + # Clip grad and step + grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( + [p for p in self.transformer.parameters() if p.requires_grad], + args.max_grad_norm if args.max_grad_norm is not None else 0.0) + + self.optimizer.step() + self.lr_scheduler.step() + + if grad_norm is None: + grad_value = 0.0 + else: + try: + if isinstance(grad_norm, torch.Tensor): + grad_value = float(grad_norm.detach().float().item()) + else: + grad_value = float(grad_norm) + except Exception: + grad_value = 0.0 + training_batch.grad_norm = grad_value + training_batch.raw_latent_shape = training_batch.latents.shape + return training_batch + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames.""" + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() + + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + +def main(args) -> None: + logger.info("Starting WanGame AR diffusion training pipeline...") + + pipeline = WanGameARDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + args = pipeline.training_args + pipeline.train() + logger.info("WanGame AR diffusion training pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.dit_cpu_offload = False + main(args) diff --git a/fastvideo/training/wangame_distillation_pipeline.py b/fastvideo/training/wangame_distillation_pipeline.py new file mode 100644 index 000000000..c72c60250 --- /dev/null +++ b/fastvideo/training/wangame_distillation_pipeline.py @@ -0,0 +1,517 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import pyarrow_schema_wangame +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.dits.hyworld.pose import process_custom_actions +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler) +from fastvideo.models.utils import pred_noise_to_pred_video +from fastvideo.pipelines.basic.wan.wangame_i2v_pipeline import ( + WanGameActionImageToVideoPipeline) +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.distillation_pipeline import DistillationPipeline +from fastvideo.training.training_utils import shift_timestep +from fastvideo.utils import is_vsa_available, shallow_asdict + +try: + vsa_available = is_vsa_available() +except Exception: + vsa_available = False + +logger = init_logger(__name__) + + +class WanGameDistillationPipeline(DistillationPipeline): + """ + DMD distillation pipeline for WanGame. + """ + _required_config_modules = ["scheduler", "transformer", "vae"] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + """Initialize WanGame-specific scheduler.""" + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=fastvideo_args.pipeline_config.flow_shift) + + def create_training_stages(self, training_args: TrainingArgs): + """May be used in future refactors.""" + pass + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_wangame + + def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + logger.info("Starting epoch %s", self.current_epoch) + self.train_loader_iter = iter(self.train_dataloader) + batch = next(self.train_loader_iter) + + device = get_local_torch_device() + dtype = torch.bfloat16 + + clip_feature = batch['clip_feature'] + first_frame_latent = batch['first_frame_latent'] + keyboard_cond = batch.get('keyboard_cond', None) + mouse_cond = batch.get('mouse_cond', None) + infos = batch['info_list'] + + if self.training_args.simulate_generator_forward: + # When simulating, we don't need real VAE latents — just use zeros + batch_size = clip_feature.shape[0] + vae_config = self.training_args.pipeline_config.vae_config.arch_config + num_channels = vae_config.z_dim + spatial_compression_ratio = vae_config.spatial_compression_ratio + + latent_height = self.training_args.num_height // spatial_compression_ratio + latent_width = self.training_args.num_width // spatial_compression_ratio + + latents = torch.zeros( + batch_size, + num_channels, + self.training_args.num_latent_t, + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + else: + if 'vae_latent' not in batch: + raise ValueError( + "vae_latent not found in batch and simulate_generator_forward is False. " + "Either preprocess data with VAE latents or set --simulate_generator_forward." + ) + latents = batch['vae_latent'] + latents = latents[:, :, :self.training_args.num_latent_t] + latents = latents.to(device, dtype=dtype) + + training_batch.latents = latents.to(device, dtype=dtype) + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.image_embeds = clip_feature.to(device, dtype=dtype) + training_batch.image_latents = first_frame_latent.to(device, dtype=dtype) + + # Action conditioning + if keyboard_cond is not None and keyboard_cond.numel() > 0: + training_batch.keyboard_cond = keyboard_cond.to(device, dtype=dtype) + else: + training_batch.keyboard_cond = None + if mouse_cond is not None and mouse_cond.numel() > 0: + training_batch.mouse_cond = mouse_cond.to(device, dtype=dtype) + else: + training_batch.mouse_cond = None + + training_batch.infos = infos + return training_batch + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + # First, call parent method to prepare noise, timesteps, etc. for video latents + training_batch = super()._prepare_dit_inputs(training_batch) + + training_batch.conditional_dict = { + "encoder_hidden_states": None, + "encoder_attention_mask": None, + } + training_batch.unconditional_dict = None + + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + # Build mask + image_latent -> cond_concat (20 channels) + temporal_compression_ratio = self.training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (self.training_args.num_latent_t - + 1) * temporal_compression_ratio + 1 + batch_size, num_channels, _, latent_height, latent_width = image_latents.shape + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + latent_width) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=temporal_compression_ratio) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, + temporal_compression_ratio, + latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to( + image_latents.device).to(dtype=torch.bfloat16) + + # cond_concat = [mask(4), image_latent(16)] = 20 channels + image_latents = torch.cat([mask_lat_size, image_latents], dim=1) + + if self.sp_world_size > 1: + total_frames = image_latents.shape[2] + # Split cond latents to local SP shard only when tensor is still global. + if total_frames == self.training_args.num_latent_t: + if total_frames % self.sp_world_size != 0: + raise ValueError( + "image_latents temporal dim is not divisible by SP world size: " + f"frames={total_frames}, sp_world_size={self.sp_world_size}" + ) + image_latents = rearrange(image_latents, + "b c (n t) h w -> b c n t h w", + n=self.sp_world_size).contiguous() + image_latents = image_latents[:, :, self.rank_in_sp_group, :, :, + :] + + training_batch.image_latents = image_latents + + return training_batch + + def _build_distill_input_kwargs( + self, noise_input: torch.Tensor, timestep: torch.Tensor, + text_dict: dict[str, torch.Tensor] | None, + training_batch: TrainingBatch) -> TrainingBatch: + """Build model input with WanGame + """ + # Image embeds (CLIP features) for cross-attention conditioning + image_embeds = training_batch.image_embeds + assert torch.isnan(image_embeds).sum() == 0 + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) + + # already prepared in _prepare_dit_inputs + image_latents = training_batch.image_latents + + # Process action conditioning + keyboard_cond = training_batch.keyboard_cond + mouse_cond = training_batch.mouse_cond + + if keyboard_cond is not None and mouse_cond is not None: + viewmats_list = [] + intrinsics_list = [] + action_labels_list = [] + for b in range(noise_input.shape[0]): + viewmats, intrinsics, action_labels = process_custom_actions( + keyboard_cond[b], mouse_cond[b]) + viewmats_list.append(viewmats) + intrinsics_list.append(intrinsics) + action_labels_list.append(action_labels) + + viewmats = torch.stack(viewmats_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + else: + viewmats = None + intrinsics = None + action_labels = None + + # I2V concatenation: [noise_input(16ch), image_latents(20ch)] -> 36ch + noisy_model_input = torch.cat( + [noise_input, image_latents.permute(0, 2, 1, 3, 4)], dim=2) + + training_batch.input_kwargs = { + "hidden_states": noisy_model_input.permute(0, 2, 1, 3, 4), + "encoder_hidden_states": None, + "timestep": timestep, + "encoder_hidden_states_image": image_embeds, + "viewmats": viewmats, + "Ks": intrinsics, + "action": action_labels, + "return_dict": False, + } + training_batch.noise_latents = noise_input + + return training_batch + + def _dmd_forward(self, generator_pred_video: torch.Tensor, + training_batch: TrainingBatch) -> torch.Tensor: + """Compute DMD loss for WanGame.""" + original_latent = generator_pred_video + with torch.no_grad(): + timestep = torch.randint(0, + self.num_train_timestep, [1], + device=self.device, + dtype=torch.long) + + timestep = shift_timestep(timestep, self.timestep_shift, + self.num_train_timestep) + + timestep = timestep.clamp(self.min_timestep, self.max_timestep) + + noise = torch.randn(self.video_latent_shape, + device=self.device, + dtype=generator_pred_video.dtype) + + noisy_latent = self.noise_scheduler.add_noise( + generator_pred_video.flatten(0, 1), noise.flatten(0, 1), + timestep).detach().unflatten(0, (1, generator_pred_video.shape[1])) + + # Build input kwargs for critic/teacher + training_batch = self._build_distill_input_kwargs( + noisy_latent, timestep, training_batch.conditional_dict, + training_batch) + + # fake_score_transformer forward + current_fake_score_transformer = self._get_fake_score_transformer( + timestep) + fake_score_pred_noise = current_fake_score_transformer( + **training_batch.input_kwargs).permute(0, 2, 1, 3, 4) + + faker_score_pred_video = pred_noise_to_pred_video( + pred_noise=fake_score_pred_noise.flatten(0, 1), + noise_input_latent=noisy_latent.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, fake_score_pred_noise.shape[:2]) + + # real_score_transformer forward + current_real_score_transformer = self._get_real_score_transformer( + timestep) + real_score_pred_noise = current_real_score_transformer( + **training_batch.input_kwargs).permute(0, 2, 1, 3, 4) + + real_score_pred_video = pred_noise_to_pred_video( + pred_noise=real_score_pred_noise.flatten(0, 1), + noise_input_latent=noisy_latent.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, real_score_pred_noise.shape[:2]) + + # No CFG for WanGame - use real_score_pred_video directly + grad = (faker_score_pred_video - real_score_pred_video) / torch.abs( + original_latent - real_score_pred_video).mean() + grad = torch.nan_to_num(grad) + + dmd_loss = 0.5 * F.mse_loss( + original_latent.float(), + (original_latent.float() - grad.float()).detach()) + + training_batch.dmd_latent_vis_dict.update({ + "training_batch_dmd_fwd_clean_latent": + training_batch.latents, + "generator_pred_video": + original_latent.detach(), + "real_score_pred_video": + real_score_pred_video.detach(), + "faker_score_pred_video": + faker_score_pred_video.detach(), + "dmd_timestep": + timestep.detach(), + }) + + return dmd_loss + + def faker_score_forward( + self, training_batch: TrainingBatch + ) -> tuple[TrainingBatch, torch.Tensor]: + """Forward pass for critic training with WanGame action conditioning.""" + with torch.no_grad(), set_forward_context( + current_timestep=training_batch.timesteps, + attn_metadata=training_batch.attn_metadata_vsa): + if self.training_args.simulate_generator_forward: + generator_pred_video = self._generator_multi_step_simulation_forward( + training_batch) + else: + generator_pred_video = self._generator_forward(training_batch) + + fake_score_timestep = torch.randint(0, + self.num_train_timestep, [1], + device=self.device, + dtype=torch.long) + + fake_score_timestep = shift_timestep(fake_score_timestep, + self.timestep_shift, + self.num_train_timestep) + + fake_score_timestep = fake_score_timestep.clamp(self.min_timestep, + self.max_timestep) + + fake_score_noise = torch.randn(self.video_latent_shape, + device=self.device, + dtype=generator_pred_video.dtype) + + noisy_generator_pred_video = self.noise_scheduler.add_noise( + generator_pred_video.flatten(0, 1), + fake_score_noise.flatten(0, 1), fake_score_timestep).unflatten( + 0, (1, generator_pred_video.shape[1])) + + with set_forward_context(current_timestep=training_batch.timesteps, + attn_metadata=training_batch.attn_metadata): + training_batch = self._build_distill_input_kwargs( + noisy_generator_pred_video, fake_score_timestep, + training_batch.conditional_dict, training_batch) + + current_fake_score_transformer = self._get_fake_score_transformer( + fake_score_timestep) + fake_score_pred_noise = current_fake_score_transformer( + **training_batch.input_kwargs).permute(0, 2, 1, 3, 4) + + target = fake_score_noise - generator_pred_video + flow_matching_loss = torch.mean((fake_score_pred_noise - target)**2) + + training_batch.fake_score_latent_vis_dict = { + "training_batch_fakerscore_fwd_clean_latent": + training_batch.latents, + "generator_pred_video": generator_pred_video, + "fake_score_timestep": fake_score_timestep, + } + + return training_batch, flow_matching_loss + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(training_args) + args_copy.inference_mode = True + + validation_pipeline = WanGameActionImageToVideoPipeline.from_pretrained( + training_args.model_path, + args=args_copy, # type: ignore + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + "vae": self.get_module("vae"), + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + pin_cpu_memory=training_args.pin_cpu_memory, + dit_cpu_offload=True) + + self.validation_pipeline = validation_pipeline + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + if isinstance(keyboard_cond, torch.Tensor): + keyboard_cond = keyboard_cond.detach().clone().to(dtype=torch.bfloat16) + else: + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond[:num_frames] + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + if isinstance(mouse_cond, torch.Tensor): + mouse_cond = mouse_cond.detach().clone().to(dtype=torch.bfloat16) + else: + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond[:num_frames] + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames for WanGame. + + Draws keyboard (WASD) and mouse (pitch/yaw) indicators on the video frames. + """ + # Check if action data is available + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + # Import overlay functions + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + # Convert tensors to numpy if needed (bfloat16 -> float32 -> numpy) + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() # (T, 6) + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() # (T, 2) + + # WanGame convention: keyboard [W, S, A, D, left, right], mouse [Pitch, Yaw] + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + # Draw keyboard overlay + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + # Draw mouse overlay + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + +def main(args) -> None: + logger.info("Starting WanGame DMD distillation pipeline...") + + pipeline = WanGameDistillationPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + + args = pipeline.training_args + pipeline.train() + logger.info("WanGame DMD distillation pipeline completed") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/fastvideo/training/wangame_lingbot_training_pipeline.py b/fastvideo/training/wangame_lingbot_training_pipeline.py new file mode 100644 index 000000000..cb3eb8478 --- /dev/null +++ b/fastvideo/training/wangame_lingbot_training_pipeline.py @@ -0,0 +1,418 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from typing import Any + +import numpy as np +import torch + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import pyarrow_schema_wangame_lingbot +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler) +from fastvideo.pipelines.basic.wan.wangame_i2v_pipeline import WanLingBotImageToVideoPipeline +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.training_pipeline import TrainingPipeline +from fastvideo.utils import is_vsa_available, shallow_asdict + +vsa_available = is_vsa_available() + +logger = init_logger(__name__) + + +class WanLingBotTrainingPipeline(TrainingPipeline): + """ + A training pipeline for WanGame-2.1-Fun-1.3B-InP. + """ + _required_config_modules = ["scheduler", "transformer", "vae"] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=fastvideo_args.pipeline_config.flow_shift) + + def create_training_stages(self, training_args: TrainingArgs): + """ + May be used in future refactors. + """ + pass + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_wangame_lingbot + + def set_trainable(self) -> None: + """ + Override to only train newly added action-related parameters for Lingbot: + - patch_embedding_wancamctrl: embeds camera Plucker coordinates + - blocks.*.cam_conditioner: injects camera conditioning into transformer blocks + """ + train_action_only = getattr(self.fastvideo_args, "train_action_only", + False) + + if not train_action_only: + # Default behavior: train all parameters + super().set_trainable() + return + + # Freeze all transformer parameters first + transformer = self.get_module("transformer") + transformer.train() + transformer.requires_grad_(False) + + # Define which parameter name patterns to train + action_param_patterns = [ + "patch_embedding_wancamctrl", + "cam_conditioner", + ] + + # Enable gradients for action-related parameters only + trainable_count = 0 + frozen_count = 0 + for name, param in transformer.named_parameters(): + should_train = any(pattern in name + for pattern in action_param_patterns) + if should_train: + param.requires_grad_(True) + trainable_count += 1 + logger.info(f"Trainable: {name} ({param.numel()} params)") + else: + frozen_count += 1 + + logger.info( + f"Action-only training: {trainable_count} trainable param groups, " + f"{frozen_count} frozen param groups") + + # ── Action module warmup ────────────────────────────────────────────── + # For the first `action_warmup_steps`, action modules (action_embedder, + # to_out_prope) have requires_grad=False so the base model stabilizes + # first. After warmup the gradients are re-enabled. + + _ACTION_PARAM_PATTERNS = [ + "patch_embedding_wancamctrl", + "cam_conditioner", + ] + + def _set_action_params_grad(self, requires_grad: bool) -> None: + """Toggle requires_grad for action-related parameters.""" + transformer = self.get_module("transformer") + count = 0 + for name, param in transformer.named_parameters(): + if any(p in name for p in self._ACTION_PARAM_PATTERNS): + param.requires_grad_(requires_grad) + count += 1 + state = "enabled" if requires_grad else "disabled" + logger.info("Gradients %s for %d action parameter groups", state, count) + + def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: + step = training_batch.current_timestep + warmup_steps = self.training_args.action_warmup_steps + + if warmup_steps > 0: + if step == 1: + # Freeze action params at the very first step + self._set_action_params_grad(False) + logger.info( + "Action warmup: freezing action modules for the first " + "%d steps to stabilize base model", warmup_steps) + elif step == warmup_steps + 1: + # Unfreeze action params once warmup is done + self._set_action_params_grad(True) + logger.info( + "Action warmup complete — action modules unfrozen at " + "step %d", step) + + return super().train_one_step(training_batch) + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + # args_copy.pipeline_config.vae_config.load_encoder = False + # validation_pipeline = WanImageToVideoValidationPipeline.from_pretrained( + self.validation_pipeline = WanLingBotImageToVideoPipeline.from_pretrained( + training_args.model_path, + args=None, + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + dit_cpu_offload=False) + + def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + logger.info("Starting epoch %s", self.current_epoch) + # Reset iterator for next epoch + self.train_loader_iter = iter(self.train_dataloader) + # Get first batch of new epoch + batch = next(self.train_loader_iter) + + latents = batch['vae_latent'] + latents = latents[:, :, :self.training_args.num_latent_t] + # encoder_hidden_states = batch['text_embedding'] + # encoder_attention_mask = batch['text_attention_mask'] + clip_features = batch['clip_feature'] + image_latents = batch['first_frame_latent'] + image_latents = image_latents[:, :, :self.training_args.num_latent_t] + pil_image = batch['pil_image'] + infos = batch['info_list'] + + training_batch.latents = latents.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.preprocessed_image = pil_image.to( + get_local_torch_device()) + training_batch.image_embeds = clip_features.to(get_local_torch_device()) + training_batch.image_latents = image_latents.to( + get_local_torch_device()) + training_batch.infos = infos + + # Action conditioning + if 'mouse_cond' in batch and batch['mouse_cond'].numel() > 0: + training_batch.mouse_cond = batch['mouse_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.mouse_cond = None + + if 'keyboard_cond' in batch and batch['keyboard_cond'].numel() > 0: + training_batch.keyboard_cond = batch['keyboard_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.keyboard_cond = None + + # Validate action temporal dimensions match video num_frames + expected_num_frames = (self.training_args.num_latent_t - 1) * 4 + 1 + if training_batch.keyboard_cond is not None: + assert training_batch.keyboard_cond.shape[1] == expected_num_frames, ( + f"keyboard_cond temporal dim {training_batch.keyboard_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + if training_batch.mouse_cond is not None: + assert training_batch.mouse_cond.shape[1] == expected_num_frames, ( + f"mouse_cond temporal dim {training_batch.mouse_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + + return training_batch + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + + # First, call parent method to prepare noise, timesteps, etc. for video latents + training_batch = super()._prepare_dit_inputs(training_batch) + + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + temporal_compression_ratio = self.training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (self.training_args.num_latent_t - + 1) * temporal_compression_ratio + 1 + batch_size, num_channels, _, latent_height, latent_width = image_latents.shape + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + latent_width) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=temporal_compression_ratio) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, + temporal_compression_ratio, + latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to( + image_latents.device).to(dtype=torch.bfloat16) + + training_batch.noisy_model_input = torch.cat( + [training_batch.noisy_model_input, mask_lat_size, image_latents], + dim=1) + + return training_batch + + def _build_input_kwargs(self, + training_batch: TrainingBatch) -> TrainingBatch: + + # Image Embeds for conditioning + image_embeds = training_batch.image_embeds + assert torch.isnan(image_embeds).sum() == 0 + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) + encoder_hidden_states_image = image_embeds + + from fastvideo.models.dits.wangame_lingbot.cam_utils import process_custom_actions + + # Process actions for each batch sample + batch_size = training_batch.noisy_model_input.shape[0] + num_latent_t = training_batch.noisy_model_input.shape[2] + latent_height = training_batch.noisy_model_input.shape[3] + latent_width = training_batch.noisy_model_input.shape[4] + + temporal_compression_ratio = self.training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (num_latent_t - 1) * temporal_compression_ratio + 1 + + c2ws_plucker_emb_list = [] + for b in range(batch_size): + # Lingbot's process_custom_actions returns [1, 6*spatial_scale^2, lat_f, H_lat, W_lat] + c2ws_plucker_emb = process_custom_actions( + num_frames=num_frames, + keyboard_cond=training_batch.keyboard_cond[b], + mouse_cond=training_batch.mouse_cond[b], + latent_height=latent_height, + latent_width=latent_width) + c2ws_plucker_emb_list.append(c2ws_plucker_emb) + + c2ws_plucker_emb = torch.cat(c2ws_plucker_emb_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + + # c2ws_plucker_emb: [B, C, lat_f, H_lat, W_lat] + assert c2ws_plucker_emb.shape[2] == num_latent_t, ( + f"c2ws_plucker_emb temporal dim {c2ws_plucker_emb.shape[2]} != " + f"video latent temporal dim {num_latent_t}") + + training_batch.input_kwargs = { + "hidden_states": + training_batch.noisy_model_input, + "encoder_hidden_states": + training_batch.encoder_hidden_states, # None (no text conditioning) + "timestep": + training_batch.timesteps.to(get_local_torch_device(), + dtype=torch.bfloat16), + # "encoder_attention_mask": + # training_batch.encoder_attention_mask, + "encoder_hidden_states_image": + encoder_hidden_states_image, + # Action conditioning + "c2ws_plucker_emb": + c2ws_plucker_emb, + "return_dict": + False, + } + return training_batch + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames for WanGame. + + Draws keyboard (WASD) and mouse (pitch/yaw) indicators on the video frames. + """ + # Check if action data is available + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + # Import overlay functions + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + # Convert tensors to numpy if needed (bfloat16 -> float32 -> numpy) + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() # (T, 6) + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() # (T, 2) + + # MatrixGame convention: keyboard [W, S, A, D, left, right], mouse [Pitch, Yaw] + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + # Draw keyboard overlay + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + # Draw mouse overlay + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + +def main(args) -> None: + logger.info("Starting training pipeline...") + + pipeline = WanLingBotTrainingPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + args = pipeline.training_args + pipeline.train() + logger.info("Training pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.dit_cpu_offload = False + main(args) diff --git a/fastvideo/training/wangame_ode_causal_pipeline.py b/fastvideo/training/wangame_ode_causal_pipeline.py new file mode 100644 index 000000000..ae408282e --- /dev/null +++ b/fastvideo/training/wangame_ode_causal_pipeline.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from copy import deepcopy +from typing import Any, cast + +import numpy as np +import torch +import torch.nn.functional as F + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_ode_trajectory_wangame) +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.dits.hyworld.pose import process_custom_actions +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler) +from fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline import ( + WanGameCausalDMDPipeline) +from fastvideo.pipelines.stages.decoding import DecodingStage +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.training_pipeline import TrainingPipeline +from fastvideo.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases) +from fastvideo.utils import shallow_asdict + +logger = init_logger(__name__) + + +class WanGameODEInitTrainingPipeline(TrainingPipeline): + """ + Training pipeline for ODE-init using precomputed denoising trajectories. + + Supervision: predict the next latent in the stored trajectory by + - feeding current latent at timestep t into the transformer to predict noise + - stepping the scheduler with the predicted noise + - minimizing MSE to the stored next latent at timestep t_next + """ + + _required_config_modules = ["scheduler", "transformer", "vae"] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + # Match the preprocess/generation scheduler for consistent stepping + self.modules["scheduler"] = SelfForcingFlowMatchScheduler( + shift=fastvideo_args.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + self.modules["scheduler"].set_timesteps(num_inference_steps=1000, + training=True) + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_ode_trajectory_wangame + + def initialize_training_pipeline(self, training_args: TrainingArgs): + super().initialize_training_pipeline(training_args) + + self.noise_scheduler = self.get_module("scheduler") + self.vae = self.get_module("vae") + self.vae.requires_grad_(False) + + self.timestep_shift = self.training_args.pipeline_config.flow_shift + self.noise_scheduler = SelfForcingFlowMatchScheduler( + shift=self.timestep_shift, sigma_min=0.0, extra_one_step=True) + self.noise_scheduler.set_timesteps(num_inference_steps=1000, + training=True) + + self.training_args.pipeline_config.dmd_denoising_steps = [1000, 750, 500, 250, 0] + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + logger.info("dmd_denoising_steps: %s", + self.training_args.pipeline_config.dmd_denoising_steps) + self.dmd_denoising_steps = torch.tensor([1000, 750, 500, 250, 0], + dtype=torch.long, + device=get_local_torch_device()) + if training_args.warp_denoising_step: # Warp the denoising step according to the scheduler time shift + timesteps = torch.cat((self.noise_scheduler.timesteps.cpu(), + torch.tensor([0], + dtype=torch.float32))).cuda() + logger.info("timesteps: %s", timesteps) + self.dmd_denoising_steps = timesteps[1000 - + self.dmd_denoising_steps] + logger.info("warped self.dmd_denoising_steps: %s", + self.dmd_denoising_steps) + else: + raise ValueError("warp_denoising_step must be true") + + self.dmd_denoising_steps = self.dmd_denoising_steps.to( + get_local_torch_device()) + + logger.info("denoising_step_list: %s", self.dmd_denoising_steps) + + logger.info( + "Initialized ODE-init training pipeline with %s denoising steps", + len(self.dmd_denoising_steps)) + # Cache for nearest trajectory index per DMD step (computed lazily on first batch) + self._cached_closest_idx_per_dmd = None + self.num_train_timestep = self.noise_scheduler.num_train_timesteps + self.manual_idx = 0 + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(training_args) + args_copy.inference_mode = True + # Use the same flow-matching scheduler as training for consistent validation. + validation_scheduler = SelfForcingFlowMatchScheduler( + shift=args_copy.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + validation_scheduler.set_timesteps(num_inference_steps=1000, + training=True) + # Warm start validation with current transformer + self.validation_pipeline = WanGameCausalDMDPipeline.from_pretrained( + training_args.model_path, + args=args_copy, # type: ignore + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + "vae": self.get_module("vae"), + "scheduler": validation_scheduler, + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + pin_cpu_memory=training_args.pin_cpu_memory, + dit_cpu_offload=True) + + def _get_next_batch( + self, + training_batch) -> tuple[TrainingBatch, torch.Tensor, torch.Tensor]: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + logger.info("Starting epoch %s", self.current_epoch) + self.train_loader_iter = iter(self.train_dataloader) + batch = next(self.train_loader_iter) + + # Required fields from parquet (ODE trajectory schema) + clip_feature = batch['clip_feature'] + first_frame_latent = batch['first_frame_latent'] + keyboard_cond = batch.get('keyboard_cond', None) + # keyboard_cond = keyboard_cond[:, :, :3] # TODO: remove hardcode + mouse_cond = batch.get('mouse_cond', None) + infos = batch['info_list'] + + # Trajectory tensors may include a leading singleton batch dim per row + trajectory_latents = batch['trajectory_latents'] + if trajectory_latents.dim() == 7: + # [B, 1, S, C, T, H, W] -> [B, S, C, T, H, W] + trajectory_latents = trajectory_latents[:, 0] + elif trajectory_latents.dim() == 6: + # already [B, S, C, T, H, W] + pass + else: + raise ValueError( + f"Unexpected trajectory_latents dim: {trajectory_latents.dim()}" + ) + + trajectory_timesteps = batch['trajectory_timesteps'] + if trajectory_timesteps.dim() == 3: + # [B, 1, S] -> [B, S] + trajectory_timesteps = trajectory_timesteps[:, 0] + elif trajectory_timesteps.dim() == 2: + # [B, S] + pass + else: + raise ValueError( + f"Unexpected trajectory_timesteps dim: {trajectory_timesteps.dim()}" + ) + # [B, S, C, T, H, W] -> [B, S, T, C, H, W] to match self-forcing + trajectory_latents = trajectory_latents.permute(0, 1, 3, 2, 4, 5) + + # Move to device + device = get_local_torch_device() + training_batch.image_embeds = clip_feature.to(device, + dtype=torch.bfloat16) + training_batch.image_latents = first_frame_latent.to( + device, dtype=torch.bfloat16) + if keyboard_cond is not None and keyboard_cond.numel() > 0: + training_batch.keyboard_cond = keyboard_cond.to( + device, dtype=torch.bfloat16) + else: + training_batch.keyboard_cond = None + if mouse_cond is not None and mouse_cond.numel() > 0: + training_batch.mouse_cond = mouse_cond.to(device, + dtype=torch.bfloat16) + else: + training_batch.mouse_cond = None + training_batch.infos = infos + + # Validate action temporal dimensions match expected video frame count. + expected_num_frames = (self.training_args.num_latent_t - 1) * 4 + 1 + if training_batch.keyboard_cond is not None: + assert training_batch.keyboard_cond.shape[1] == expected_num_frames, ( + f"keyboard_cond temporal dim {training_batch.keyboard_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + if training_batch.mouse_cond is not None: + assert training_batch.mouse_cond.shape[1] == expected_num_frames, ( + f"mouse_cond temporal dim {training_batch.mouse_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + + return training_batch, trajectory_latents[:, :, :self.training_args. + num_latent_t].to( + device, + dtype=torch.bfloat16 + ), trajectory_timesteps.to( + device) + + def _get_timestep(self, + min_timestep: int, + max_timestep: int, + batch_size: int, + num_frame: int, + num_frame_per_block: int, + uniform_timestep: bool = False) -> torch.Tensor: + if uniform_timestep: + timestep = torch.randint(min_timestep, + max_timestep, [batch_size, 1], + device=self.device, + dtype=torch.long).repeat(1, num_frame) + return timestep + else: + timestep = torch.randint(min_timestep, + max_timestep, [batch_size, num_frame], + device=self.device, + dtype=torch.long) + # logger.info(f"individual timestep: {timestep}") + # make the noise level the same within every block + timestep = timestep.reshape(timestep.shape[0], -1, + num_frame_per_block) + timestep[:, :, 1:] = timestep[:, :, 0:1] + timestep = timestep.reshape(timestep.shape[0], -1) + return timestep + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + + # First, call parent method to prepare noise, timesteps, etc. for video latents + training_batch = super()._prepare_dit_inputs(training_batch) + + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + temporal_compression_ratio = 4 + num_frames = (self.training_args.num_latent_t - + 1) * temporal_compression_ratio + 1 + batch_size, num_channels, _, latent_height, latent_width = image_latents.shape + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + latent_width) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=temporal_compression_ratio) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, + temporal_compression_ratio, + latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to( + image_latents.device).to(dtype=torch.bfloat16) + + training_batch.noisy_model_input = torch.cat( + [training_batch.noisy_model_input, mask_lat_size, image_latents], + dim=1) + + return training_batch + + def _step_predict_next_latent( + self, traj_latents: torch.Tensor, traj_timesteps: torch.Tensor, + image_embeds: torch.Tensor, image_latents: torch.Tensor, + keyboard_cond: torch.Tensor | None, mouse_cond: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, + torch.Tensor]]: + latent_vis_dict: dict[str, torch.Tensor] = {} + device = get_local_torch_device() + target_latent = traj_latents[:, -1] + + # Shapes: traj_latents [B, S, C, T, H, W], traj_timesteps [B, S] + B, S, num_frames, num_channels, height, width = traj_latents.shape + + # Lazily cache nearest trajectory index per DMD step based on the (fixed) S timesteps + if self._cached_closest_idx_per_dmd is None: + self._cached_closest_idx_per_dmd = torch.tensor( + [0, 12, 24, 36, S - 1], dtype=torch.long).cpu() + # [0, 1, 2, 3], dtype=torch.long).cpu() + logger.info("self._cached_closest_idx_per_dmd: %s", + self._cached_closest_idx_per_dmd) + logger.info( + "corresponding timesteps: %s", self.noise_scheduler.timesteps[ + self._cached_closest_idx_per_dmd]) + + # Select the K indexes from traj_latents using self._cached_closest_idx_per_dmd + # traj_latents: [B, S, C, T, H, W], self._cached_closest_idx_per_dmd: [K] + # Output: [B, K, C, T, H, W] + assert self._cached_closest_idx_per_dmd is not None + relevant_traj_latents = torch.index_select( + traj_latents, + dim=1, + index=self._cached_closest_idx_per_dmd.to(traj_latents.device)) + logger.info("relevant_traj_latents: %s", relevant_traj_latents.shape) + # assert relevant_traj_latents.shape[0] == 1 + + indexes = self._get_timestep( # [B, num_frames] + 0, + len(self.dmd_denoising_steps), + B, + num_frames, + 3, + uniform_timestep=False) + logger.info("indexes: %s", indexes.shape) + logger.info("indexes: %s", indexes) + # noisy_input = relevant_traj_latents[indexes] + noisy_input = torch.gather( + relevant_traj_latents, + dim=1, + index=indexes.reshape(B, 1, num_frames, 1, 1, + 1).expand(-1, -1, -1, num_channels, height, + width).to(self.device)).squeeze(1) + latent_model_input = noisy_input.permute(0, 2, 1, 3, 4) + if image_latents is not None: + latent_model_input = torch.cat([ + latent_model_input, + image_latents.to(latent_model_input.device, + latent_model_input.dtype), + ], + dim=1) + timestep = self.dmd_denoising_steps[indexes] + logger.info("selected timestep for rank %s: %s", + self.global_rank, + timestep, + local_main_process_only=False) + + # Prepare inputs for transformer + latent_vis_dict["noisy_input"] = noisy_input.permute( + 0, 2, 1, 3, 4).detach().clone().cpu() + latent_vis_dict["x0"] = target_latent.permute(0, 2, 1, 3, + 4).detach().clone().cpu() + + latent_model_input = latent_model_input.to(device, dtype=torch.bfloat16) + timestep = timestep.to(device, dtype=torch.bfloat16) + + logger.info("========== Transformer Input ==========") + logger.info("hidden_states (latent_model_input) shape: %s, dtype: %s", + latent_model_input.shape, latent_model_input.dtype) + logger.info("hidden_states min/max/mean: %.4f / %.4f / %.4f", + latent_model_input.min().item(), + latent_model_input.max().item(), + latent_model_input.mean().item()) + logger.info("encoder_hidden_states_image (image_embeds) shape: %s", + image_embeds.shape if image_embeds is not None else None) + logger.info("timestep shape: %s, dtype: %s", timestep.shape, + timestep.dtype) + logger.info("keyboard_cond: %s", + keyboard_cond.shape if keyboard_cond is not None else None) + logger.info("mouse_cond: %s", + mouse_cond.shape if mouse_cond is not None else None) + + if keyboard_cond is not None and mouse_cond is not None: + viewmats_list = [] + intrinsics_list = [] + action_labels_list = [] + for b in range(latent_model_input.shape[0]): + viewmats, intrinsics, action_labels = process_custom_actions( + keyboard_cond[b], mouse_cond[b]) + viewmats_list.append(viewmats) + intrinsics_list.append(intrinsics) + action_labels_list.append(action_labels) + viewmats = torch.stack(viewmats_list, + dim=0).to(device=device, + dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, + dim=0).to(device=device, + dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, + dim=0).to(device=device, + dtype=torch.bfloat16) + else: + viewmats = None + intrinsics = None + action_labels = None + + empty_text = torch.zeros( + (latent_model_input.shape[0], 0, self.transformer.hidden_size), + device=device, + dtype=torch.bfloat16) + + input_kwargs = { + "hidden_states": latent_model_input, + "encoder_hidden_states": empty_text, + "encoder_hidden_states_image": image_embeds, + "timestep": timestep, + "viewmats": viewmats, + "Ks": intrinsics, + "action": action_labels, + "return_dict": False, + } + # Predict noise and step the scheduler to obtain next latent + with set_forward_context(current_timestep=timestep, + attn_metadata=None, + forward_batch=None): + noise_pred = self.transformer(**input_kwargs).permute(0, 2, 1, 3, 4) + + logger.info("========== Transformer Output ==========") + logger.info("noise_pred shape: %s", noise_pred.shape) + logger.info("noise_pred min/max/mean: %.4f / %.4f / %.4f", + noise_pred.min().item(), + noise_pred.max().item(), + noise_pred.mean().item()) + + from fastvideo.models.utils import pred_noise_to_pred_video + pred_video = pred_noise_to_pred_video( + pred_noise=noise_pred.flatten(0, 1), + noise_input_latent=noisy_input.flatten(0, 1), + timestep=timestep.to(dtype=torch.bfloat16).flatten(0, 1), + scheduler=self.modules["scheduler"]).unflatten( + 0, noise_pred.shape[:2]) + latent_vis_dict["pred_video"] = pred_video.permute( + 0, 2, 1, 3, 4).detach().clone().cpu() + + return pred_video, target_latent, timestep, latent_vis_dict + + def train_one_step(self, training_batch): # type: ignore[override] + self.transformer.train() + self.optimizer.zero_grad() + training_batch.total_loss = 0.0 + args = cast(TrainingArgs, self.training_args) + + # Using cached nearest index per DMD step; computation happens in _step_predict_next_latent + + for _ in range(args.gradient_accumulation_steps): + training_batch, traj_latents, traj_timesteps = self._get_next_batch( + training_batch) + image_embeds = training_batch.image_embeds + image_latents = training_batch.image_latents + keyboard_cond = training_batch.keyboard_cond + mouse_cond = training_batch.mouse_cond + assert traj_latents.shape[0] == 1 + + # Shapes: traj_latents [B, S, C, T, H, W], traj_timesteps [B, S] + _, S = traj_latents.shape[0], traj_latents.shape[1] + if S < 2: + raise ValueError("Trajectory must contain at least 2 steps") + + # Forward to predict next latent by stepping scheduler with predicted noise + noise_pred, target_latent, t, latent_vis_dict = self._step_predict_next_latent( + traj_latents, traj_timesteps, image_embeds, image_latents, + keyboard_cond, mouse_cond) + + training_batch.latent_vis_dict.update(latent_vis_dict) + + mask = t != 0 + + # Compute loss + loss = F.mse_loss(noise_pred[mask], + target_latent[mask], + reduction="mean") + loss = loss / args.gradient_accumulation_steps + + with set_forward_context(current_timestep=t, + attn_metadata=None, + forward_batch=None): + loss.backward() + avg_loss = loss.detach().clone() + training_batch.total_loss += avg_loss.item() + + # Clip grad and step optimizers + grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( + [p for p in self.transformer.parameters() if p.requires_grad], + args.max_grad_norm if args.max_grad_norm is not None else 0.0) + + self.optimizer.step() + self.lr_scheduler.step() + + if grad_norm is None: + grad_value = 0.0 + else: + try: + if isinstance(grad_norm, torch.Tensor): + grad_value = float(grad_norm.detach().float().item()) + else: + grad_value = float(grad_norm) + except Exception: + grad_value = 0.0 + training_batch.grad_norm = grad_value + B, S, T, C, H, W = traj_latents.shape + training_batch.raw_latent_shape = (B, C, T, H, W) + return training_batch + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames for WanGame. + + Draws keyboard (WASD) and mouse (pitch/yaw) indicators on the video frames. + """ + # Check if action data is available + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + # Import overlay functions + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + # Convert tensors to numpy if needed (bfloat16 -> float32 -> numpy) + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() # (T, 6) + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() # (T, 2) + + # MatrixGame convention: keyboard [W, S, A, D, left, right], mouse [Pitch, Yaw] + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + # Draw keyboard overlay + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + # Draw mouse overlay + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + + def visualize_intermediate_latents(self, training_batch: TrainingBatch, + training_args: TrainingArgs, step: int): + tracker_loss_dict: dict[str, Any] = {} + latents_vis_dict = training_batch.latent_vis_dict + latent_log_keys = ['noisy_input', 'x0', 'pred_video'] + for latent_key in latent_log_keys: + assert latent_key in latents_vis_dict and latents_vis_dict[ + latent_key] is not None + latent = latents_vis_dict[latent_key] + pixel_latent = self.decoding_stage.decode(latent, training_args) + + video = pixel_latent.cpu().float() + video = video.permute(0, 2, 1, 3, 4) + video = (video * 255).numpy().astype(np.uint8) + + keyboard_cond = getattr(training_batch, "keyboard_cond", None) + mouse_cond = getattr(training_batch, "mouse_cond", None) + for batch_idx in range(video.shape[0]): + sample_batch = type("ValidationBatch", (), {})() + if keyboard_cond is not None and batch_idx < keyboard_cond.shape[0]: + sample_batch.keyboard_cond = keyboard_cond[batch_idx:batch_idx + 1] + if mouse_cond is not None and batch_idx < mouse_cond.shape[0]: + sample_batch.mouse_cond = mouse_cond[batch_idx:batch_idx + 1] + + video_frames = [ + np.transpose(video[batch_idx, frame_idx], (1, 2, 0)) + for frame_idx in range(video.shape[1]) + ] + video_frames = self._post_process_validation_frames( + video_frames, cast(ForwardBatch, sample_batch)) + video[batch_idx] = np.stack([ + np.transpose(frame, (2, 0, 1)) for frame in video_frames + ], axis=0) + + video_artifact = self.tracker.video( + video, fps=16, format="mp4") # change to 16 for Wan2.1 + if video_artifact is not None: + tracker_loss_dict[latent_key] = video_artifact + # Clean up references + del video, pixel_latent, latent + + if self.global_rank == 0 and tracker_loss_dict: + self.tracker.log_artifacts(tracker_loss_dict, step) + + +def main(args) -> None: + logger.info("Starting ODE-init training pipeline...") + pipeline = WanGameODEInitTrainingPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + args = pipeline.training_args + pipeline.train() + logger.info("ODE-init training pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.dit_cpu_offload = False + main(args) diff --git a/fastvideo/training/wangame_self_forcing_distillation_pipeline.py b/fastvideo/training/wangame_self_forcing_distillation_pipeline.py new file mode 100644 index 000000000..0325826ac --- /dev/null +++ b/fastvideo/training/wangame_self_forcing_distillation_pipeline.py @@ -0,0 +1,952 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from copy import deepcopy +from typing import Any + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import ( + pyarrow_schema_ode_trajectory_wangame) +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.dits.hyworld.pose import process_custom_actions +from fastvideo.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler) +from fastvideo.models.utils import pred_noise_to_pred_video +from fastvideo.pipelines.basic.wan.wangame_causal_dmd_pipeline import ( + WanGameCausalDMDPipeline) +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.self_forcing_distillation_pipeline import ( + SelfForcingDistillationPipeline) +from fastvideo.training.training_utils import shift_timestep +from fastvideo.utils import is_vsa_available, shallow_asdict + +vsa_available = is_vsa_available() + +logger = init_logger(__name__) + + +class WanGameSelfForcingDistillationPipeline(SelfForcingDistillationPipeline): + """ + A self-forcing distillation pipeline for WanGame that uses the self-forcing methodology + with DMD for video generation. + """ + _required_config_modules = [ + "scheduler", + "transformer", + "vae", + ] + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_ode_trajectory_wangame + + def _initialize_simulation_caches( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + *, + max_num_frames: int | None = None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Initialize KV cache and cross-attention cache for multi-step simulation.""" + num_transformer_blocks = len(self.transformer.blocks) + latent_shape = self.video_latent_shape_sp + _, num_frames, _, height, width = latent_shape + + _, p_h, p_w = self.transformer.patch_size + post_patch_height = height // p_h + post_patch_width = width // p_w + + frame_seq_length = post_patch_height * post_patch_width + self.frame_seq_length = frame_seq_length + + # Get model configuration parameters - handle FSDP wrapping + num_attention_heads = getattr(self.transformer, 'num_attention_heads', + None) + attention_head_dim = getattr(self.transformer, 'attention_head_dim', + None) + + # 1 CLS token + 256 patch tokens = 257 + text_len = 257 + + if max_num_frames is None: + max_num_frames = num_frames + num_max_frames = max(max_num_frames, num_frames) + kv_cache_size = num_max_frames * frame_seq_length + # WanGame causal attention stores both RoPE and PRoPE branches in cache. + cache_head_dim = attention_head_dim * 2 + + kv_cache = [] + for _ in range(num_transformer_blocks): + kv_cache.append({ + "k": + torch.zeros([ + batch_size, kv_cache_size, num_attention_heads, + cache_head_dim + ], + dtype=dtype, + device=device), + "v": + torch.zeros([ + batch_size, kv_cache_size, num_attention_heads, + cache_head_dim + ], + dtype=dtype, + device=device), + "global_end_index": + torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": + torch.tensor([0], dtype=torch.long, device=device) + }) + + # Initialize cross-attention cache + crossattn_cache = [] + for _ in range(num_transformer_blocks): + crossattn_cache.append({"is_init": False}) + + return kv_cache, crossattn_cache + + def _reset_simulation_caches( + self, kv_cache: list[dict[str, + Any]], crossattn_cache: list[dict[str, + Any]] + ) -> None: + """Reset KV cache and cross-attention cache to clean state.""" + if kv_cache is not None: + for cache_dict in kv_cache: + cache_dict["global_end_index"].fill_(0) + cache_dict["local_end_index"].fill_(0) + cache_dict["k"].zero_() + cache_dict["v"].zero_() + + if crossattn_cache is not None: + for cache_dict in crossattn_cache: + cache_dict["is_init"] = False + if "k" in cache_dict: + cache_dict["k"].zero_() + if "v" in cache_dict: + cache_dict["v"].zero_() + + def _generator_multi_step_simulation_forward( + self, + training_batch: TrainingBatch, + return_sim_steps: bool = False) -> torch.Tensor: + """Forward pass through student transformer matching inference procedure with KV cache management. + + This function is adapted from the reference self-forcing implementation's inference_with_trajectory + and includes gradient masking logic for dynamic frame generation. + """ + latents = training_batch.latents + dtype = latents.dtype + batch_size = latents.shape[0] + initial_latent = getattr(training_batch, 'image_latent', None) + + # Dynamic frame generation logic (adapted from _run_generator) + num_training_frames = getattr(self.training_args, 'num_latent_t', 21) + + # During training, the number of generated frames should be uniformly sampled from + # [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block + min_num_frames = 20 if self.independent_first_frame else 21 + max_num_frames = num_training_frames - 1 if self.independent_first_frame else num_training_frames + assert max_num_frames % self.num_frame_per_block == 0 + assert min_num_frames % self.num_frame_per_block == 0 + max_num_blocks = max_num_frames // self.num_frame_per_block + min_num_blocks = min_num_frames // self.num_frame_per_block + + # Sample number of blocks and sync across processes + num_generated_blocks = torch.randint(min_num_blocks, + max_num_blocks + 1, (1, ), + device=self.device) + if dist.is_initialized(): + dist.broadcast(num_generated_blocks, src=0) + num_generated_blocks = num_generated_blocks.item() + num_generated_frames = num_generated_blocks * self.num_frame_per_block + if self.independent_first_frame and initial_latent is None: + num_generated_frames += 1 + min_num_frames += 1 + + # Create noise with dynamic shape + if initial_latent is not None: + noise_shape = [ + batch_size, num_generated_frames - 1, + *self.video_latent_shape[2:] + ] + else: + noise_shape = [ + batch_size, num_generated_frames, *self.video_latent_shape[2:] + ] + + noise = torch.randn(noise_shape, device=self.device, dtype=dtype) + if self.sp_world_size > 1: + noise = rearrange(noise, + "b (n t) c h w -> b n t c h w", + n=self.sp_world_size).contiguous() + noise = noise[:, self.rank_in_sp_group, :, :, :, :] + + batch_size, num_frames, num_channels, height, width = noise.shape + + # Block size calculation + if not self.independent_first_frame or (self.independent_first_frame + and initial_latent is not None): + assert num_frames % self.num_frame_per_block == 0 + num_blocks = num_frames // self.num_frame_per_block + else: + assert (num_frames - 1) % self.num_frame_per_block == 0 + num_blocks = (num_frames - 1) // self.num_frame_per_block + + num_input_frames = initial_latent.shape[ + 1] if initial_latent is not None else 0 + num_output_frames = num_frames + num_input_frames + output = torch.zeros( + [batch_size, num_output_frames, num_channels, height, width], + device=noise.device, + dtype=noise.dtype) + + def get_model_device(model): + if model is None: + return "None" + try: + return next(model.parameters()).device + except (StopIteration, AttributeError): + return "Unknown" + + # Step 1: Initialize KV cache to all zeros + cache_frames = num_generated_frames + num_input_frames + (self.kv_cache1, + self.crossattn_cache) = self._initialize_simulation_caches( + batch_size, dtype, self.device, max_num_frames=cache_frames) + + # Step 2: Cache context feature + current_start_frame = 0 + if initial_latent is not None: + timestep = torch.ones( + [batch_size, 1], device=noise.device, dtype=torch.int64) * 0 + output[:, :1] = initial_latent + with torch.no_grad(): + # Build input kwargs for initial latent + training_batch_temp = self._build_distill_input_kwargs( + initial_latent, + timestep * 0, + training_batch.conditional_dict, + training_batch, + frame_start=0, + frame_end=1, + num_frame_per_block=1) + + # we process the image latent with self.transformer_2 (low-noise expert) + current_model = self.transformer_2 if self.transformer_2 is not None else self.transformer + current_model( + **training_batch_temp.input_kwargs, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + start_frame=current_start_frame) + current_start_frame += 1 + + # Step 3: Temporal denoising loop + all_num_frames = [self.num_frame_per_block] * num_blocks + if self.independent_first_frame and initial_latent is None: + all_num_frames = [1] + all_num_frames + num_denoising_steps = len(self.denoising_step_list) + exit_flags = self.generate_and_sync_list(len(all_num_frames), + num_denoising_steps, + device=noise.device) + start_gradient_frame_index = max(0, num_output_frames - 21) + + for block_index, current_num_frames in enumerate(all_num_frames): + noisy_input = noise[:, current_start_frame - + num_input_frames:current_start_frame + + current_num_frames - num_input_frames] + + # Step 3.1: Spatial denoising loop + for index, current_timestep in enumerate(self.denoising_step_list): + if self.same_step_across_blocks: + exit_flag = (index == exit_flags[0]) + else: + exit_flag = (index == exit_flags[block_index]) + + timestep = torch.ones([batch_size, current_num_frames], + device=noise.device, + dtype=torch.int64) * current_timestep + + if self.boundary_timestep is not None and current_timestep < self.boundary_timestep and self.transformer_2 is not None: + current_model = self.transformer_2 + else: + current_model = self.transformer + + if not exit_flag: + with torch.no_grad(): + # Build input kwargs + training_batch_temp = self._build_distill_input_kwargs( + noisy_input, + timestep, + training_batch.conditional_dict, + training_batch, + frame_start=current_start_frame, + frame_end=current_start_frame + current_num_frames, + num_frame_per_block=current_num_frames) + + pred_flow = current_model( + **training_batch_temp.input_kwargs, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * + self.frame_seq_length, + start_frame=current_start_frame).permute( + 0, 2, 1, 3, 4) + + denoised_pred = pred_noise_to_pred_video( + pred_noise=pred_flow.flatten(0, 1), + noise_input_latent=noisy_input.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, pred_flow.shape[:2]) + + next_timestep = self.denoising_step_list[index + 1] + noisy_input = self.noise_scheduler.add_noise( + denoised_pred.flatten(0, 1), + torch.randn_like(denoised_pred.flatten(0, 1)), + next_timestep * + torch.ones([batch_size * current_num_frames], + device=noise.device, + dtype=torch.long)).unflatten( + 0, denoised_pred.shape[:2]) + else: + # Final prediction with gradient control + if current_start_frame < start_gradient_frame_index: + with torch.no_grad(): + training_batch_temp = self._build_distill_input_kwargs( + noisy_input, + timestep, + training_batch.conditional_dict, + training_batch, + frame_start=current_start_frame, + frame_end=current_start_frame + + current_num_frames, + num_frame_per_block=current_num_frames) + + pred_flow = current_model( + **training_batch_temp.input_kwargs, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * + self.frame_seq_length, + start_frame=current_start_frame).permute( + 0, 2, 1, 3, 4) + else: + training_batch_temp = self._build_distill_input_kwargs( + noisy_input, + timestep, + training_batch.conditional_dict, + training_batch, + frame_start=current_start_frame, + frame_end=current_start_frame + current_num_frames, + num_frame_per_block=current_num_frames) + + pred_flow = current_model( + **training_batch_temp.input_kwargs, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * + self.frame_seq_length, + start_frame=current_start_frame).permute( + 0, 2, 1, 3, 4) + + denoised_pred = pred_noise_to_pred_video( + pred_noise=pred_flow.flatten(0, 1), + noise_input_latent=noisy_input.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, pred_flow.shape[:2]) + break + + # Step 3.2: record the model's output + output[:, current_start_frame:current_start_frame + + current_num_frames] = denoised_pred + + # Step 3.3: rerun with timestep zero to update the cache + context_timestep = torch.ones_like(timestep) * self.context_noise + denoised_pred = self.noise_scheduler.add_noise( + denoised_pred.flatten(0, 1), + torch.randn_like(denoised_pred.flatten(0, 1)), + context_timestep).unflatten(0, denoised_pred.shape[:2]) + + with torch.no_grad(): + training_batch_temp = self._build_distill_input_kwargs( + denoised_pred, + context_timestep, + training_batch.conditional_dict, + training_batch, + frame_start=current_start_frame, + frame_end=current_start_frame + current_num_frames, + num_frame_per_block=current_num_frames) + + # context_timestep is 0 so we use transformer_2 + current_model = self.transformer_2 if self.transformer_2 is not None else self.transformer + current_model( + **training_batch_temp.input_kwargs, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + start_frame=current_start_frame) + + # Step 3.4: update the start and end frame indices + current_start_frame += current_num_frames + + # Handle last 21 frames logic + pred_image_or_video = output + if num_input_frames > 0: + pred_image_or_video = output[:, num_input_frames:] + + # Slice last 21 frames if we generated more + gradient_mask = None + if pred_image_or_video.shape[1] > 21: + with torch.no_grad(): + # Re-encode to get image latent + latent_to_decode = pred_image_or_video[:, :-20, ...] + # Decode to video + latent_to_decode = latent_to_decode.permute( + 0, 2, 1, 3, 4) # [B, C, F, H, W] + + # Apply VAE scaling and shift factors + if isinstance(self.vae.scaling_factor, torch.Tensor): + latent_to_decode = latent_to_decode / self.vae.scaling_factor.to( + latent_to_decode.device, latent_to_decode.dtype) + else: + latent_to_decode = latent_to_decode / self.vae.scaling_factor + + if hasattr( + self.vae, + "shift_factor") and self.vae.shift_factor is not None: + if isinstance(self.vae.shift_factor, torch.Tensor): + latent_to_decode += self.vae.shift_factor.to( + latent_to_decode.device, latent_to_decode.dtype) + else: + latent_to_decode += self.vae.shift_factor + + # Decode to pixels + pixels = self.vae.decode(latent_to_decode) + frame = pixels[:, :, -1:, :, :].to( + dtype) # Last frame [B, C, 1, H, W] + + # Encode frame back to get image latent + image_latent = self.vae.encode(frame).to(dtype) + image_latent = image_latent.permute(0, 2, 1, 3, + 4) # [B, F, C, H, W] + + pred_image_or_video_last_21 = torch.cat( + [image_latent, pred_image_or_video[:, -20:, ...]], dim=1) + else: + pred_image_or_video_last_21 = pred_image_or_video + + # Set up gradient mask if we generated more than minimum frames + if num_generated_frames != min_num_frames: + # Currently, we do not use gradient for the first chunk, since it contains image latents + gradient_mask = torch.ones_like(pred_image_or_video_last_21, + dtype=torch.bool) + if self.independent_first_frame: + gradient_mask[:, :1] = False + else: + gradient_mask[:, :self.num_frame_per_block] = False + + # Apply gradient masking if needed + final_output = pred_image_or_video_last_21.to(dtype) + if gradient_mask is not None: + # Apply gradient masking: detach frames that shouldn't contribute gradients + final_output = torch.where( + gradient_mask, + pred_image_or_video_last_21, # Keep original values where gradient_mask is True + pred_image_or_video_last_21.detach( + ) # Detach where gradient_mask is False + ) + + # Store visualization data + training_batch.dmd_latent_vis_dict["generator_timestep"] = torch.tensor( + self.denoising_step_list[exit_flags[0]], + dtype=torch.float32, + device=self.device) + + # Store gradient mask information for debugging + if gradient_mask is not None: + training_batch.dmd_latent_vis_dict[ + "gradient_mask"] = gradient_mask.float() + training_batch.dmd_latent_vis_dict[ + "num_generated_frames"] = torch.tensor(num_generated_frames, + dtype=torch.float32, + device=self.device) + training_batch.dmd_latent_vis_dict["min_num_frames"] = torch.tensor( + min_num_frames, dtype=torch.float32, device=self.device) + + # Clean up caches + assert self.kv_cache1 is not None + assert self.crossattn_cache is not None + self._reset_simulation_caches(self.kv_cache1, self.crossattn_cache) + + return final_output if gradient_mask is not None else pred_image_or_video + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(training_args) + args_copy.inference_mode = True + # Use the same flow-matching scheduler as training for consistent validation. + validation_scheduler = SelfForcingFlowMatchScheduler( + shift=args_copy.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True) + validation_scheduler.set_timesteps(num_inference_steps=1000, + training=True) + # Warm start validation with current transformer + self.validation_pipeline = WanGameCausalDMDPipeline.from_pretrained( + training_args.model_path, + args=args_copy, # type: ignore + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + "vae": self.get_module("vae"), + "scheduler": validation_scheduler, + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + pin_cpu_memory=training_args.pin_cpu_memory, + dit_cpu_offload=True) + + def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + # Reset iterator for next epoch + self.train_loader_iter = iter(self.train_dataloader) + # Get first batch of new epoch + batch = next(self.train_loader_iter) + + clip_feature = batch['clip_feature'] + first_frame_latent = batch['first_frame_latent'] + keyboard_cond = batch.get('keyboard_cond', None) + mouse_cond = batch.get('mouse_cond', None) + infos = batch['info_list'] + + batch_size = clip_feature.shape[0] + vae_config = self.training_args.pipeline_config.vae_config.arch_config + num_channels = vae_config.z_dim + spatial_compression_ratio = vae_config.spatial_compression_ratio + + latent_height = self.training_args.num_height // spatial_compression_ratio + latent_width = self.training_args.num_width // spatial_compression_ratio + + latents = torch.randn(batch_size, num_channels, + self.training_args.num_latent_t, latent_height, + latent_width).to(get_local_torch_device(), + dtype=torch.bfloat16) + + training_batch.latents = latents.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.image_embeds = clip_feature.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.image_latents = first_frame_latent.to( + get_local_torch_device(), dtype=torch.bfloat16) + # Action conditioning + if keyboard_cond is not None and keyboard_cond.numel() > 0: + keyboard_cond_full = keyboard_cond.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.keyboard_cond = keyboard_cond_full # For Teacher/Critic (dim=6) + else: + training_batch.keyboard_cond = None + if mouse_cond is not None and mouse_cond.numel() > 0: + training_batch.mouse_cond = mouse_cond.to(get_local_torch_device(), + dtype=torch.bfloat16) + else: + training_batch.mouse_cond = None + training_batch.infos = infos + return training_batch + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + # First, call parent method to prepare noise, timesteps, etc. for video latents + training_batch = super()._prepare_dit_inputs(training_batch) + + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + # cond_concat = [mask(4), image_latent(16)] with 20 channels. + expected_cond_channels = 20 + if image_latents.shape[1] != expected_cond_channels: + raise ValueError( + "Unexpected first_frame_latent channels, " + "Expected {expected_cond_channels} (cond_concat), got {image_latents.shape[1]}." + ) + + if self.sp_world_size > 1: + total_frames = image_latents.shape[2] + # Split cond latents to local SP shard only when tensor is still global. + if total_frames == self.training_args.num_latent_t: + if total_frames % self.sp_world_size != 0: + raise ValueError( + "image_latents temporal dim is not divisible by SP world size: " + f"frames={total_frames}, sp_world_size={self.sp_world_size}" + ) + image_latents = rearrange(image_latents, + "b c (n t) h w -> b c n t h w", + n=self.sp_world_size).contiguous() + image_latents = image_latents[:, :, self.rank_in_sp_group, :, :, + :] + + training_batch.image_latents = image_latents + + return training_batch + + def _build_distill_input_kwargs( + self, + noise_input: torch.Tensor, + timestep: torch.Tensor, + text_dict: dict[str, torch.Tensor] | None, + training_batch: TrainingBatch, + frame_start: int | None = None, + frame_end: int | None = None, + num_frame_per_block: int | None = None) -> TrainingBatch: + # Image Embeds for conditioning + image_embeds = training_batch.image_embeds + assert torch.isnan(image_embeds).sum() == 0 + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) + + image_latents = training_batch.image_latents + if frame_start is not None and frame_end is not None: + image_latents = image_latents[:, :, frame_start:frame_end, :, :] + + vae_temporal_compression_ratio = 4 + if frame_start is not None and frame_end is not None: + action_frame_start = frame_start * vae_temporal_compression_ratio + action_frame_end = (frame_end - + 1) * vae_temporal_compression_ratio + 1 + if frame_start == 0: + action_frame_start = 0 + keyboard_cond_sliced = training_batch.keyboard_cond[:, + action_frame_start:action_frame_end, :] if training_batch.keyboard_cond is not None else None + mouse_cond_sliced = training_batch.mouse_cond[:, + action_frame_start:action_frame_end, :] if training_batch.mouse_cond is not None else None + else: + keyboard_cond_sliced = training_batch.keyboard_cond + mouse_cond_sliced = training_batch.mouse_cond + + if keyboard_cond_sliced is not None and mouse_cond_sliced is not None: + viewmats_list = [] + intrinsics_list = [] + action_labels_list = [] + for b in range(noise_input.shape[0]): + viewmats, intrinsics, action_labels = process_custom_actions( + keyboard_cond_sliced[b], mouse_cond_sliced[b]) + viewmats_list.append(viewmats) + intrinsics_list.append(intrinsics) + action_labels_list.append(action_labels) + + viewmats = torch.stack(viewmats_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, dim=0).to( + device=get_local_torch_device(), dtype=torch.bfloat16) + else: + viewmats = None + intrinsics = None + action_labels = None + + noisy_model_input = torch.cat( + [noise_input, image_latents.permute(0, 2, 1, 3, 4)], dim=2) + + training_batch.input_kwargs = { + "hidden_states": noisy_model_input.permute(0, 2, 1, 3, + 4), # bs, c, t, h, w + "encoder_hidden_states": None, + "timestep": timestep, + "encoder_hidden_states_image": image_embeds, + "viewmats": viewmats, + "Ks": intrinsics, + "action": action_labels, + "num_frame_per_block": num_frame_per_block if num_frame_per_block is not None else self.num_frame_per_block, + } + training_batch.noise_latents = noise_input + + return training_batch + + def _dmd_forward(self, generator_pred_video: torch.Tensor, + training_batch: TrainingBatch) -> torch.Tensor: + """Compute DMD (Diffusion Model Distillation) loss for WanGame.""" + original_latent = generator_pred_video + with torch.no_grad(): + timestep = torch.randint(0, + self.num_train_timestep, [1], + device=self.device, + dtype=torch.long) + + timestep = shift_timestep(timestep, self.timestep_shift, + self.num_train_timestep) + + timestep = timestep.clamp(self.min_timestep, self.max_timestep) + + noise = torch.randn(self.video_latent_shape, + device=self.device, + dtype=generator_pred_video.dtype) + + noisy_latent = self.noise_scheduler.add_noise( + generator_pred_video.flatten(0, 1), noise.flatten(0, 1), + timestep).detach().unflatten(0, (generator_pred_video.shape[0], + generator_pred_video.shape[1])) + + # Non-causal models expect 1D timestep (batch_size,) + critic_timestep = timestep.expand(noisy_latent.shape[0]) + + self._build_distill_input_kwargs( + noisy_latent, critic_timestep, None, training_batch + ) + + # fake_score_transformer forward + current_fake_score_transformer = self._get_fake_score_transformer( + timestep) + fake_score_pred_noise = current_fake_score_transformer( + **training_batch.input_kwargs + ).permute(0, 2, 1, 3, 4) + + faker_score_pred_video = pred_noise_to_pred_video( + pred_noise=fake_score_pred_noise.flatten(0, 1), + noise_input_latent=noisy_latent.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, fake_score_pred_noise.shape[:2]) + + # real_score_transformer forward + current_real_score_transformer = self._get_real_score_transformer( + timestep) + real_score_pred_noise = current_real_score_transformer( + **training_batch.input_kwargs + ).permute(0, 2, 1, 3, 4) + + real_score_pred_video = pred_noise_to_pred_video( + pred_noise=real_score_pred_noise.flatten(0, 1), + noise_input_latent=noisy_latent.flatten(0, 1), + timestep=timestep, + scheduler=self.noise_scheduler).unflatten( + 0, real_score_pred_noise.shape[:2]) + + # No CFG for WanGame - use real_score_pred_video directly + grad = (faker_score_pred_video - real_score_pred_video) / torch.abs( + original_latent - real_score_pred_video).mean() + grad = torch.nan_to_num(grad) + + dmd_loss = 0.5 * F.mse_loss( + original_latent.float(), + (original_latent.float() - grad.float()).detach()) + + training_batch.dmd_latent_vis_dict.update({ + "training_batch_dmd_fwd_clean_latent": + training_batch.latents, + "generator_pred_video": + original_latent.detach(), + "real_score_pred_video": + real_score_pred_video.detach(), + "faker_score_pred_video": + faker_score_pred_video.detach(), + "dmd_timestep": + timestep.detach(), + }) + + return dmd_loss + + def faker_score_forward( + self, training_batch: TrainingBatch + ) -> tuple[TrainingBatch, torch.Tensor]: + """Forward pass for critic training with WanGame action conditioning.""" + with torch.no_grad(), set_forward_context( + current_timestep=training_batch.timesteps, + attn_metadata=training_batch.attn_metadata_vsa): + if self.training_args.simulate_generator_forward: + generator_pred_video = self._generator_multi_step_simulation_forward( + training_batch) + else: + generator_pred_video = self._generator_forward(training_batch) + + fake_score_timestep = torch.randint(0, + self.num_train_timestep, [1], + device=self.device, + dtype=torch.long) + + fake_score_timestep = shift_timestep(fake_score_timestep, + self.timestep_shift, + self.num_train_timestep) + + fake_score_timestep = fake_score_timestep.clamp(self.min_timestep, + self.max_timestep) + + fake_score_noise = torch.randn(self.video_latent_shape, + device=self.device, + dtype=generator_pred_video.dtype) + + noisy_generator_pred_video = self.noise_scheduler.add_noise( + generator_pred_video.flatten(0, 1), + fake_score_noise.flatten(0, 1), fake_score_timestep).unflatten( + 0, + (generator_pred_video.shape[0], generator_pred_video.shape[1])) + + # Non-causal critic expects 1D timestep (batch_size,), not 2D (batch_size, num_frames). + expanded_fake_score_timestep = fake_score_timestep.expand( + noisy_generator_pred_video.shape[0]) + + self._build_distill_input_kwargs( + noisy_generator_pred_video, expanded_fake_score_timestep, None, training_batch + ) + + with set_forward_context(current_timestep=training_batch.timesteps, + attn_metadata=training_batch.attn_metadata): + current_fake_score_transformer = self._get_fake_score_transformer(fake_score_timestep) + fake_score_pred_noise = current_fake_score_transformer( + **training_batch.input_kwargs + ).permute(0, 2, 1, 3, 4) + + target = fake_score_noise - generator_pred_video + flow_matching_loss = torch.mean((fake_score_pred_noise - target)**2) + + training_batch.fake_score_latent_vis_dict = { + "training_batch_fakerscore_fwd_clean_latent": + training_batch.latents, + "generator_pred_video": generator_pred_video, + "fake_score_timestep": fake_score_timestep, + } + + return training_batch, flow_matching_loss + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + if isinstance(keyboard_cond, torch.Tensor): + keyboard_cond = keyboard_cond.detach().clone().to(dtype=torch.bfloat16) + else: + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + if isinstance(mouse_cond, torch.Tensor): + mouse_cond = mouse_cond.detach().clone().to(dtype=torch.bfloat16) + else: + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames for WanGame. + + Draws keyboard (WASD) and mouse (pitch/yaw) indicators on the video frames. + """ + # Check if action data is available + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + # Import overlay functions + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + # Convert tensors to numpy if needed (bfloat16 -> float32 -> numpy) + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() # (T, 6) + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() # (T, 2) + + # WanGame convention: keyboard [W, S, A, D, left, right], mouse [Pitch, Yaw] + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + # Draw keyboard overlay + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + # Draw mouse overlay + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + +def main(args) -> None: + logger.info("Starting WanGame self-forcing distillation pipeline...") + + pipeline = WanGameSelfForcingDistillationPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + + args = pipeline.training_args + pipeline.train() + logger.info("WanGame self-forcing distillation pipeline completed") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/fastvideo/training/wangame_training_pipeline.py b/fastvideo/training/wangame_training_pipeline.py new file mode 100644 index 000000000..68c626692 --- /dev/null +++ b/fastvideo/training/wangame_training_pipeline.py @@ -0,0 +1,542 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from typing import Any + +import numpy as np +import torch + +from fastvideo.configs.sample import SamplingParam +from fastvideo.dataset.dataloader.schema import pyarrow_schema_wangame +from fastvideo.distributed import get_local_torch_device +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler) +from fastvideo.pipelines.basic.wan.wangame_i2v_pipeline import WanGameActionImageToVideoPipeline +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch, TrainingBatch +from fastvideo.training.training_pipeline import TrainingPipeline +from fastvideo.training.training_utils import count_trainable, count_trainable_total +from fastvideo.utils import is_vsa_available, shallow_asdict + +vsa_available = is_vsa_available() + +logger = init_logger(__name__) + + +class WanGameTrainingPipeline(TrainingPipeline): + """ + A training pipeline for WanGame-2.1-Fun-1.3B-InP. + """ + _required_config_modules = ["scheduler", "transformer", "vae"] + + _FLOW_EVAL_SCALAR_KEYS = ( + "mf_epe_mean", + "mf_angle_err_mean", + "mf_cosine_mean", + "mf_mag_ratio_mean", + "pixel_epe_mean_mean", + "px_angle_rmse_mean", + "fl_all_mean", + "foe_dist_mean", + "flow_kl_2d_mean", + ) + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=fastvideo_args.pipeline_config.flow_shift) + + def create_training_stages(self, training_args: TrainingArgs): + """ + May be used in future refactors. + """ + pass + + def set_schemas(self): + self.train_dataset_schema = pyarrow_schema_wangame + + def set_trainable(self) -> None: + """ + Override to only train newly added action-related parameters: + - condition_embedder.action_embedder: embeds action into timestep + - blocks.*.to_out_prope: projects PRoPE attention output + + This freezes the base model (q/k/v projections, FFN, etc.) while + allowing the action-conditioning path to be trained. + """ + train_action_only = getattr(self.fastvideo_args, "train_action_only", + False) + + if not train_action_only: + # Default behavior: train all parameters + super().set_trainable() + return + + # Freeze all transformer parameters first + transformer = self.get_module("transformer") + transformer.train() + transformer.requires_grad_(False) + + # Define which parameter name patterns to train + action_param_patterns = [ + "condition_embedder.action_embedder", # Action embedding MLP + "to_out_prope", # PRoPE output projections in each block + ] + + # Enable gradients for action-related parameters only + trainable_count = 0 + frozen_count = 0 + for name, param in transformer.named_parameters(): + should_train = any(pattern in name + for pattern in action_param_patterns) + if should_train: + param.requires_grad_(True) + trainable_count += 1 + logger.info(f"Trainable: {name} ({param.numel()} params)") + else: + frozen_count += 1 + + logger.info( + f"Action-only training: {trainable_count} trainable param groups, " + f"{frozen_count} frozen param groups") + + # ── Action module warmup ────────────────────────────────────────────── + # For the first `action_warmup_steps`, action modules (action_embedder, + # to_out_prope) have requires_grad=False so the base model stabilizes + # first. After warmup the gradients are re-enabled. + + _ACTION_PARAM_PATTERNS = [ + "condition_embedder.action_embedder", + "to_out_prope", + ] + + def _set_action_params_grad(self, requires_grad: bool) -> None: + """Toggle requires_grad for action-related parameters.""" + transformer = self.get_module("transformer") + count = 0 + for name, param in transformer.named_parameters(): + if any(p in name for p in self._ACTION_PARAM_PATTERNS): + param.requires_grad_(requires_grad) + count += 1 + state = "enabled" if requires_grad else "disabled" + logger.info("Gradients %s for %d action parameter groups", state, count) + + def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: + step = training_batch.current_timestep + warmup_steps = self.training_args.action_warmup_steps + + if warmup_steps > 0: + if step == 1: + # Freeze action params at the very first step + self._set_action_params_grad(False) + local_trainable = count_trainable(self.transformer) + total_trainable = count_trainable_total( + self.transformer, get_local_torch_device()) + logger.info( + "Action warmup: freezing action modules for the first " + "%d steps to stabilize base model", warmup_steps) + logger.info( + "Trainable during warmup: %s B (total); this rank shard: %s B", + round(total_trainable / 1e9, 3), + round(local_trainable / 1e9, 3), + ) + elif step == warmup_steps + 1: + # Unfreeze action params once warmup is done + self._set_action_params_grad(True) + logger.info( + "Action warmup complete — action modules unfrozen at " + "step %d", step) + + return super().train_one_step(training_batch) + + def initialize_validation_pipeline(self, training_args: TrainingArgs): + logger.info("Initializing validation pipeline...") + # args_copy.pipeline_config.vae_config.load_encoder = False + # validation_pipeline = WanImageToVideoValidationPipeline.from_pretrained( + self.validation_pipeline = WanGameActionImageToVideoPipeline.from_pretrained( + training_args.model_path, + args=None, + inference_mode=True, + loaded_modules={ + "transformer": self.get_module("transformer"), + }, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + dit_cpu_offload=False) + + def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: + batch = next(self.train_loader_iter, None) # type: ignore + if batch is None: + self.current_epoch += 1 + logger.info("Starting epoch %s", self.current_epoch) + # Reshuffle dataset order each epoch + self.train_dataset.sampler.set_epoch(self.current_epoch) + # Reset iterator for next epoch + self.train_loader_iter = iter(self.train_dataloader) + # Get first batch of new epoch + batch = next(self.train_loader_iter) + + latents = batch['vae_latent'] + latents = latents[:, :, :self.training_args.num_latent_t] + # encoder_hidden_states = batch['text_embedding'] + # encoder_attention_mask = batch['text_attention_mask'] + clip_features = batch['clip_feature'] + image_latents = batch['first_frame_latent'] + image_latents = image_latents[:, :, :self.training_args.num_latent_t] + pil_image = batch['pil_image'] + infos = batch['info_list'] + + training_batch.latents = latents.to(get_local_torch_device(), + dtype=torch.bfloat16) + training_batch.encoder_hidden_states = None + training_batch.encoder_attention_mask = None + training_batch.preprocessed_image = pil_image.to( + get_local_torch_device()) + training_batch.image_embeds = clip_features.to(get_local_torch_device()) + training_batch.image_latents = image_latents.to( + get_local_torch_device()) + training_batch.infos = infos + + # Action conditioning + if 'mouse_cond' in batch and batch['mouse_cond'].numel() > 0: + training_batch.mouse_cond = batch['mouse_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.mouse_cond = None + + if 'keyboard_cond' in batch and batch['keyboard_cond'].numel() > 0: + training_batch.keyboard_cond = batch['keyboard_cond'].to( + get_local_torch_device(), dtype=torch.bfloat16) + else: + training_batch.keyboard_cond = None + + # Validate action temporal dimensions match video num_frames + expected_num_frames = (self.training_args.num_latent_t - 1) * 4 + 1 + if training_batch.keyboard_cond is not None: + assert training_batch.keyboard_cond.shape[1] == expected_num_frames, ( + f"keyboard_cond temporal dim {training_batch.keyboard_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + if training_batch.mouse_cond is not None: + assert training_batch.mouse_cond.shape[1] == expected_num_frames, ( + f"mouse_cond temporal dim {training_batch.mouse_cond.shape[1]} " + f"!= expected {expected_num_frames} " + f"(num_latent_t={self.training_args.num_latent_t})") + + return training_batch + + def _prepare_dit_inputs(self, + training_batch: TrainingBatch) -> TrainingBatch: + """Override to properly handle I2V concatenation - call parent first, then concatenate image conditioning.""" + + # First, call parent method to prepare noise, timesteps, etc. for video latents + training_batch = super()._prepare_dit_inputs(training_batch) + + assert isinstance(training_batch.image_latents, torch.Tensor) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) + + temporal_compression_ratio = self.training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (self.training_args.num_latent_t - + 1) * temporal_compression_ratio + 1 + batch_size, num_channels, _, latent_height, latent_width = image_latents.shape + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + latent_width) + mask_lat_size[:, :, 1:] = 0 + + first_frame_mask = mask_lat_size[:, :, :1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=temporal_compression_ratio) + mask_lat_size = torch.cat([first_frame_mask, mask_lat_size[:, :, 1:]], + dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, + temporal_compression_ratio, + latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to( + image_latents.device).to(dtype=torch.bfloat16) + + training_batch.noisy_model_input = torch.cat( + [training_batch.noisy_model_input, mask_lat_size, image_latents], + dim=1) + + return training_batch + + def _build_input_kwargs(self, + training_batch: TrainingBatch) -> TrainingBatch: + + # Image Embeds for conditioning + image_embeds = training_batch.image_embeds + assert torch.isnan(image_embeds).sum() == 0 + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) + encoder_hidden_states_image = image_embeds + + from fastvideo.models.dits.hyworld.pose import process_custom_actions + + # Process actions for each batch sample + batch_size = training_batch.noisy_model_input.shape[0] + viewmats_list, intrinsics_list, action_labels_list = [], [], [] + for b in range(batch_size): + v, i, a = process_custom_actions(training_batch.keyboard_cond[b], + training_batch.mouse_cond[b]) + viewmats_list.append(v) + intrinsics_list.append(i) + action_labels_list.append(a) + viewmats = torch.stack(viewmats_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + intrinsics = torch.stack(intrinsics_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + action_labels = torch.stack(action_labels_list, + dim=0).to(get_local_torch_device(), + dtype=torch.bfloat16) + + # Validate processed action latent dim matches video latent dim + num_latent_t = training_batch.noisy_model_input.shape[2] + assert action_labels.shape[1] == num_latent_t, ( + f"action_labels temporal dim {action_labels.shape[1]} != " + f"video latent temporal dim {num_latent_t}") + assert viewmats.shape[1] == num_latent_t, ( + f"viewmats temporal dim {viewmats.shape[1]} != " + f"video latent temporal dim {num_latent_t}") + + # NOTE: noisy_model_input already contains concatenated image_latents from _prepare_dit_inputs + training_batch.input_kwargs = { + "hidden_states": + training_batch.noisy_model_input, + "encoder_hidden_states": + training_batch.encoder_hidden_states, # None (no text conditioning) + "timestep": + training_batch.timesteps.to(get_local_torch_device(), + dtype=torch.bfloat16), + # "encoder_attention_mask": + # training_batch.encoder_attention_mask, + "encoder_hidden_states_image": + encoder_hidden_states_image, + # Action conditioning + "viewmats": + viewmats, + "Ks": + intrinsics, + "action": + action_labels, + "return_dict": + False, + } + return training_batch + + def _prepare_validation_batch(self, sampling_param: SamplingParam, + training_args: TrainingArgs, + validation_batch: dict[str, Any], + num_inference_steps: int) -> ForwardBatch: + sampling_param.prompt = validation_batch['prompt'] + sampling_param.height = training_args.num_height + sampling_param.width = training_args.num_width + sampling_param.image_path = validation_batch.get( + 'image_path') or validation_batch.get('video_path') + sampling_param.num_inference_steps = num_inference_steps + sampling_param.data_type = "video" + assert self.seed is not None + sampling_param.seed = self.seed + + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + num_frames = (training_args.num_latent_t - + 1) * temporal_compression_factor + 1 + sampling_param.num_frames = num_frames + batch = ForwardBatch( + **shallow_asdict(sampling_param), + latents=None, + generator=torch.Generator(device="cpu").manual_seed(self.seed), + n_tokens=n_tokens, + eta=0.0, + VSA_sparsity=training_args.VSA_sparsity, + ) + if "image" in validation_batch and validation_batch["image"] is not None: + batch.pil_image = validation_batch["image"] + + if "keyboard_cond" in validation_batch and validation_batch[ + "keyboard_cond"] is not None: + keyboard_cond = validation_batch["keyboard_cond"] + keyboard_cond = torch.tensor(keyboard_cond, dtype=torch.bfloat16) + keyboard_cond = keyboard_cond.unsqueeze(0) + batch.keyboard_cond = keyboard_cond + + if "mouse_cond" in validation_batch and validation_batch[ + "mouse_cond"] is not None: + mouse_cond = validation_batch["mouse_cond"] + mouse_cond = torch.tensor(mouse_cond, dtype=torch.bfloat16) + mouse_cond = mouse_cond.unsqueeze(0) + batch.mouse_cond = mouse_cond + + return batch + + def _post_process_validation_frames( + self, frames: list[np.ndarray], + batch: ForwardBatch) -> list[np.ndarray]: + """Apply action overlay to validation frames for WanGame. + + Draws keyboard (WASD) and mouse (pitch/yaw) indicators on the video frames. + """ + # Check if action data is available + keyboard_cond = getattr(batch, 'keyboard_cond', None) + mouse_cond = getattr(batch, 'mouse_cond', None) + + if keyboard_cond is None and mouse_cond is None: + return frames + + # Import overlay functions + from fastvideo.models.dits.matrixgame.utils import (draw_keys_on_frame, + draw_mouse_on_frame) + + # Convert tensors to numpy if needed (bfloat16 -> float32 -> numpy) + if keyboard_cond is not None: + keyboard_cond = keyboard_cond.squeeze( + 0).cpu().float().numpy() # (T, 6) + if mouse_cond is not None: + mouse_cond = mouse_cond.squeeze(0).cpu().float().numpy() # (T, 2) + + # MatrixGame convention: keyboard [W, S, A, D, left, right], mouse [Pitch, Yaw] + key_names = ["W", "S", "A", "D", "left", "right"] + + processed_frames = [] + for frame_idx, frame in enumerate(frames): + frame = np.ascontiguousarray(frame.copy()) + + # Draw keyboard overlay + if keyboard_cond is not None and frame_idx < len(keyboard_cond): + keys = { + key_names[i]: bool(keyboard_cond[frame_idx, i]) + for i in range(min(len(key_names), keyboard_cond.shape[1])) + } + draw_keys_on_frame(frame, keys, mode='universal') + + # Draw mouse overlay + if mouse_cond is not None and frame_idx < len(mouse_cond): + pitch = float(mouse_cond[frame_idx, 0]) + yaw = float(mouse_cond[frame_idx, 1]) + draw_mouse_on_frame(frame, pitch, yaw) + + processed_frames.append(frame) + + return processed_frames + + def _init_flow_eval_module(self) -> None: + if getattr(self, "_flow_eval_init_done", False): + return + self._flow_eval_init_done = True + self._flow_eval_ready = False + + ptlflow_dir = Path("/mnt/weka/home/hao.zhang/mhuo/FastVideo/benchmarks/ptlflow") + + try: + ptlflow_dir_str = str(ptlflow_dir.resolve()) + if ptlflow_dir_str not in sys.path: + sys.path.insert(0, ptlflow_dir_str) + + from eval_flow_divergence import evaluate_pair_synthetic # type: ignore + + self._flow_eval_fn = evaluate_pair_synthetic + self._flow_eval_ckpt = str(ptlflow_dir / "dpflow-things-2012b5d6.ckpt") + self._flow_eval_calibration_path = str(ptlflow_dir / + "calibration.json") + self._flow_eval_ready = True + logger.info("Initialized flow divergence evaluator: %s", + ptlflow_dir) + except Exception as e: + logger.warning("Failed to initialize flow divergence evaluator: %s", + e) + + def _evaluate_validation_video( + self, + video_path: str, + caption: str, + action_path: str | None, + global_step: int, + num_inference_steps: int, + ) -> dict[str, float]: + del caption + self._init_flow_eval_module() + if not getattr(self, "_flow_eval_ready", False): + raise RuntimeError( + "ptlflow evaluator is not initialized; cannot compute flow metrics." + ) + + if not isinstance(action_path, str) or not os.path.isfile(action_path): + raise FileNotFoundError( + f"Validation sample is missing a valid action_path: {action_path}" + ) + + eval_output_dir = os.path.join( + self.training_args.output_dir, + "flow_eval", + f"step_{global_step}", + f"inference_steps_{num_inference_steps}", + Path(video_path).stem, + ) + + try: + summary = self._flow_eval_fn( + gen_video=video_path, + action_path=action_path, + calibration_path=self._flow_eval_calibration_path, + output_dir=eval_output_dir, + model_name="dpflow", + ckpt=self._flow_eval_ckpt, + no_viz=True, + use_depth=True, + ) + except Exception as e: + raise RuntimeError( + f"ptlflow synthetic evaluation failed for {video_path}") from e + + if not isinstance(summary, dict): + raise RuntimeError( + f"ptlflow returned invalid summary type: {type(summary)}" + ) + + metrics: dict[str, float] = {} + missing_or_invalid: list[str] = [] + for key in self._FLOW_EVAL_SCALAR_KEYS: + val = summary.get(key) + if not isinstance(val, (float, int, np.floating, np.integer)): + missing_or_invalid.append(key) + continue + val_float = float(val) + if not np.isfinite(val_float): + missing_or_invalid.append(key) + continue + metrics[key] = val_float + + if missing_or_invalid: + raise RuntimeError( + "ptlflow summary missing/invalid metrics: " + f"{', '.join(missing_or_invalid)}") + + return metrics + + +def main(args) -> None: + logger.info("Starting training pipeline...") + + pipeline = WanGameTrainingPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + args = pipeline.training_args + pipeline.train() + logger.info("Training pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.dit_cpu_offload = False + main(args) diff --git a/fastvideo/utils.py b/fastvideo/utils.py index d3efd69c8..d722f4054 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -935,7 +935,7 @@ def save_decoded_latents_as_video(decoded_latents: list[torch.Tensor], for x in videos: x = make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) - frames.append((x * 255).numpy().astype(np.uint8)) + frames.append((x * 255).cpu().numpy().astype(np.uint8)) os.makedirs(os.path.dirname(output_path), exist_ok=True) imageio.mimsave(output_path, frames, fps=fps, format="mp4") diff --git a/visualize_trajectory.py b/visualize_trajectory.py new file mode 100644 index 000000000..a0eafd048 --- /dev/null +++ b/visualize_trajectory.py @@ -0,0 +1,224 @@ +import argparse +import os +import numpy as np +import pyarrow.parquet as pq +import torch +from tqdm import tqdm + +from fastvideo import PipelineConfig +from fastvideo.configs.models.vaes import WanVAEConfig +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.models.loader.component_loader import VAELoader +from fastvideo.utils import maybe_download_model, save_decoded_latents_as_video + + +def _torch_dtype_from_precision(precision: str) -> torch.dtype: + precision = precision.lower() + if precision == "fp32": + return torch.float32 + if precision == "fp16": + return torch.float16 + if precision == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported precision: {precision}") + + +def _denormalize_latents_for_vae(vae, latents: torch.Tensor) -> torch.Tensor: + if bool(getattr(vae, "handles_latent_denorm", False)): + return latents + + cfg = getattr(vae, "config", None) + + if cfg is not None and hasattr(cfg, "latents_mean") and hasattr( + cfg, "latents_std"): + latents_mean = torch.tensor(cfg.latents_mean, + device=latents.device, + dtype=latents.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(cfg.latents_std, + device=latents.device, + dtype=latents.dtype).view(1, -1, 1, 1, 1) + return latents * latents_std + latents_mean + + if hasattr(vae, "scaling_factor"): + if isinstance(vae.scaling_factor, torch.Tensor): + latents = latents / vae.scaling_factor.to(latents.device, + latents.dtype) + else: + latents = latents / vae.scaling_factor + + if hasattr(vae, "shift_factor") and vae.shift_factor is not None: + if isinstance(vae.shift_factor, torch.Tensor): + latents = latents + vae.shift_factor.to(latents.device, + latents.dtype) + else: + latents = latents + vae.shift_factor + + return latents + + +@torch.inference_mode() +def _decode_with_vae(vae, latents: torch.Tensor, *, device: torch.device, + precision: str) -> torch.Tensor: + latents = latents.to(device=device) + target_dtype = _torch_dtype_from_precision(precision) + latents = latents.to(dtype=target_dtype) + + latents = _denormalize_latents_for_vae(vae, latents) + + use_autocast = (device.type == "cuda" and target_dtype != torch.float32) + with torch.autocast(device_type=device.type, + dtype=target_dtype, + enabled=use_autocast): + decoded = vae.decode(latents) + + return (decoded / 2 + 0.5).clamp(0, 1) + + +def main(): + parser = argparse.ArgumentParser( + description="Visualize Trajectory from Parquet file") + parser.add_argument("--parquet_path", + type=str, + required=True, + help="Path to the input parquet file") + parser.add_argument("--model_path", + type=str, + required=True, + help="Path to the model directory") + parser.add_argument("--output_dir", + type=str, + default="visualizations", + help="Directory to save output videos") + parser.add_argument("--num_samples", + type=int, + default=1, + help="Number of samples to visualize") + parser.add_argument("--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--vae_precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16"], + help="Precision for VAE decoding") + parser.add_argument("--vae_subfolder", + type=str, + default="vae", + help="Subfolder name containing VAE weights/config") + parser.add_argument("--fps", type=int, default=25, help="Output video FPS") + parser.add_argument( + "--decode_steps", + type=str, + default="last", + help= + "Which trajectory steps to decode: 'last', 'all', or comma-separated indices (e.g. '0,10,20')", + ) + + args = parser.parse_args() + + device = torch.device(args.device) + print(f"Using device: {device}, vae_precision: {args.vae_precision}") + + os.makedirs(args.output_dir, exist_ok=True) + + # Load VAE (must load weights; creating AutoencoderKLWan(config) alone leaves random weights) + print(f"Loading model from {args.model_path}...") + model_root = maybe_download_model(args.model_path) + pipeline_config = PipelineConfig.from_pretrained(model_root) + pipeline_config.update_config_from_dict({ + "vae_precision": + args.vae_precision, + "vae_config": + WanVAEConfig(load_encoder=False, load_decoder=True), + }) + fastvideo_args = FastVideoArgs( + model_path=model_root, + num_gpus=1, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pipeline_config=pipeline_config, + ) + + vae_path = os.path.join(model_root, args.vae_subfolder) + vae = VAELoader().load(vae_path, fastvideo_args) + vae.to(device) + + # Read Parquet + print(f"Reading parquet file: {args.parquet_path}") + table = pq.read_table(args.parquet_path) + + # Iterate over rows + num_visualized = 0 + + pbar = tqdm(total=min(args.num_samples, table.num_rows)) + + for i in range(table.num_rows): + if num_visualized >= args.num_samples: + break + + row = table.slice(i, length=1) + record = row.to_pydict() + + video_id = record["id"][0] + + # Parse Latents + shape = record["trajectory_latents_shape"][0] + dtype = record["trajectory_latents_dtype"][0] + dtype = np.dtype(dtype) + + latents_bytes = record["trajectory_latents_bytes"][0] + # Copy to avoid read-only warning + latents_np = np.copy( + np.frombuffer(latents_bytes, dtype=dtype).reshape(shape)) + + latents_tensor = torch.from_numpy(latents_np) + if latents_tensor.ndim == 6 and latents_tensor.shape[0] == 1: + latents_tensor = latents_tensor.squeeze(0) + + print(f"Decoding video {video_id} with shape {latents_tensor.shape}...") + + # create subfolder + vid_output_dir = os.path.join(args.output_dir, str(video_id)) + os.makedirs(vid_output_dir, exist_ok=True) + + # Pick steps to decode + steps = latents_tensor.shape[0] + if args.decode_steps == "last": + indices_to_decode = [steps - 1] + elif args.decode_steps == "all": + indices_to_decode = list(range(steps)) + else: + indices_to_decode = [ + int(x) for x in args.decode_steps.split(",") if x.strip() != "" + ] + indices_to_decode = [i for i in indices_to_decode if 0 <= i < steps] + if not indices_to_decode: + raise ValueError( + f"No valid indices selected for decode_steps='{args.decode_steps}' with steps={steps}" + ) + + for step in tqdm(indices_to_decode, + desc=f"Decoding {video_id}", + leave=False): + latent_step = latents_tensor[step].unsqueeze(0) # [1, C, T, H, W] + + decoded_video = _decode_with_vae(vae, + latent_step, + device=device, + precision=args.vae_precision) + + save_path = os.path.join(vid_output_dir, f"step_{step:03d}.mp4") + save_decoded_latents_as_video(decoded_video.float(), + save_path, + fps=args.fps) + + print(f"Saved {len(indices_to_decode)} step(s) to {vid_output_dir}") + num_visualized += 1 + pbar.update(1) + + pbar.close() + + +if __name__ == "__main__": + main()