Skip to content

Commit 84e0305

Browse files
committed
support cfg
Signed-off-by: David Chen <530634352@qq.com>
1 parent bccd723 commit 84e0305

File tree

2 files changed

+448
-117
lines changed

2 files changed

+448
-117
lines changed

vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py

Lines changed: 267 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
from vllm.model_executor.models.utils import AutoWeightsLoader
2424

2525
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
26+
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
27+
from vllm_omni.diffusion.distributed.parallel_state import (
28+
get_cfg_group,
29+
get_classifier_free_guidance_rank,
30+
get_classifier_free_guidance_world_size,
31+
)
2632
from vllm_omni.diffusion.distributed.utils import get_local_device
2733
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
2834
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -107,7 +113,7 @@ def calculate_shift(
107113
return mu
108114

109115

110-
class LTX2Pipeline(nn.Module):
116+
class LTX2Pipeline(nn.Module, CFGParallelMixin):
111117
def __init__(
112118
self,
113119
*,
@@ -605,6 +611,142 @@ def attention_kwargs(self):
605611
def interrupt(self):
606612
return self._interrupt
607613

614+
def _is_cfg_parallel_enabled(self, do_true_cfg: bool) -> bool:
615+
return do_true_cfg and get_classifier_free_guidance_world_size() > 1
616+
617+
def _predict_noise_av(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
618+
with self.transformer.cache_context("cond_uncond"):
619+
noise_pred_video, noise_pred_audio = self.transformer(**kwargs)
620+
return noise_pred_video, noise_pred_audio
621+
622+
def predict_noise_av_maybe_with_cfg(
623+
self,
624+
do_true_cfg: bool,
625+
true_cfg_scale: float,
626+
positive_kwargs: dict[str, Any],
627+
negative_kwargs: dict[str, Any] | None,
628+
guidance_rescale: float = 0.0,
629+
cfg_normalize: bool = False,
630+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
631+
if do_true_cfg:
632+
cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1
633+
634+
if cfg_parallel_ready:
635+
cfg_group = get_cfg_group()
636+
cfg_rank = get_classifier_free_guidance_rank()
637+
638+
if cfg_rank == 0:
639+
noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs)
640+
else:
641+
noise_pred_video, noise_pred_audio = self._predict_noise_av(**negative_kwargs)
642+
643+
noise_pred_video = noise_pred_video.float()
644+
noise_pred_audio = noise_pred_audio.float()
645+
646+
gathered_video = cfg_group.all_gather(noise_pred_video, separate_tensors=True)
647+
gathered_audio = cfg_group.all_gather(noise_pred_audio, separate_tensors=True)
648+
649+
if cfg_rank == 0:
650+
noise_pred_video_text = gathered_video[0]
651+
noise_pred_video_uncond = gathered_video[1]
652+
noise_pred_audio_text = gathered_audio[0]
653+
noise_pred_audio_uncond = gathered_audio[1]
654+
655+
noise_pred_video = self.combine_cfg_noise(
656+
noise_pred_video_text,
657+
noise_pred_video_uncond,
658+
true_cfg_scale,
659+
cfg_normalize,
660+
)
661+
noise_pred_audio = self.combine_cfg_noise(
662+
noise_pred_audio_text,
663+
noise_pred_audio_uncond,
664+
true_cfg_scale,
665+
cfg_normalize,
666+
)
667+
668+
if guidance_rescale > 0:
669+
noise_pred_video = rescale_noise_cfg(
670+
noise_pred_video,
671+
noise_pred_video_text,
672+
guidance_rescale=guidance_rescale,
673+
)
674+
noise_pred_audio = rescale_noise_cfg(
675+
noise_pred_audio,
676+
noise_pred_audio_text,
677+
guidance_rescale=guidance_rescale,
678+
)
679+
return noise_pred_video, noise_pred_audio
680+
return None, None
681+
682+
noise_pred_video_text, noise_pred_audio_text = self._predict_noise_av(**positive_kwargs)
683+
noise_pred_video_uncond, noise_pred_audio_uncond = self._predict_noise_av(**negative_kwargs)
684+
685+
noise_pred_video_text = noise_pred_video_text.float()
686+
noise_pred_audio_text = noise_pred_audio_text.float()
687+
noise_pred_video_uncond = noise_pred_video_uncond.float()
688+
noise_pred_audio_uncond = noise_pred_audio_uncond.float()
689+
690+
noise_pred_video = self.combine_cfg_noise(
691+
noise_pred_video_text,
692+
noise_pred_video_uncond,
693+
true_cfg_scale,
694+
cfg_normalize,
695+
)
696+
noise_pred_audio = self.combine_cfg_noise(
697+
noise_pred_audio_text,
698+
noise_pred_audio_uncond,
699+
true_cfg_scale,
700+
cfg_normalize,
701+
)
702+
703+
if guidance_rescale > 0:
704+
noise_pred_video = rescale_noise_cfg(
705+
noise_pred_video,
706+
noise_pred_video_text,
707+
guidance_rescale=guidance_rescale,
708+
)
709+
noise_pred_audio = rescale_noise_cfg(
710+
noise_pred_audio,
711+
noise_pred_audio_text,
712+
guidance_rescale=guidance_rescale,
713+
)
714+
715+
return noise_pred_video, noise_pred_audio
716+
717+
noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs)
718+
return noise_pred_video.float(), noise_pred_audio.float()
719+
720+
def _scheduler_step_video_audio_maybe_with_cfg(
721+
self,
722+
noise_pred_video: torch.Tensor | None,
723+
noise_pred_audio: torch.Tensor | None,
724+
t: torch.Tensor,
725+
latents: torch.Tensor,
726+
audio_latents: torch.Tensor,
727+
audio_scheduler: FlowMatchEulerDiscreteScheduler,
728+
do_true_cfg: bool,
729+
) -> tuple[torch.Tensor, torch.Tensor]:
730+
cfg_parallel_ready = self._is_cfg_parallel_enabled(do_true_cfg)
731+
732+
if cfg_parallel_ready:
733+
cfg_group = get_cfg_group()
734+
cfg_rank = get_classifier_free_guidance_rank()
735+
736+
if cfg_rank == 0:
737+
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
738+
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
739+
740+
latents = latents.contiguous()
741+
audio_latents = audio_latents.contiguous()
742+
cfg_group.broadcast(latents, src=0)
743+
cfg_group.broadcast(audio_latents, src=0)
744+
return latents, audio_latents
745+
746+
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
747+
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
748+
return latents, audio_latents
749+
608750
@torch.no_grad()
609751
def forward(
610752
self,
@@ -750,7 +892,8 @@ def forward(
750892
max_sequence_length=max_sequence_length,
751893
device=device,
752894
)
753-
if self.do_classifier_free_guidance:
895+
cfg_parallel_ready = self._is_cfg_parallel_enabled(self.do_classifier_free_guidance)
896+
if self.do_classifier_free_guidance and not cfg_parallel_ready:
754897
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
755898
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
756899

@@ -759,6 +902,23 @@ def forward(
759902
prompt_embeds, additive_attention_mask, additive_mask=True
760903
)
761904

905+
negative_connector_prompt_embeds = None
906+
negative_connector_audio_prompt_embeds = None
907+
negative_connector_attention_mask = None
908+
if cfg_parallel_ready:
909+
negative_additive_attention_mask = (
910+
1 - negative_prompt_attention_mask.to(negative_prompt_embeds.dtype)
911+
) * -1000000.0
912+
(
913+
negative_connector_prompt_embeds,
914+
negative_connector_audio_prompt_embeds,
915+
negative_connector_attention_mask,
916+
) = self.connectors(
917+
negative_prompt_embeds,
918+
negative_additive_attention_mask,
919+
additive_mask=True,
920+
)
921+
762922
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
763923
latent_height = height // self.vae_spatial_compression_ratio
764924
latent_width = width // self.vae_spatial_compression_ratio
@@ -838,58 +998,119 @@ def forward(
838998

839999
self._current_timestep = t
8401000

841-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
842-
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
843-
audio_latent_model_input = (
844-
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
845-
)
846-
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
847-
848-
timestep = t.expand(latent_model_input.shape[0])
849-
850-
with self.transformer.cache_context("cond_uncond"):
851-
noise_pred_video, noise_pred_audio = self.transformer(
852-
hidden_states=latent_model_input,
853-
audio_hidden_states=audio_latent_model_input,
854-
encoder_hidden_states=connector_prompt_embeds,
855-
audio_encoder_hidden_states=connector_audio_prompt_embeds,
856-
timestep=timestep,
857-
encoder_attention_mask=connector_attention_mask,
858-
audio_encoder_attention_mask=connector_attention_mask,
859-
num_frames=latent_num_frames,
860-
height=latent_height,
861-
width=latent_width,
862-
fps=frame_rate,
863-
audio_num_frames=audio_num_frames,
864-
video_coords=video_coords,
865-
audio_coords=audio_coords,
866-
attention_kwargs=attention_kwargs,
867-
return_dict=False,
1001+
if cfg_parallel_ready:
1002+
latent_model_input = latents.to(prompt_embeds.dtype)
1003+
audio_latent_model_input = audio_latents.to(prompt_embeds.dtype)
1004+
timestep = t.expand(latent_model_input.shape[0])
1005+
1006+
positive_kwargs = {
1007+
"hidden_states": latent_model_input,
1008+
"audio_hidden_states": audio_latent_model_input,
1009+
"encoder_hidden_states": connector_prompt_embeds,
1010+
"audio_encoder_hidden_states": connector_audio_prompt_embeds,
1011+
"timestep": timestep,
1012+
"encoder_attention_mask": connector_attention_mask,
1013+
"audio_encoder_attention_mask": connector_attention_mask,
1014+
"num_frames": latent_num_frames,
1015+
"height": latent_height,
1016+
"width": latent_width,
1017+
"fps": frame_rate,
1018+
"audio_num_frames": audio_num_frames,
1019+
"video_coords": video_coords,
1020+
"audio_coords": audio_coords,
1021+
"attention_kwargs": attention_kwargs,
1022+
"return_dict": False,
1023+
}
1024+
negative_kwargs = {
1025+
"hidden_states": latent_model_input,
1026+
"audio_hidden_states": audio_latent_model_input,
1027+
"encoder_hidden_states": negative_connector_prompt_embeds,
1028+
"audio_encoder_hidden_states": negative_connector_audio_prompt_embeds,
1029+
"timestep": timestep,
1030+
"encoder_attention_mask": negative_connector_attention_mask,
1031+
"audio_encoder_attention_mask": negative_connector_attention_mask,
1032+
"num_frames": latent_num_frames,
1033+
"height": latent_height,
1034+
"width": latent_width,
1035+
"fps": frame_rate,
1036+
"audio_num_frames": audio_num_frames,
1037+
"video_coords": video_coords,
1038+
"audio_coords": audio_coords,
1039+
"attention_kwargs": attention_kwargs,
1040+
"return_dict": False,
1041+
}
1042+
1043+
noise_pred_video, noise_pred_audio = self.predict_noise_av_maybe_with_cfg(
1044+
do_true_cfg=True,
1045+
true_cfg_scale=guidance_scale,
1046+
positive_kwargs=positive_kwargs,
1047+
negative_kwargs=negative_kwargs,
1048+
guidance_rescale=guidance_rescale,
1049+
cfg_normalize=False,
8681050
)
869-
noise_pred_video = noise_pred_video.float()
870-
noise_pred_audio = noise_pred_audio.float()
8711051

872-
if self.do_classifier_free_guidance:
873-
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
874-
noise_pred_video = noise_pred_video_uncond + guidance_scale * (
875-
noise_pred_video_text - noise_pred_video_uncond
1052+
latents, audio_latents = self._scheduler_step_video_audio_maybe_with_cfg(
1053+
noise_pred_video,
1054+
noise_pred_audio,
1055+
t,
1056+
latents,
1057+
audio_latents,
1058+
audio_scheduler,
1059+
do_true_cfg=True,
8761060
)
877-
878-
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
879-
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
880-
noise_pred_audio_text - noise_pred_audio_uncond
1061+
else:
1062+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1063+
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1064+
audio_latent_model_input = (
1065+
torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
8811066
)
1067+
audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
1068+
1069+
timestep = t.expand(latent_model_input.shape[0])
1070+
1071+
with self.transformer.cache_context("cond_uncond"):
1072+
noise_pred_video, noise_pred_audio = self.transformer(
1073+
hidden_states=latent_model_input,
1074+
audio_hidden_states=audio_latent_model_input,
1075+
encoder_hidden_states=connector_prompt_embeds,
1076+
audio_encoder_hidden_states=connector_audio_prompt_embeds,
1077+
timestep=timestep,
1078+
encoder_attention_mask=connector_attention_mask,
1079+
audio_encoder_attention_mask=connector_attention_mask,
1080+
num_frames=latent_num_frames,
1081+
height=latent_height,
1082+
width=latent_width,
1083+
fps=frame_rate,
1084+
audio_num_frames=audio_num_frames,
1085+
video_coords=video_coords,
1086+
audio_coords=audio_coords,
1087+
attention_kwargs=attention_kwargs,
1088+
return_dict=False,
1089+
)
1090+
noise_pred_video = noise_pred_video.float()
1091+
noise_pred_audio = noise_pred_audio.float()
8821092

883-
if guidance_rescale > 0:
884-
noise_pred_video = rescale_noise_cfg(
885-
noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale
1093+
if self.do_classifier_free_guidance:
1094+
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
1095+
noise_pred_video = noise_pred_video_uncond + guidance_scale * (
1096+
noise_pred_video_text - noise_pred_video_uncond
8861097
)
887-
noise_pred_audio = rescale_noise_cfg(
888-
noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale
1098+
1099+
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
1100+
noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
1101+
noise_pred_audio_text - noise_pred_audio_uncond
8891102
)
8901103

891-
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
892-
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
1104+
if guidance_rescale > 0:
1105+
noise_pred_video = rescale_noise_cfg(
1106+
noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale
1107+
)
1108+
noise_pred_audio = rescale_noise_cfg(
1109+
noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale
1110+
)
1111+
1112+
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
1113+
audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
8931114

8941115
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
8951116
pass

0 commit comments

Comments
 (0)