Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -39,7 +40,17 @@

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.parallel_state import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
from vllm_omni.platforms import current_omni_platform

logger = init_logger(__name__)


class Flux2SwiGLU(nn.Module):
Expand Down Expand Up @@ -334,6 +345,12 @@ def forward(


class Flux2SingleTransformerBlock(nn.Module):
"""
Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.

SP handling is delegated to Flux2Attention via the forward context.
"""

def __init__(
self,
dim: int,
Expand Down Expand Up @@ -367,6 +384,13 @@ def forward(
split_hidden_states: bool = False,
text_seq_len: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for Flux2SingleTransformerBlock with SP support.

In SP mode: image hidden_states is chunked (B, img_len/SP, D),
text encoder_hidden_states is full (B, txt_len, D).
The block concatenates them for joint attention.
"""
if encoder_hidden_states is not None:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
Expand Down Expand Up @@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
class Flux2Transformer2DModel(nn.Module):
"""
The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
"""

_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
Expand All @@ -580,6 +606,7 @@ def __init__(
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
od_config: OmniDiffusionConfig = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
Expand All @@ -601,6 +628,13 @@ def __init__(
guidance_embeds=guidance_embeds,
)

if od_config is not None:
self.parallel_config = od_config.parallel_config
else:
from vllm_omni.diffusion.data import DiffusionParallelConfig

self.parallel_config = DiffusionParallelConfig()

self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels,
Expand Down Expand Up @@ -672,6 +706,25 @@ def forward(

num_txt_tokens = encoder_hidden_states.shape[1]

sp_size = self.parallel_config.sequence_parallel_size
get_forward_context().sequence_parallel_size = sp_size
if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
original_shape = hidden_states.shape
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
get_forward_context().split_text_embed_in_sp = False
Comment on lines +711 to +716

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pad or validate SP splits for non-divisible seq_len

When sequence_parallel_size > 1, this code shards hidden_states with torch.chunk(...) without padding or validation. If the image token length is not divisible by the SP world size, torch.chunk yields uneven shapes across ranks, and the later get_sp_group().all_gather(output, dim=1) will fail because the group coordinator uses torch.distributed.all_gather_into_tensor, which requires equal-sized tensors. This makes SP mode crash for any input image size where the latent sequence length isn’t divisible by the SP degree; the existing SP auto-padding logic in diffusion/hooks/sequence_parallel.py is bypassed here.

Useful? React with 👍 / 👎.

if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
logger.info(
f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, "
f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
)
else:
if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}")

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
Expand All @@ -690,11 +743,27 @@ def forward(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
if current_omni_platform.is_npu():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gcanlin do we have better ways to handle this difference? this is so awkward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wtomin PTAL

img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu())
img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu()
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu())
txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu()
else:
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)

if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank]
img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank]
if get_forward_context().split_text_embed_in_sp:
txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank]
txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank]

concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
)

for block in self.transformer_blocks:
Expand Down Expand Up @@ -722,6 +791,9 @@ def forward(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

if self.parallel_config.sequence_parallel_size > 1:
output = get_sp_group().all_gather(output, dim=1)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(
).to(self._execution_device)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel)
self.transformer = Flux2Transformer2DModel(**transformer_kwargs)
self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
Expand Down