2828)
2929from diffusers .models .modeling_outputs import Transformer2DModelOutput
3030from diffusers .models .normalization import AdaLayerNormContinuous
31+ from vllm .logger import init_logger
3132from vllm .model_executor .layers .layernorm import RMSNorm
3233from vllm .model_executor .layers .linear import (
3334 ColumnParallelLinear ,
3940
4041from vllm_omni .diffusion .attention .backends .abstract import AttentionMetadata
4142from vllm_omni .diffusion .attention .layer import Attention
43+ from vllm_omni .diffusion .distributed .parallel_state import (
44+ get_sequence_parallel_rank ,
45+ get_sequence_parallel_world_size ,
46+ get_sp_group ,
47+ )
48+ from vllm_omni .diffusion .forward_context import get_forward_context
4249from vllm_omni .diffusion .layers .rope import RotaryEmbedding
50+ from vllm_omni .platforms import current_omni_platform
51+
52+ logger = init_logger (__name__ )
4353
4454
4555class Flux2SwiGLU (nn .Module ):
@@ -334,6 +344,12 @@ def forward(
334344
335345
336346class Flux2SingleTransformerBlock (nn .Module ):
347+ """
348+ Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.
349+
350+ SP handling is delegated to Flux2Attention via the forward context.
351+ """
352+
337353 def __init__ (
338354 self ,
339355 dim : int ,
@@ -367,6 +383,13 @@ def forward(
367383 split_hidden_states : bool = False ,
368384 text_seq_len : int | None = None ,
369385 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
386+ """
387+ Forward pass for Flux2SingleTransformerBlock with SP support.
388+
389+ In SP mode: image hidden_states is chunked (B, img_len/SP, D),
390+ text encoder_hidden_states is full (B, txt_len, D).
391+ The block concatenates them for joint attention.
392+ """
370393 if encoder_hidden_states is not None :
371394 text_seq_len = encoder_hidden_states .shape [1 ]
372395 hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
@@ -556,6 +579,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
556579class Flux2Transformer2DModel (nn .Module ):
557580 """
558581 The Transformer model introduced in Flux 2.
582+
583+ Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
559584 """
560585
561586 _repeated_blocks = ["Flux2TransformerBlock" , "Flux2SingleTransformerBlock" ]
@@ -580,6 +605,7 @@ def __init__(
580605 rope_theta : int = 2000 ,
581606 eps : float = 1e-6 ,
582607 guidance_embeds : bool = True ,
608+ od_config : "OmniDiffusionConfig" = None ,
583609 ):
584610 super ().__init__ ()
585611 self .out_channels = out_channels or in_channels
@@ -601,6 +627,12 @@ def __init__(
601627 guidance_embeds = guidance_embeds ,
602628 )
603629
630+ if od_config is not None :
631+ self .parallel_config = od_config .parallel_config
632+ else :
633+ from vllm_omni .diffusion .data import DiffusionParallelConfig
634+ self .parallel_config = DiffusionParallelConfig ()
635+
604636 self .pos_embed = Flux2PosEmbed (theta = rope_theta , axes_dim = list (axes_dims_rope ))
605637 self .time_guidance_embed = Flux2TimestepGuidanceEmbeddings (
606638 in_channels = timestep_guidance_channels ,
@@ -672,6 +704,25 @@ def forward(
672704
673705 num_txt_tokens = encoder_hidden_states .shape [1 ]
674706
707+ sp_size = self .parallel_config .sequence_parallel_size
708+ get_forward_context ().sequence_parallel_size = sp_size
709+ if sp_size > 1 :
710+ sp_world_size = get_sequence_parallel_world_size ()
711+ sp_rank = get_sequence_parallel_rank ()
712+ original_shape = hidden_states .shape
713+ hidden_states = torch .chunk (hidden_states , sp_world_size , dim = 1 )[sp_rank ]
714+ get_forward_context ().split_text_embed_in_sp = False
715+ if not hasattr (self , "_sp_forward_logged" ):
716+ self ._sp_forward_logged = True
717+ logger .info (
718+ f"[Flux2 Transformer] SP enabled: sp_size={ sp_size } , world_size={ sp_world_size } , "
719+ f"rank={ sp_rank } , original_shape={ original_shape } , chunked_shape={ hidden_states .shape } "
720+ )
721+ else :
722+ if not hasattr (self , "_sp_forward_logged" ):
723+ self ._sp_forward_logged = True
724+ logger .info (f"[Flux2 Transformer] SP disabled: sp_size={ sp_size } " )
725+
675726 timestep = timestep .to (hidden_states .dtype ) * 1000
676727 if guidance is not None :
677728 guidance = guidance .to (hidden_states .dtype ) * 1000
@@ -690,11 +741,28 @@ def forward(
690741 if txt_ids .ndim == 3 :
691742 txt_ids = txt_ids [0 ]
692743
693- image_rotary_emb = self .pos_embed (img_ids )
694- text_rotary_emb = self .pos_embed (txt_ids )
744+ if current_omni_platform .is_npu ():
745+ img_freqs_cos , img_freqs_sin = self .pos_embed (img_ids .cpu ())
746+ img_freqs_cos , img_freqs_sin = img_freqs_cos .npu (), img_freqs_sin .npu ()
747+ txt_freqs_cos , txt_freqs_sin = self .pos_embed (txt_ids .cpu ())
748+ txt_freqs_cos , txt_freqs_sin = txt_freqs_cos .npu (), txt_freqs_sin .npu ()
749+ else :
750+ img_freqs_cos , img_freqs_sin = self .pos_embed (img_ids )
751+ txt_freqs_cos , txt_freqs_sin = self .pos_embed (txt_ids )
752+
753+ if sp_size > 1 :
754+ sp_world_size = get_sequence_parallel_world_size ()
755+ sp_rank = get_sequence_parallel_rank ()
756+ img_len = img_freqs_cos .shape [0 ]
757+ img_freqs_cos = torch .chunk (img_freqs_cos , sp_world_size , dim = 0 )[sp_rank ]
758+ img_freqs_sin = torch .chunk (img_freqs_sin , sp_world_size , dim = 0 )[sp_rank ]
759+ if get_forward_context ().split_text_embed_in_sp :
760+ txt_freqs_cos = torch .chunk (txt_freqs_cos , sp_world_size , dim = 0 )[sp_rank ]
761+ txt_freqs_sin = torch .chunk (txt_freqs_sin , sp_world_size , dim = 0 )[sp_rank ]
762+
695763 concat_rotary_emb = (
696- torch .cat ([text_rotary_emb [ 0 ], image_rotary_emb [ 0 ] ], dim = 0 ),
697- torch .cat ([text_rotary_emb [ 1 ], image_rotary_emb [ 1 ] ], dim = 0 ),
764+ torch .cat ([txt_freqs_cos , img_freqs_cos ], dim = 0 ),
765+ torch .cat ([txt_freqs_sin , img_freqs_sin ], dim = 0 ),
698766 )
699767
700768 for block in self .transformer_blocks :
@@ -722,6 +790,9 @@ def forward(
722790 hidden_states = self .norm_out (hidden_states , temb )
723791 output = self .proj_out (hidden_states )
724792
793+ if self .parallel_config .sequence_parallel_size > 1 :
794+ output = get_sp_group ().all_gather (output , dim = 1 )
795+
725796 if not return_dict :
726797 return (output ,)
727798 return Transformer2DModelOutput (sample = output )
0 commit comments