|
19 | 19 | from ..models.mova_audio_dit import MovaAudioDit |
20 | 20 | from ..models.mova_audio_vae import DacVAE |
21 | 21 | from ..models.mova_dual_tower_bridge import DualTowerConditionalBridge |
| 22 | +from ..utils.data.audio import convert_to_mono, resample_waveform |
22 | 23 |
|
23 | 24 |
|
24 | 25 | class MovaAudioVideoPipeline(BasePipeline): |
@@ -81,12 +82,16 @@ def from_pretrained( |
81 | 82 |
|
82 | 83 | # Fetch models |
83 | 84 | pipe.text_encoder = model_pool.fetch_model("wan_video_text_encoder") |
84 | | - pipe.video_dit, pipe.video_dit2 = model_pool.fetch_model("wan_video_dit", index=2) |
| 85 | + dit = model_pool.fetch_model("wan_video_dit", index=2) |
| 86 | + if isinstance(dit, list): |
| 87 | + pipe.video_dit, pipe.video_dit2 = dit |
| 88 | + else: |
| 89 | + pipe.video_dit = dit |
85 | 90 | pipe.audio_dit = model_pool.fetch_model("mova_audio_dit") |
86 | 91 | pipe.dual_tower_bridge = model_pool.fetch_model("mova_dual_tower_bridge") |
87 | 92 | pipe.video_vae = model_pool.fetch_model("wan_video_vae") |
88 | 93 | pipe.audio_vae = model_pool.fetch_model("mova_audio_vae") |
89 | | - set_to_torch_norm([pipe.video_dit, pipe.video_dit2, pipe.audio_dit, pipe.dual_tower_bridge]) |
| 94 | + set_to_torch_norm([pipe.video_dit, pipe.audio_dit, pipe.dual_tower_bridge] + ([pipe.video_dit2] if pipe.video_dit2 is not None else [])) |
90 | 95 |
|
91 | 96 | # Size division factor |
92 | 97 | if pipe.video_vae is not None: |
@@ -185,7 +190,8 @@ def __call__( |
185 | 190 | video = self.video_vae.decode(inputs_shared["video_latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) |
186 | 191 | video = self.vae_output_to_video(video) |
187 | 192 | self.load_models_to_device(["audio_vae"]) |
188 | | - audio = self.audio_vae.decode(inputs_shared["audio_latents"]).to(dtype=torch.float32, device='cpu').squeeze() |
| 193 | + audio = self.audio_vae.decode(inputs_shared["audio_latents"]) |
| 194 | + audio = self.output_audio_format_check(audio) |
189 | 195 | self.load_models_to_device([]) |
190 | 196 | return video, audio |
191 | 197 |
|
@@ -229,36 +235,33 @@ def __init__(self): |
229 | 235 | ) |
230 | 236 |
|
231 | 237 | def process(self, pipe: MovaAudioVideoPipeline, input_video, video_noise, tiled, tile_size, tile_stride): |
232 | | - if input_video is None: |
| 238 | + if input_video is None or not pipe.scheduler.training: |
233 | 239 | return {"video_latents": video_noise} |
234 | | - # TODO: check for train |
235 | | - pipe.load_models_to_device(self.onload_model_names) |
236 | | - input_video = pipe.preprocess_video(input_video) |
237 | | - input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
238 | | - if pipe.scheduler.training: |
239 | | - return {"latents": video_noise, "input_latents": input_latents} |
240 | 240 | else: |
241 | | - latents = pipe.scheduler.add_noise(input_latents, video_noise, timestep=pipe.scheduler.timesteps[0]) |
242 | | - return {"latents": latents} |
| 241 | + pipe.load_models_to_device(self.onload_model_names) |
| 242 | + input_video = pipe.preprocess_video(input_video) |
| 243 | + input_latents = pipe.video_vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) |
| 244 | + return {"input_latents": input_latents} |
243 | 245 |
|
244 | 246 |
|
245 | 247 | class MovaAudioVideoUnit_InputAudioEmbedder(PipelineUnit): |
246 | 248 | def __init__(self): |
247 | 249 | super().__init__( |
248 | 250 | input_params=("input_audio", "audio_noise"), |
249 | 251 | output_params=("audio_latents", "audio_input_latents"), |
250 | | - onload_model_names=("audio_vae_encoder",) |
| 252 | + onload_model_names=("audio_vae",) |
251 | 253 | ) |
252 | 254 |
|
253 | 255 | def process(self, pipe: MovaAudioVideoPipeline, input_audio, audio_noise): |
254 | | - if input_audio is None: |
| 256 | + if input_audio is None or not pipe.scheduler.training: |
255 | 257 | return {"audio_latents": audio_noise} |
256 | 258 | else: |
257 | | - # TODO: support audio training |
258 | | - if pipe.scheduler.training: |
259 | | - return {"audio_latents": audio_noise, "audio_input_latents": audio_noise} |
260 | | - else: |
261 | | - raise NotImplementedError("Audio-to-video not supported.") |
| 259 | + input_audio, sample_rate = input_audio |
| 260 | + input_audio = convert_to_mono(input_audio) |
| 261 | + input_audio = resample_waveform(input_audio, sample_rate, pipe.audio_vae.sample_rate) |
| 262 | + input_audio = pipe.audio_vae.preprocess(input_audio.unsqueeze(0), pipe.audio_vae.sample_rate) |
| 263 | + z, _, _, _, _ = pipe.audio_vae.encode(input_audio) |
| 264 | + return {"audio_input_latents": z.mode()} |
262 | 265 |
|
263 | 266 |
|
264 | 267 | class MovaAudioVideoUnit_PromptEmbedder(PipelineUnit): |
@@ -329,15 +332,16 @@ def process(self, pipe: MovaAudioVideoPipeline, input_image, end_image, num_fram |
329 | 332 | y = y.to(dtype=pipe.torch_dtype, device=pipe.device) |
330 | 333 | return {"y": y} |
331 | 334 |
|
| 335 | + |
332 | 336 | class MovaAudioVideoUnit_UnifiedSequenceParallel(PipelineUnit): |
333 | 337 | def __init__(self): |
334 | 338 | super().__init__(input_params=(), output_params=("use_unified_sequence_parallel",)) |
335 | 339 |
|
336 | 340 | def process(self, pipe: MovaAudioVideoPipeline): |
337 | | - if hasattr(pipe, "use_unified_sequence_parallel"): |
338 | | - if pipe.use_unified_sequence_parallel: |
339 | | - return {"use_unified_sequence_parallel": True} |
340 | | - return {} |
| 341 | + if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: |
| 342 | + return {"use_unified_sequence_parallel": True} |
| 343 | + return {"use_unified_sequence_parallel": False} |
| 344 | + |
341 | 345 |
|
342 | 346 | def model_fn_mova_audio_video( |
343 | 347 | video_dit: WanModel, |
|
0 commit comments