diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ee10d2e0e4..8e1bf3c7af 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -28,6 +28,7 @@ ) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.normalization import AdaLayerNormContinuous +from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -39,7 +40,17 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) +from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) class Flux2SwiGLU(nn.Module): @@ -334,6 +345,12 @@ def forward( class Flux2SingleTransformerBlock(nn.Module): + """ + Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support. + + SP handling is delegated to Flux2Attention via the forward context. + """ + def __init__( self, dim: int, @@ -367,6 +384,13 @@ def forward( split_hidden_states: bool = False, text_seq_len: int | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for Flux2SingleTransformerBlock with SP support. + + In SP mode: image hidden_states is chunked (B, img_len/SP, D), + text encoder_hidden_states is full (B, txt_len, D). + The block concatenates them for joint attention. + """ if encoder_hidden_states is not None: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, class Flux2Transformer2DModel(nn.Module): """ The Transformer model introduced in Flux 2. + + Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. """ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] @@ -580,6 +606,7 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, + od_config: OmniDiffusionConfig = None, ): super().__init__() self.out_channels = out_channels or in_channels @@ -601,6 +628,13 @@ def __init__( guidance_embeds=guidance_embeds, ) + if od_config is not None: + self.parallel_config = od_config.parallel_config + else: + from vllm_omni.diffusion.data import DiffusionParallelConfig + + self.parallel_config = DiffusionParallelConfig() + self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( in_channels=timestep_guidance_channels, @@ -672,6 +706,25 @@ def forward( num_txt_tokens = encoder_hidden_states.shape[1] + sp_size = self.parallel_config.sequence_parallel_size + get_forward_context().sequence_parallel_size = sp_size + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + original_shape = hidden_states.shape + hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank] + get_forward_context().split_text_embed_in_sp = False + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info( + f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " + f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" + ) + else: + if not hasattr(self, "_sp_forward_logged"): + self._sp_forward_logged = True + logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}") + timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -690,11 +743,27 @@ def forward( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + if current_omni_platform.is_npu(): + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu()) + img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu() + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu()) + txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu() + else: + img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) + txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) + + if sp_size > 1: + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank] + img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank] + if get_forward_context().split_text_embed_in_sp: + txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank] + txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank] + concat_rotary_emb = ( - torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), - torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), + torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), ) for block in self.transformer_blocks: @@ -722,6 +791,9 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) + if self.parallel_config.sequence_parallel_size > 1: + output = get_sp_group().all_gather(output, dim=1) + if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index e1ef706c3f..0de6e11a65 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -230,7 +230,7 @@ def __init__( ).to(self._execution_device) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) - self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)