Skip to content

Commit f3436b8

Browse files
committed
[Feat] support SP for FLUX.2-klein
Signed-off-by: Lancer <maruixiang6688@gmail.com>
1 parent 6c19f3e commit f3436b8

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from diffusers.models.modeling_outputs import Transformer2DModelOutput
3030
from diffusers.models.normalization import AdaLayerNormContinuous
31+
from vllm.logger import init_logger
3132
from vllm.model_executor.layers.layernorm import RMSNorm
3233
from vllm.model_executor.layers.linear import (
3334
ColumnParallelLinear,
@@ -39,7 +40,16 @@
3940

4041
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
4142
from vllm_omni.diffusion.attention.layer import Attention
43+
from vllm_omni.diffusion.distributed.parallel_state import (
44+
get_sequence_parallel_rank,
45+
get_sequence_parallel_world_size,
46+
get_sp_group,
47+
)
48+
from vllm_omni.diffusion.forward_context import get_forward_context
4249
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
50+
from vllm_omni.platforms import current_omni_platform
51+
52+
logger = init_logger(__name__)
4353

4454

4555
class Flux2SwiGLU(nn.Module):
@@ -334,6 +344,12 @@ def forward(
334344

335345

336346
class Flux2SingleTransformerBlock(nn.Module):
347+
"""
348+
Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.
349+
350+
SP handling is delegated to Flux2Attention via the forward context.
351+
"""
352+
337353
def __init__(
338354
self,
339355
dim: int,
@@ -367,6 +383,13 @@ def forward(
367383
split_hidden_states: bool = False,
368384
text_seq_len: int | None = None,
369385
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
386+
"""
387+
Forward pass for Flux2SingleTransformerBlock with SP support.
388+
389+
In SP mode: image hidden_states is chunked (B, img_len/SP, D),
390+
text encoder_hidden_states is full (B, txt_len, D).
391+
The block concatenates them for joint attention.
392+
"""
370393
if encoder_hidden_states is not None:
371394
text_seq_len = encoder_hidden_states.shape[1]
372395
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -556,6 +579,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
556579
class Flux2Transformer2DModel(nn.Module):
557580
"""
558581
The Transformer model introduced in Flux 2.
582+
583+
Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
559584
"""
560585

561586
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
@@ -580,6 +605,7 @@ def __init__(
580605
rope_theta: int = 2000,
581606
eps: float = 1e-6,
582607
guidance_embeds: bool = True,
608+
od_config: "OmniDiffusionConfig" = None,
583609
):
584610
super().__init__()
585611
self.out_channels = out_channels or in_channels
@@ -601,6 +627,12 @@ def __init__(
601627
guidance_embeds=guidance_embeds,
602628
)
603629

630+
if od_config is not None:
631+
self.parallel_config = od_config.parallel_config
632+
else:
633+
from vllm_omni.diffusion.data import DiffusionParallelConfig
634+
self.parallel_config = DiffusionParallelConfig()
635+
604636
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
605637
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
606638
in_channels=timestep_guidance_channels,
@@ -672,6 +704,25 @@ def forward(
672704

673705
num_txt_tokens = encoder_hidden_states.shape[1]
674706

707+
sp_size = self.parallel_config.sequence_parallel_size
708+
get_forward_context().sequence_parallel_size = sp_size
709+
if sp_size > 1:
710+
sp_world_size = get_sequence_parallel_world_size()
711+
sp_rank = get_sequence_parallel_rank()
712+
original_shape = hidden_states.shape
713+
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
714+
get_forward_context().split_text_embed_in_sp = False
715+
if not hasattr(self, "_sp_forward_logged"):
716+
self._sp_forward_logged = True
717+
logger.info(
718+
f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, "
719+
f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
720+
)
721+
else:
722+
if not hasattr(self, "_sp_forward_logged"):
723+
self._sp_forward_logged = True
724+
logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}")
725+
675726
timestep = timestep.to(hidden_states.dtype) * 1000
676727
if guidance is not None:
677728
guidance = guidance.to(hidden_states.dtype) * 1000
@@ -690,11 +741,28 @@ def forward(
690741
if txt_ids.ndim == 3:
691742
txt_ids = txt_ids[0]
692743

693-
image_rotary_emb = self.pos_embed(img_ids)
694-
text_rotary_emb = self.pos_embed(txt_ids)
744+
if current_omni_platform.is_npu():
745+
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu())
746+
img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu()
747+
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu())
748+
txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu()
749+
else:
750+
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
751+
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)
752+
753+
if sp_size > 1:
754+
sp_world_size = get_sequence_parallel_world_size()
755+
sp_rank = get_sequence_parallel_rank()
756+
img_len = img_freqs_cos.shape[0]
757+
img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank]
758+
img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank]
759+
if get_forward_context().split_text_embed_in_sp:
760+
txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank]
761+
txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank]
762+
695763
concat_rotary_emb = (
696-
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
697-
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
764+
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
765+
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
698766
)
699767

700768
for block in self.transformer_blocks:
@@ -722,6 +790,9 @@ def forward(
722790
hidden_states = self.norm_out(hidden_states, temb)
723791
output = self.proj_out(hidden_states)
724792

793+
if self.parallel_config.sequence_parallel_size > 1:
794+
output = get_sp_group().all_gather(output, dim=1)
795+
725796
if not return_dict:
726797
return (output,)
727798
return Transformer2DModelOutput(sample=output)

vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __init__(
230230
).to(self._execution_device)
231231

232232
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel)
233-
self.transformer = Flux2Transformer2DModel(**transformer_kwargs)
233+
self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs)
234234

235235
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
236236
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)

0 commit comments

Comments
 (0)