Skip to content

[feat] Sana-WM Triton Optimizations#29513

Open
sjmshsh wants to merge 12 commits into
sgl-project:mainfrom
sjmshsh:sana-wm
Open

[feat] Sana-WM Triton Optimizations#29513
sjmshsh wants to merge 12 commits into
sgl-project:mainfrom
sjmshsh:sana-wm

Conversation

@sjmshsh

@sjmshsh sjmshsh commented Jun 27, 2026

Copy link
Copy Markdown
Contributor

Sana-WM Triton Optimizations

Summary

This PR improves the Sana-WM Triton integration in multimodal_gen while keeping the current single-process execution model. It prepares the runtime and kernels for future tensor-parallel RMSNorm handling, but does not introduce TP sharding in this change.

What Changed

  • Added CUDA unit coverage for Sana-WM Triton QKV preprocessing, fused Q/K inverse RMS, main GDN, camera scan, and stateful chunkwise paths.
  • Routed the runtime through TP-ready *_with_inv_rms Triton APIs for main GDN and camera preprocessing.
  • Removed the camera scan Q/K/V repack from (B, H, D, N) to (B, N, 3, H, D) by adding direct camera Phase A/C Triton paths.
  • Enabled forward_long to use stateful chunkwise Triton for main GDN and camera GDN, with torch fallbacks preserved.
  • Split camera fast-path fallback state into separate scan, preprocess, softmax-preprocess, and output paths.

Why

  • Future TP support needs Q/K inverse RMS to be computed outside the local kernel, where cross-rank reductions can be inserted later.
  • Camera scan was paying avoidable memory traffic to repack Q/K/V into the generic GDN layout.
  • Streaming inference can reuse the existing stateful chunkwise Triton kernels instead of always falling back to torch scans.
  • A failure in one camera Triton subpath should not disable all camera fast paths.

Compatibility

  • No TP partitioning is enabled in this PR.
  • Existing non-TP behavior is preserved.
  • Existing public Sana-WM Triton entry points remain available; the new state and cached-RoPE arguments are optional.
  • Runtime fallback to torch remains available when Triton guards reject an input or a kernel path fails.

Tests

Added:

  • python/sglang/jit_kernel/tests/diffusion/test_sana_wm_qkv_preprocess.py
  • python/sglang/jit_kernel/tests/diffusion/test_sana_wm_camera_scan.py
  • python/sglang/jit_kernel/tests/diffusion/test_sana_wm_main_gdn.py

Local checks run:

  • python3 -m py_compile ...
  • python3 -m compileall ...
  • git diff --check

CI States

Latest PR Test (Base): ❌ Run #28292163517
Latest PR Test (Extra): ❌ Run #28292163508

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Triton kernels and helper functions to support SANA-WM fused QKV and camera-branch preprocessing, bidirectional GDN entry points, and camera scan chunkwise operations. The review feedback highlights several critical areas for improvement: a potential device mismatch in prepare_sana_wm_rope_tables when rotary_emb is on a different device, buggy nested checks for init_state_z inside init_state_kv blocks in chunkwise GDN functions that could lead to AttributeError or missing padding, and suboptimal GPU occupancy due to using num_warps=1 in multiple Triton kernel launches.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

torch.ones(N, D, device=device, dtype=torch.float32),
torch.zeros(N, D, device=device, dtype=torch.float32),
)
freqs = rotary_emb.squeeze(0).squeeze(0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If rotary_emb is not None, freqs is squeezed but not moved to the target device. If rotary_emb is on CPU or a different GPU, this will cause a device mismatch runtime error when the returned rope_cos and rope_sin are used in Triton kernels. We should explicitly move freqs to device.

Suggested change
freqs = rotary_emb.squeeze(0).squeeze(0)
freqs = rotary_emb.squeeze(0).squeeze(0).to(device)

Comment on lines +1843 to +1861
init_kv_padded, init_z_padded = init_state_kv, init_state_z
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
BLOCK_D_ = I_P_kv.shape[-1]
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
# z: (B, H, D) or (B, H, D, 1) → (BH, BLOCK_D)
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The padding and reshaping of init_state_z is nested inside the if init_state_kv is not None: block. If init_state_kv is None but init_state_z is not None, init_z_padded will not be padded or reshaped, leading to a shape mismatch or runtime error in phase_b_triton. Additionally, if init_state_kv is not None but init_state_z is None, calling init_state_z.dim() will raise an AttributeError. We should separate the checks for init_state_kv and init_state_z and default both padded variables to None.

Suggested change
init_kv_padded, init_z_padded = init_state_kv, init_state_z
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
BLOCK_D_ = I_P_kv.shape[-1]
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
# z: (B, H, D) or (B, H, D, 1) → (BH, BLOCK_D)
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()
init_kv_padded, init_z_padded = None, None
BLOCK_D_ = I_P_kv.shape[-1]
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
if init_state_z is not None:
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()

Comment on lines +1971 to +1989
init_kv_padded, init_z_padded = init_state_kv, init_state_z
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
BLOCK_D_ = I_P_kv.shape[-1]
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The padding and reshaping of init_state_z is nested inside the if init_state_kv is not None: block. If init_state_kv is None but init_state_z is not None, init_z_padded will not be padded or reshaped, leading to a shape mismatch or runtime error in phase_b_triton. Additionally, if init_state_kv is not None but init_state_z is None, calling init_state_z.dim() will raise an AttributeError. We should separate the checks for init_state_kv and init_state_z and default both padded variables to None.

Suggested change
init_kv_padded, init_z_padded = init_state_kv, init_state_z
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
BLOCK_D_ = I_P_kv.shape[-1]
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()
init_kv_padded, init_z_padded = None, None
BLOCK_D_ = I_P_kv.shape[-1]
if init_state_kv is not None:
B_, H_, D_in, D_out = init_state_kv.shape
if D_in != BLOCK_D_ or D_out != BLOCK_D_:
pad_in = BLOCK_D_ - D_in
pad_out = BLOCK_D_ - D_out
init_kv_padded = torch.nn.functional.pad(
init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out)
).contiguous()
else:
init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous()
if init_state_z is not None:
z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z
Bz_, Hz_, Dz_ = z_.shape
if Dz_ != BLOCK_D_:
init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous()
else:
init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous()

OUT_STRIDE_D=out_stride_d,
BLOCK_D_ROPE=block_d_rope,
BLOCK_GROUPS=block_groups,
num_warps=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using num_warps=1 (32 threads) is extremely low and can lead to poor GPU occupancy and under-utilization of execution resources. Increasing num_warps to 4 (128 threads) is a much safer and more performant default for modern GPUs.

Suggested change
num_warps=1,
num_warps=4,

K_SCALE=k_scale,
BLOCK_D_ROPE=block_d_rope,
BLOCK_GROUPS=block_groups,
num_warps=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using num_warps=1 (32 threads) is extremely low and can lead to poor GPU occupancy and under-utilization of execution resources. Increasing num_warps to 4 (128 threads) is a much safer and more performant default for modern GPUs.

Suggested change
num_warps=1,
num_warps=4,

EPS=downscale_eps,
BLOCK_D_ROPE=block_d_rope,
BLOCK_GROUPS=block_groups,
num_warps=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using num_warps=1 (32 threads) is extremely low and can lead to poor GPU occupancy and under-utilization of execution resources. Increasing num_warps to 4 (128 threads) is a much safer and more performant default for modern GPUs.

Suggested change
num_warps=1,
num_warps=4,

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Jun 27, 2026
@sjmshsh sjmshsh changed the title [feat] SANA-WM fused-gdn: TP-aware QKV/CAM preprocessing & optimized chunkwise execution [feat] Sana-WM Triton Optimizations Jun 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion jit-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant