-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathv1_distill_dmd_wan.sh
More file actions
59 lines (58 loc) · 2.02 KB
/
v1_distill_dmd_wan.sh
File metadata and controls
59 lines (58 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=offline
export WANDB_API_KEY=
export TRITON_CACHE_DIR=/tmp/triton_cache
DATA_DIR=mini_i2v_dataset/crush-smol_preprocessed/combined_parquet_dataset/
VALIDATION_DIR=mini_i2v_dataset/crush-smol_raw/validation.json
NUM_GPUS=8
export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN
export TOKENIZERS_PARALLELISM=false
# make sure that num_latent_t is a multiple of sp_size
torchrun --nnodes 1 --nproc_per_node $NUM_GPUS \
fastvideo/training/wan_distillation_pipeline.py \
--model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--real_score_model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--fake_score_model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--inference_mode False\
--pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--cache_dir "/home/ray/.cache" \
--data_path "$DATA_DIR" \
--validation_dataset_file "$VALIDATION_DIR" \
--train_batch_size 1 \
--num_latent_t 20 \
--sp_size 1 \
--tp_size 1 \
--num_gpus $NUM_GPUS \
--hsdp_replicate_dim $NUM_GPUS \
--hsdp-shard-dim 1 \
--train_sp_batch_size 1 \
--dataloader_num_workers 0 \
--gradient_accumulation_steps 8 \
--max_train_steps 30000 \
--learning_rate 2e-6 \
--mixed_precision "bf16" \
--training_state_checkpointing_steps 400 \
--validation_steps 100 \
--validation_sampling_steps "3" \
--log_validation \
--checkpoints_total_limit 3 \
--ema_start_step 0 \
--training_cfg_rate 0.0 \
--output_dir "outputs_dmd/wan_finetune" \
--tracker_project_name Wan_distillation \
--num_height 448 \
--num_width 832 \
--num_frames 77 \
--flow_shift 8 \
--validation_guidance_scale "6.0" \
--master_weight_type "fp32" \
--dit_precision "fp32" \
--vae_precision "bf16" \
--weight_decay 0.01 \
--max_grad_norm 1.0 \
--generator_update_interval 5 \
--dmd_denoising_steps '1000,757,522' \
--min_timestep_ratio 0.02 \
--max_timestep_ratio 0.98 \
--real_score_guidance_scale 3.5 \
--seed 1024