Skip to content

[Feat] support SP for FLUX.2-klein#1250

Open
RuixiangMa wants to merge 1 commit intovllm-project:mainfrom
RuixiangMa:spforflux2klein
Open

[Feat] support SP for FLUX.2-klein#1250
RuixiangMa wants to merge 1 commit intovllm-project:mainfrom
RuixiangMa:spforflux2klein

Conversation

@RuixiangMa
Copy link
Contributor

@RuixiangMa RuixiangMa commented Feb 6, 2026

Purpose

support SP (Ulysses & Ring) for FLUX.2-klein

Test Plan

Test Result

  • Target image:
  • tp = 1 + 4 * nvidia 4090(24G)

curl -s -X POST "http://localhost:8004/v1/images/edits" -F "image=@test.jpg" -F "prompt=Change the sky to orange sunset." -F "guidance_scale=1.0" -F "num_inference_steps=50" -F "n=1" -F "size=1024x1024" -F "output_format=png" | jq -r '.data[0].b64_json' | base64 --decode > output.png

Configuration Ulysses degree Ring degree Generation Time Speedup Images Generated
Baseline 1 1 25.503s 1.00x
Ulysses 4 1 13.173s 1.94x
Ring 1 4 16.866s 1.51x
Hybrid Ulysses + Ring 2 2 14.812s 1.72x

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f3436b8532

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +709 to +716
if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
original_shape = hidden_states.shape
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
get_forward_context().split_text_embed_in_sp = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pad or validate SP splits for non-divisible seq_len

When sequence_parallel_size > 1, this code shards hidden_states with torch.chunk(...) without padding or validation. If the image token length is not divisible by the SP world size, torch.chunk yields uneven shapes across ranks, and the later get_sp_group().all_gather(output, dim=1) will fail because the group coordinator uses torch.distributed.all_gather_into_tensor, which requires equal-sized tensors. This makes SP mode crash for any input image size where the latent sequence length isn’t divisible by the SP degree; the existing SP auto-padding logic in diffusion/hooks/sequence_parallel.py is bypassed here.

Useful? React with 👍 / 👎.

Signed-off-by: Lancer <maruixiang6688@gmail.com>

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
if current_omni_platform.is_npu():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gcanlin do we have better ways to handle this difference? this is so awkward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wtomin PTAL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants