2323from vllm .model_executor .models .utils import AutoWeightsLoader
2424
2525from vllm_omni .diffusion .data import DiffusionOutput , OmniDiffusionConfig
26+ from vllm_omni .diffusion .distributed .cfg_parallel import CFGParallelMixin
27+ from vllm_omni .diffusion .distributed .parallel_state import (
28+ get_cfg_group ,
29+ get_classifier_free_guidance_rank ,
30+ get_classifier_free_guidance_world_size ,
31+ )
2632from vllm_omni .diffusion .distributed .utils import get_local_device
2733from vllm_omni .diffusion .model_loader .diffusers_loader import DiffusersPipelineLoader
2834from vllm_omni .diffusion .request import OmniDiffusionRequest
@@ -107,7 +113,7 @@ def calculate_shift(
107113 return mu
108114
109115
110- class LTX2Pipeline (nn .Module ):
116+ class LTX2Pipeline (nn .Module , CFGParallelMixin ):
111117 def __init__ (
112118 self ,
113119 * ,
@@ -605,6 +611,142 @@ def attention_kwargs(self):
605611 def interrupt (self ):
606612 return self ._interrupt
607613
614+ def _is_cfg_parallel_enabled (self , do_true_cfg : bool ) -> bool :
615+ return do_true_cfg and get_classifier_free_guidance_world_size () > 1
616+
617+ def _predict_noise_av (self , ** kwargs ) -> tuple [torch .Tensor , torch .Tensor ]:
618+ with self .transformer .cache_context ("cond_uncond" ):
619+ noise_pred_video , noise_pred_audio = self .transformer (** kwargs )
620+ return noise_pred_video , noise_pred_audio
621+
622+ def predict_noise_av_maybe_with_cfg (
623+ self ,
624+ do_true_cfg : bool ,
625+ true_cfg_scale : float ,
626+ positive_kwargs : dict [str , Any ],
627+ negative_kwargs : dict [str , Any ] | None ,
628+ guidance_rescale : float = 0.0 ,
629+ cfg_normalize : bool = False ,
630+ ) -> tuple [torch .Tensor | None , torch .Tensor | None ]:
631+ if do_true_cfg :
632+ cfg_parallel_ready = get_classifier_free_guidance_world_size () > 1
633+
634+ if cfg_parallel_ready :
635+ cfg_group = get_cfg_group ()
636+ cfg_rank = get_classifier_free_guidance_rank ()
637+
638+ if cfg_rank == 0 :
639+ noise_pred_video , noise_pred_audio = self ._predict_noise_av (** positive_kwargs )
640+ else :
641+ noise_pred_video , noise_pred_audio = self ._predict_noise_av (** negative_kwargs )
642+
643+ noise_pred_video = noise_pred_video .float ()
644+ noise_pred_audio = noise_pred_audio .float ()
645+
646+ gathered_video = cfg_group .all_gather (noise_pred_video , separate_tensors = True )
647+ gathered_audio = cfg_group .all_gather (noise_pred_audio , separate_tensors = True )
648+
649+ if cfg_rank == 0 :
650+ noise_pred_video_text = gathered_video [0 ]
651+ noise_pred_video_uncond = gathered_video [1 ]
652+ noise_pred_audio_text = gathered_audio [0 ]
653+ noise_pred_audio_uncond = gathered_audio [1 ]
654+
655+ noise_pred_video = self .combine_cfg_noise (
656+ noise_pred_video_text ,
657+ noise_pred_video_uncond ,
658+ true_cfg_scale ,
659+ cfg_normalize ,
660+ )
661+ noise_pred_audio = self .combine_cfg_noise (
662+ noise_pred_audio_text ,
663+ noise_pred_audio_uncond ,
664+ true_cfg_scale ,
665+ cfg_normalize ,
666+ )
667+
668+ if guidance_rescale > 0 :
669+ noise_pred_video = rescale_noise_cfg (
670+ noise_pred_video ,
671+ noise_pred_video_text ,
672+ guidance_rescale = guidance_rescale ,
673+ )
674+ noise_pred_audio = rescale_noise_cfg (
675+ noise_pred_audio ,
676+ noise_pred_audio_text ,
677+ guidance_rescale = guidance_rescale ,
678+ )
679+ return noise_pred_video , noise_pred_audio
680+ return None , None
681+
682+ noise_pred_video_text , noise_pred_audio_text = self ._predict_noise_av (** positive_kwargs )
683+ noise_pred_video_uncond , noise_pred_audio_uncond = self ._predict_noise_av (** negative_kwargs )
684+
685+ noise_pred_video_text = noise_pred_video_text .float ()
686+ noise_pred_audio_text = noise_pred_audio_text .float ()
687+ noise_pred_video_uncond = noise_pred_video_uncond .float ()
688+ noise_pred_audio_uncond = noise_pred_audio_uncond .float ()
689+
690+ noise_pred_video = self .combine_cfg_noise (
691+ noise_pred_video_text ,
692+ noise_pred_video_uncond ,
693+ true_cfg_scale ,
694+ cfg_normalize ,
695+ )
696+ noise_pred_audio = self .combine_cfg_noise (
697+ noise_pred_audio_text ,
698+ noise_pred_audio_uncond ,
699+ true_cfg_scale ,
700+ cfg_normalize ,
701+ )
702+
703+ if guidance_rescale > 0 :
704+ noise_pred_video = rescale_noise_cfg (
705+ noise_pred_video ,
706+ noise_pred_video_text ,
707+ guidance_rescale = guidance_rescale ,
708+ )
709+ noise_pred_audio = rescale_noise_cfg (
710+ noise_pred_audio ,
711+ noise_pred_audio_text ,
712+ guidance_rescale = guidance_rescale ,
713+ )
714+
715+ return noise_pred_video , noise_pred_audio
716+
717+ noise_pred_video , noise_pred_audio = self ._predict_noise_av (** positive_kwargs )
718+ return noise_pred_video .float (), noise_pred_audio .float ()
719+
720+ def _scheduler_step_video_audio_maybe_with_cfg (
721+ self ,
722+ noise_pred_video : torch .Tensor | None ,
723+ noise_pred_audio : torch .Tensor | None ,
724+ t : torch .Tensor ,
725+ latents : torch .Tensor ,
726+ audio_latents : torch .Tensor ,
727+ audio_scheduler : FlowMatchEulerDiscreteScheduler ,
728+ do_true_cfg : bool ,
729+ ) -> tuple [torch .Tensor , torch .Tensor ]:
730+ cfg_parallel_ready = self ._is_cfg_parallel_enabled (do_true_cfg )
731+
732+ if cfg_parallel_ready :
733+ cfg_group = get_cfg_group ()
734+ cfg_rank = get_classifier_free_guidance_rank ()
735+
736+ if cfg_rank == 0 :
737+ latents = self .scheduler .step (noise_pred_video , t , latents , return_dict = False )[0 ]
738+ audio_latents = audio_scheduler .step (noise_pred_audio , t , audio_latents , return_dict = False )[0 ]
739+
740+ latents = latents .contiguous ()
741+ audio_latents = audio_latents .contiguous ()
742+ cfg_group .broadcast (latents , src = 0 )
743+ cfg_group .broadcast (audio_latents , src = 0 )
744+ return latents , audio_latents
745+
746+ latents = self .scheduler .step (noise_pred_video , t , latents , return_dict = False )[0 ]
747+ audio_latents = audio_scheduler .step (noise_pred_audio , t , audio_latents , return_dict = False )[0 ]
748+ return latents , audio_latents
749+
608750 @torch .no_grad ()
609751 def forward (
610752 self ,
@@ -750,7 +892,8 @@ def forward(
750892 max_sequence_length = max_sequence_length ,
751893 device = device ,
752894 )
753- if self .do_classifier_free_guidance :
895+ cfg_parallel_ready = self ._is_cfg_parallel_enabled (self .do_classifier_free_guidance )
896+ if self .do_classifier_free_guidance and not cfg_parallel_ready :
754897 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
755898 prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
756899
@@ -759,6 +902,23 @@ def forward(
759902 prompt_embeds , additive_attention_mask , additive_mask = True
760903 )
761904
905+ negative_connector_prompt_embeds = None
906+ negative_connector_audio_prompt_embeds = None
907+ negative_connector_attention_mask = None
908+ if cfg_parallel_ready :
909+ negative_additive_attention_mask = (
910+ 1 - negative_prompt_attention_mask .to (negative_prompt_embeds .dtype )
911+ ) * - 1000000.0
912+ (
913+ negative_connector_prompt_embeds ,
914+ negative_connector_audio_prompt_embeds ,
915+ negative_connector_attention_mask ,
916+ ) = self .connectors (
917+ negative_prompt_embeds ,
918+ negative_additive_attention_mask ,
919+ additive_mask = True ,
920+ )
921+
762922 latent_num_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
763923 latent_height = height // self .vae_spatial_compression_ratio
764924 latent_width = width // self .vae_spatial_compression_ratio
@@ -838,58 +998,119 @@ def forward(
838998
839999 self ._current_timestep = t
8401000
841- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
842- latent_model_input = latent_model_input .to (prompt_embeds .dtype )
843- audio_latent_model_input = (
844- torch .cat ([audio_latents ] * 2 ) if self .do_classifier_free_guidance else audio_latents
845- )
846- audio_latent_model_input = audio_latent_model_input .to (prompt_embeds .dtype )
847-
848- timestep = t .expand (latent_model_input .shape [0 ])
849-
850- with self .transformer .cache_context ("cond_uncond" ):
851- noise_pred_video , noise_pred_audio = self .transformer (
852- hidden_states = latent_model_input ,
853- audio_hidden_states = audio_latent_model_input ,
854- encoder_hidden_states = connector_prompt_embeds ,
855- audio_encoder_hidden_states = connector_audio_prompt_embeds ,
856- timestep = timestep ,
857- encoder_attention_mask = connector_attention_mask ,
858- audio_encoder_attention_mask = connector_attention_mask ,
859- num_frames = latent_num_frames ,
860- height = latent_height ,
861- width = latent_width ,
862- fps = frame_rate ,
863- audio_num_frames = audio_num_frames ,
864- video_coords = video_coords ,
865- audio_coords = audio_coords ,
866- attention_kwargs = attention_kwargs ,
867- return_dict = False ,
1001+ if cfg_parallel_ready :
1002+ latent_model_input = latents .to (prompt_embeds .dtype )
1003+ audio_latent_model_input = audio_latents .to (prompt_embeds .dtype )
1004+ timestep = t .expand (latent_model_input .shape [0 ])
1005+
1006+ positive_kwargs = {
1007+ "hidden_states" : latent_model_input ,
1008+ "audio_hidden_states" : audio_latent_model_input ,
1009+ "encoder_hidden_states" : connector_prompt_embeds ,
1010+ "audio_encoder_hidden_states" : connector_audio_prompt_embeds ,
1011+ "timestep" : timestep ,
1012+ "encoder_attention_mask" : connector_attention_mask ,
1013+ "audio_encoder_attention_mask" : connector_attention_mask ,
1014+ "num_frames" : latent_num_frames ,
1015+ "height" : latent_height ,
1016+ "width" : latent_width ,
1017+ "fps" : frame_rate ,
1018+ "audio_num_frames" : audio_num_frames ,
1019+ "video_coords" : video_coords ,
1020+ "audio_coords" : audio_coords ,
1021+ "attention_kwargs" : attention_kwargs ,
1022+ "return_dict" : False ,
1023+ }
1024+ negative_kwargs = {
1025+ "hidden_states" : latent_model_input ,
1026+ "audio_hidden_states" : audio_latent_model_input ,
1027+ "encoder_hidden_states" : negative_connector_prompt_embeds ,
1028+ "audio_encoder_hidden_states" : negative_connector_audio_prompt_embeds ,
1029+ "timestep" : timestep ,
1030+ "encoder_attention_mask" : negative_connector_attention_mask ,
1031+ "audio_encoder_attention_mask" : negative_connector_attention_mask ,
1032+ "num_frames" : latent_num_frames ,
1033+ "height" : latent_height ,
1034+ "width" : latent_width ,
1035+ "fps" : frame_rate ,
1036+ "audio_num_frames" : audio_num_frames ,
1037+ "video_coords" : video_coords ,
1038+ "audio_coords" : audio_coords ,
1039+ "attention_kwargs" : attention_kwargs ,
1040+ "return_dict" : False ,
1041+ }
1042+
1043+ noise_pred_video , noise_pred_audio = self .predict_noise_av_maybe_with_cfg (
1044+ do_true_cfg = True ,
1045+ true_cfg_scale = guidance_scale ,
1046+ positive_kwargs = positive_kwargs ,
1047+ negative_kwargs = negative_kwargs ,
1048+ guidance_rescale = guidance_rescale ,
1049+ cfg_normalize = False ,
8681050 )
869- noise_pred_video = noise_pred_video .float ()
870- noise_pred_audio = noise_pred_audio .float ()
8711051
872- if self .do_classifier_free_guidance :
873- noise_pred_video_uncond , noise_pred_video_text = noise_pred_video .chunk (2 )
874- noise_pred_video = noise_pred_video_uncond + guidance_scale * (
875- noise_pred_video_text - noise_pred_video_uncond
1052+ latents , audio_latents = self ._scheduler_step_video_audio_maybe_with_cfg (
1053+ noise_pred_video ,
1054+ noise_pred_audio ,
1055+ t ,
1056+ latents ,
1057+ audio_latents ,
1058+ audio_scheduler ,
1059+ do_true_cfg = True ,
8761060 )
877-
878- noise_pred_audio_uncond , noise_pred_audio_text = noise_pred_audio .chunk (2 )
879- noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
880- noise_pred_audio_text - noise_pred_audio_uncond
1061+ else :
1062+ latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1063+ latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1064+ audio_latent_model_input = (
1065+ torch .cat ([audio_latents ] * 2 ) if self .do_classifier_free_guidance else audio_latents
8811066 )
1067+ audio_latent_model_input = audio_latent_model_input .to (prompt_embeds .dtype )
1068+
1069+ timestep = t .expand (latent_model_input .shape [0 ])
1070+
1071+ with self .transformer .cache_context ("cond_uncond" ):
1072+ noise_pred_video , noise_pred_audio = self .transformer (
1073+ hidden_states = latent_model_input ,
1074+ audio_hidden_states = audio_latent_model_input ,
1075+ encoder_hidden_states = connector_prompt_embeds ,
1076+ audio_encoder_hidden_states = connector_audio_prompt_embeds ,
1077+ timestep = timestep ,
1078+ encoder_attention_mask = connector_attention_mask ,
1079+ audio_encoder_attention_mask = connector_attention_mask ,
1080+ num_frames = latent_num_frames ,
1081+ height = latent_height ,
1082+ width = latent_width ,
1083+ fps = frame_rate ,
1084+ audio_num_frames = audio_num_frames ,
1085+ video_coords = video_coords ,
1086+ audio_coords = audio_coords ,
1087+ attention_kwargs = attention_kwargs ,
1088+ return_dict = False ,
1089+ )
1090+ noise_pred_video = noise_pred_video .float ()
1091+ noise_pred_audio = noise_pred_audio .float ()
8821092
883- if guidance_rescale > 0 :
884- noise_pred_video = rescale_noise_cfg (
885- noise_pred_video , noise_pred_video_text , guidance_rescale = guidance_rescale
1093+ if self .do_classifier_free_guidance :
1094+ noise_pred_video_uncond , noise_pred_video_text = noise_pred_video .chunk (2 )
1095+ noise_pred_video = noise_pred_video_uncond + guidance_scale * (
1096+ noise_pred_video_text - noise_pred_video_uncond
8861097 )
887- noise_pred_audio = rescale_noise_cfg (
888- noise_pred_audio , noise_pred_audio_text , guidance_rescale = guidance_rescale
1098+
1099+ noise_pred_audio_uncond , noise_pred_audio_text = noise_pred_audio .chunk (2 )
1100+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
1101+ noise_pred_audio_text - noise_pred_audio_uncond
8891102 )
8901103
891- latents = self .scheduler .step (noise_pred_video , t , latents , return_dict = False )[0 ]
892- audio_latents = audio_scheduler .step (noise_pred_audio , t , audio_latents , return_dict = False )[0 ]
1104+ if guidance_rescale > 0 :
1105+ noise_pred_video = rescale_noise_cfg (
1106+ noise_pred_video , noise_pred_video_text , guidance_rescale = guidance_rescale
1107+ )
1108+ noise_pred_audio = rescale_noise_cfg (
1109+ noise_pred_audio , noise_pred_audio_text , guidance_rescale = guidance_rescale
1110+ )
1111+
1112+ latents = self .scheduler .step (noise_pred_video , t , latents , return_dict = False )[0 ]
1113+ audio_latents = audio_scheduler .step (noise_pred_audio , t , audio_latents , return_dict = False )[0 ]
8931114
8941115 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
8951116 pass
0 commit comments