Skip to content

Commit b3ecbb6

Browse files
committed
support mova train
1 parent 4a9391d commit b3ecbb6

13 files changed

Lines changed: 599 additions & 30 deletions

File tree

diffsynth/diffusion/base_pipeline.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,6 @@ def output_audio_format_check(self, audio_output):
152152
# remove batch dim
153153
if audio_output.ndim == 3:
154154
audio_output = audio_output.squeeze(0)
155-
# Transform to stereo
156-
if audio_output.shape[0] == 1:
157-
audio_output = audio_output.repeat(2, 1)
158-
elif audio_output.shape[0] == 2:
159-
pass
160-
else:
161-
raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].")
162155
return audio_output.float()
163156

164157
def load_models_to_device(self, model_names):

diffsynth/pipelines/mova_audio_video.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..models.mova_audio_dit import MovaAudioDit
2020
from ..models.mova_audio_vae import DacVAE
2121
from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge
22+
from ..utils.data.audio import convert_to_mono, resample_waveform
2223

2324

2425
class MovaAudioVideoPipeline(BasePipeline):
@@ -81,12 +82,16 @@ def from_pretrained(
8182

8283
# Fetch models
8384
pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder")
84-
pipe.video_dit, pipe.video_dit2 = model_pool.fetch_model("wan_video_dit", index=2)
85+
dit = model_pool.fetch_model("wan_video_dit", index=2)
86+
if isinstance(dit, list):
87+
pipe.video_dit, pipe.video_dit2 = dit
88+
else:
89+
pipe.video_dit = dit
8590
pipe.audio_dit = model_pool.fetch_model("mova_audio_dit")
8691
pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge")
8792
pipe.video_vae = model_pool.fetch_model("wan_video_vae")
8893
pipe.audio_vae = model_pool.fetch_model("mova_audio_vae")
89-
set_to_torch_norm([pipe.video_dit, pipe.video_dit2, pipe.audio_dit, pipe.dual_tower_bridge])
94+
set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else []))
9095

9196
# Size division factor
9297
if pipe.video_vae is not None:
@@ -185,7 +190,8 @@ def __call__(
185190
video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
186191
video = self.vae_output_to_video(video)
187192
self.load_models_to_device(["audio_vae"])
188-
audio = self.audio_vae.decode(inputs_shared["audio_latents"]).to(dtype=torch.float32, device='cpu').squeeze()
193+
audio = self.audio_vae.decode(inputs_shared["audio_latents"])
194+
audio = self.output_audio_format_check(audio)
189195
self.load_models_to_device([])
190196
return video, audio
191197

@@ -229,36 +235,33 @@ def __init__(self):
229235
)
230236

231237
def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride):
232-
if input_video is None:
238+
if input_video is None or not pipe.scheduler.training:
233239
return {"video_latents": video_noise}
234-
# TODO: check for train
235-
pipe.load_models_to_device(self.onload_model_names)
236-
input_video = pipe.preprocess_video(input_video)
237-
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
238-
if pipe.scheduler.training:
239-
return {"latents": video_noise, "input_latents": input_latents}
240240
else:
241-
latents = pipe.scheduler.add_noise(input_latents, video_noise, timestep=pipe.scheduler.timesteps[0])
242-
return {"latents": latents}
241+
pipe.load_models_to_device(self.onload_model_names)
242+
input_video = pipe.preprocess_video(input_video)
243+
input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
244+
return {"input_latents": input_latents}
243245

244246

245247
class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit):
246248
def __init__(self):
247249
super().__init__(
248250
input_params=("input_audio", "audio_noise"),
249251
output_params=("audio_latents", "audio_input_latents"),
250-
onload_model_names=("audio_vae_encoder",)
252+
onload_model_names=("audio_vae",)
251253
)
252254

253255
def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise):
254-
if input_audio is None:
256+
if input_audio is None or not pipe.scheduler.training:
255257
return {"audio_latents": audio_noise}
256258
else:
257-
# TODO: support audio training
258-
if pipe.scheduler.training:
259-
return {"audio_latents": audio_noise, "audio_input_latents": audio_noise}
260-
else:
261-
raise NotImplementedError("Audio-to-video not supported.")
259+
input_audio, sample_rate = input_audio
260+
input_audio = convert_to_mono(input_audio)
261+
input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate)
262+
input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate)
263+
z, _, _, _, _ = pipe.audio_vae.encode(input_audio)
264+
return {"audio_input_latents": z.mode()}
262265

263266

264267
class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit):
@@ -329,15 +332,16 @@ def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_fram
329332
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
330333
return {"y": y}
331334

335+
332336
class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit):
333337
def __init__(self):
334338
super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",))
335339

336340
def process(self, pipe: MovaAudioVideoPipeline):
337-
if hasattr(pipe, "use_unified_sequence_parallel"):
338-
if pipe.use_unified_sequence_parallel:
339-
return {"use_unified_sequence_parallel": True}
340-
return {}
341+
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
342+
return {"use_unified_sequence_parallel": True}
343+
return {"use_unified_sequence_parallel": False}
344+
341345

342346
def model_fn_mova_audio_video(
343347
video_dit: WanModel,
File renamed without changes.
File renamed without changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
2+
--dataset_base_path data/example_video_dataset/ltx2 \
3+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
4+
--data_file_keys "video,input_audio" \
5+
--extra_inputs "input_audio,input_image" \
6+
--height 352 \
7+
--width 640 \
8+
--num_frames 121 \
9+
--dataset_repeat 100 \
10+
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.video_dit." \
14+
--output_path "./models/train/MOVA-360p-I2AV_high_noise_full" \
15+
--trainable_models "dit" \
16+
--max_timestep_boundary 0.358 \
17+
--min_timestep_boundary 0 \
18+
--use_gradient_checkpointing
19+
# boundary corresponds to timesteps [900, 1000]
20+
21+
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
22+
--dataset_base_path data/example_video_dataset/ltx2 \
23+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
24+
--data_file_keys "video,input_audio" \
25+
--extra_inputs "input_audio,input_image" \
26+
--height 352 \
27+
--width 640 \
28+
--num_frames 121 \
29+
--dataset_repeat 100 \
30+
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
31+
--learning_rate 1e-4 \
32+
--num_epochs 5 \
33+
--remove_prefix_in_ckpt "pipe.video_dit." \
34+
--output_path "./models/train/MOVA-360p-I2AV_low_noise_full" \
35+
--trainable_models "dit" \
36+
--max_timestep_boundary 1 \
37+
--min_timestep_boundary 0.358 \
38+
--use_gradient_checkpointing
39+
# boundary corresponds to timesteps [0, 900)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
2+
--dataset_base_path data/example_video_dataset/ltx2 \
3+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
4+
--data_file_keys "video,input_audio" \
5+
--extra_inputs "input_audio,input_image" \
6+
--height 720 \
7+
--width 1280 \
8+
--num_frames 121 \
9+
--dataset_repeat 100 \
10+
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.video_dit." \
14+
--output_path "./models/train/MOVA-720p-I2AV_high_noise_full" \
15+
--trainable_models "dit" \
16+
--max_timestep_boundary 0.358 \
17+
--min_timestep_boundary 0 \
18+
--use_gradient_checkpointing
19+
# boundary corresponds to timesteps [900, 1000]
20+
21+
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/mova/model_training/train.py \
22+
--dataset_base_path data/example_video_dataset/ltx2 \
23+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
24+
--data_file_keys "video,input_audio" \
25+
--extra_inputs "input_audio,input_image" \
26+
--height 720 \
27+
--width 1280 \
28+
--num_frames 121 \
29+
--dataset_repeat 100 \
30+
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
31+
--learning_rate 1e-4 \
32+
--num_epochs 5 \
33+
--remove_prefix_in_ckpt "pipe.video_dit." \
34+
--output_path "./models/train/MOVA-720p-I2AV_low_noise_full" \
35+
--trainable_models "dit" \
36+
--max_timestep_boundary 1 \
37+
--min_timestep_boundary 0.358 \
38+
--use_gradient_checkpointing
39+
# boundary corresponds to timesteps [0, 900)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
accelerate launch examples/mova/model_training/train.py \
2+
--dataset_base_path data/example_video_dataset/ltx2 \
3+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
4+
--data_file_keys "video,input_audio" \
5+
--extra_inputs "input_audio,input_image" \
6+
--height 352 \
7+
--width 640 \
8+
--num_frames 121 \
9+
--dataset_repeat 100 \
10+
--model_id_with_origin_paths "openmoss/MOVA-360p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.video_dit." \
14+
--output_path "./models/train/MOVA-360p-I2AV_high_noise_lora" \
15+
--lora_base_model "video_dit" \
16+
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
17+
--lora_rank 32 \
18+
--max_timestep_boundary 0.358 \
19+
--min_timestep_boundary 0 \
20+
--use_gradient_checkpointing
21+
# boundary corresponds to timesteps [900, 1000]
22+
23+
# accelerate launch examples/mova/model_training/train.py \
24+
# --dataset_base_path data/example_video_dataset/ltx2 \
25+
# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
26+
# --data_file_keys "video,input_audio" \
27+
# --extra_inputs "input_audio,input_image" \
28+
# --height 352 \
29+
# --width 640 \
30+
# --num_frames 121 \
31+
# --dataset_repeat 100 \
32+
# --model_id_with_origin_paths "openmoss/MOVA-360p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-360p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-360p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
33+
# --learning_rate 1e-4 \
34+
# --num_epochs 5 \
35+
# --remove_prefix_in_ckpt "pipe.video_dit." \
36+
# --output_path "./models/train/MOVA-360p-I2AV_low_noise_lora" \
37+
# --lora_base_model "video_dit" \
38+
# --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
39+
# --lora_rank 32 \
40+
# --max_timestep_boundary 1 \
41+
# --min_timestep_boundary 0.358 \
42+
# --use_gradient_checkpointing
43+
# boundary corresponds to timesteps [0, 900)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
accelerate launch examples/mova/model_training/train.py \
2+
--dataset_base_path data/example_video_dataset/ltx2 \
3+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
4+
--data_file_keys "video,input_audio" \
5+
--extra_inputs "input_audio,input_image" \
6+
--height 720 \
7+
--width 1280 \
8+
--num_frames 121 \
9+
--dataset_repeat 100 \
10+
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
11+
--learning_rate 1e-4 \
12+
--num_epochs 5 \
13+
--remove_prefix_in_ckpt "pipe.video_dit." \
14+
--output_path "./models/train/MOVA-720p-I2AV_high_noise_lora" \
15+
--lora_base_model "video_dit" \
16+
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
17+
--lora_rank 32 \
18+
--max_timestep_boundary 0.358 \
19+
--min_timestep_boundary 0 \
20+
--use_gradient_checkpointing
21+
# boundary corresponds to timesteps [900, 1000]
22+
23+
accelerate launch examples/mova/model_training/train.py \
24+
--dataset_base_path data/example_video_dataset/ltx2 \
25+
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
26+
--data_file_keys "video,input_audio" \
27+
--extra_inputs "input_audio,input_image" \
28+
--height 720 \
29+
--width 1280 \
30+
--num_frames 121 \
31+
--dataset_repeat 100 \
32+
--model_id_with_origin_paths "openmoss/MOVA-720p:video_dit_2/diffusion_pytorch_model-*.safetensors,openmoss/MOVA-720p:audio_dit/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:dual_tower_bridge/diffusion_pytorch_model.safetensors,openmoss/MOVA-720p:audio_vae/diffusion_pytorch_model.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:Wan2.1_VAE.safetensors,DiffSynth-Studio/Wan-Series-Converted-Safetensors:models_t5_umt5-xxl-enc-bf16.safetensors" \
33+
--learning_rate 1e-4 \
34+
--num_epochs 5 \
35+
--remove_prefix_in_ckpt "pipe.video_dit." \
36+
--output_path "./models/train/MOVA-720p-I2AV_low_noise_lora" \
37+
--lora_base_model "video_dit" \
38+
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
39+
--lora_rank 32 \
40+
--max_timestep_boundary 1 \
41+
--min_timestep_boundary 0.358 \
42+
--use_gradient_checkpointing
43+
# boundary corresponds to timesteps [0, 900)

0 commit comments

Comments
 (0)