[feat] Sana-WM Triton Optimizations#29513
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Triton kernels and helper functions to support SANA-WM fused QKV and camera-branch preprocessing, bidirectional GDN entry points, and camera scan chunkwise operations. The review feedback highlights several critical areas for improvement: a potential device mismatch in prepare_sana_wm_rope_tables when rotary_emb is on a different device, buggy nested checks for init_state_z inside init_state_kv blocks in chunkwise GDN functions that could lead to AttributeError or missing padding, and suboptimal GPU occupancy due to using num_warps=1 in multiple Triton kernel launches.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| torch.ones(N, D, device=device, dtype=torch.float32), | ||
| torch.zeros(N, D, device=device, dtype=torch.float32), | ||
| ) | ||
| freqs = rotary_emb.squeeze(0).squeeze(0) |
There was a problem hiding this comment.
If rotary_emb is not None, freqs is squeezed but not moved to the target device. If rotary_emb is on CPU or a different GPU, this will cause a device mismatch runtime error when the returned rope_cos and rope_sin are used in Triton kernels. We should explicitly move freqs to device.
| freqs = rotary_emb.squeeze(0).squeeze(0) | |
| freqs = rotary_emb.squeeze(0).squeeze(0).to(device) |
| init_kv_padded, init_z_padded = init_state_kv, init_state_z | ||
| if init_state_kv is not None: | ||
| B_, H_, D_in, D_out = init_state_kv.shape | ||
| BLOCK_D_ = I_P_kv.shape[-1] | ||
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | ||
| pad_in = BLOCK_D_ - D_in | ||
| pad_out = BLOCK_D_ - D_out | ||
| init_kv_padded = torch.nn.functional.pad( | ||
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | ||
| ).contiguous() | ||
| else: | ||
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | ||
| # z: (B, H, D) or (B, H, D, 1) → (BH, BLOCK_D) | ||
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | ||
| Bz_, Hz_, Dz_ = z_.shape | ||
| if Dz_ != BLOCK_D_: | ||
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | ||
| else: | ||
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() |
There was a problem hiding this comment.
The padding and reshaping of init_state_z is nested inside the if init_state_kv is not None: block. If init_state_kv is None but init_state_z is not None, init_z_padded will not be padded or reshaped, leading to a shape mismatch or runtime error in phase_b_triton. Additionally, if init_state_kv is not None but init_state_z is None, calling init_state_z.dim() will raise an AttributeError. We should separate the checks for init_state_kv and init_state_z and default both padded variables to None.
| init_kv_padded, init_z_padded = init_state_kv, init_state_z | |
| if init_state_kv is not None: | |
| B_, H_, D_in, D_out = init_state_kv.shape | |
| BLOCK_D_ = I_P_kv.shape[-1] | |
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | |
| pad_in = BLOCK_D_ - D_in | |
| pad_out = BLOCK_D_ - D_out | |
| init_kv_padded = torch.nn.functional.pad( | |
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | |
| ).contiguous() | |
| else: | |
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | |
| # z: (B, H, D) or (B, H, D, 1) → (BH, BLOCK_D) | |
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | |
| Bz_, Hz_, Dz_ = z_.shape | |
| if Dz_ != BLOCK_D_: | |
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | |
| else: | |
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() | |
| init_kv_padded, init_z_padded = None, None | |
| BLOCK_D_ = I_P_kv.shape[-1] | |
| if init_state_kv is not None: | |
| B_, H_, D_in, D_out = init_state_kv.shape | |
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | |
| pad_in = BLOCK_D_ - D_in | |
| pad_out = BLOCK_D_ - D_out | |
| init_kv_padded = torch.nn.functional.pad( | |
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | |
| ).contiguous() | |
| else: | |
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | |
| if init_state_z is not None: | |
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | |
| Bz_, Hz_, Dz_ = z_.shape | |
| if Dz_ != BLOCK_D_: | |
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | |
| else: | |
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() |
| init_kv_padded, init_z_padded = init_state_kv, init_state_z | ||
| if init_state_kv is not None: | ||
| B_, H_, D_in, D_out = init_state_kv.shape | ||
| BLOCK_D_ = I_P_kv.shape[-1] | ||
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | ||
| pad_in = BLOCK_D_ - D_in | ||
| pad_out = BLOCK_D_ - D_out | ||
| init_kv_padded = torch.nn.functional.pad( | ||
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | ||
| ).contiguous() | ||
| else: | ||
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | ||
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | ||
| Bz_, Hz_, Dz_ = z_.shape | ||
| if Dz_ != BLOCK_D_: | ||
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | ||
| else: | ||
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() | ||
|
|
There was a problem hiding this comment.
The padding and reshaping of init_state_z is nested inside the if init_state_kv is not None: block. If init_state_kv is None but init_state_z is not None, init_z_padded will not be padded or reshaped, leading to a shape mismatch or runtime error in phase_b_triton. Additionally, if init_state_kv is not None but init_state_z is None, calling init_state_z.dim() will raise an AttributeError. We should separate the checks for init_state_kv and init_state_z and default both padded variables to None.
| init_kv_padded, init_z_padded = init_state_kv, init_state_z | |
| if init_state_kv is not None: | |
| B_, H_, D_in, D_out = init_state_kv.shape | |
| BLOCK_D_ = I_P_kv.shape[-1] | |
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | |
| pad_in = BLOCK_D_ - D_in | |
| pad_out = BLOCK_D_ - D_out | |
| init_kv_padded = torch.nn.functional.pad( | |
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | |
| ).contiguous() | |
| else: | |
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | |
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | |
| Bz_, Hz_, Dz_ = z_.shape | |
| if Dz_ != BLOCK_D_: | |
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | |
| else: | |
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() | |
| init_kv_padded, init_z_padded = None, None | |
| BLOCK_D_ = I_P_kv.shape[-1] | |
| if init_state_kv is not None: | |
| B_, H_, D_in, D_out = init_state_kv.shape | |
| if D_in != BLOCK_D_ or D_out != BLOCK_D_: | |
| pad_in = BLOCK_D_ - D_in | |
| pad_out = BLOCK_D_ - D_out | |
| init_kv_padded = torch.nn.functional.pad( | |
| init_state_kv.transpose(-1, -2).reshape(B_ * H_, D_out, D_in), (0, pad_in, 0, pad_out) | |
| ).contiguous() | |
| else: | |
| init_kv_padded = init_state_kv.transpose(-1, -2).reshape(B_ * H_, BLOCK_D_, BLOCK_D_).contiguous() | |
| if init_state_z is not None: | |
| z_ = init_state_z.squeeze(-1) if init_state_z.dim() == 4 else init_state_z | |
| Bz_, Hz_, Dz_ = z_.shape | |
| if Dz_ != BLOCK_D_: | |
| init_z_padded = torch.nn.functional.pad(z_.reshape(Bz_ * Hz_, Dz_), (0, BLOCK_D_ - Dz_)).contiguous() | |
| else: | |
| init_z_padded = z_.reshape(Bz_ * Hz_, Dz_).contiguous() |
| OUT_STRIDE_D=out_stride_d, | ||
| BLOCK_D_ROPE=block_d_rope, | ||
| BLOCK_GROUPS=block_groups, | ||
| num_warps=1, |
There was a problem hiding this comment.
| K_SCALE=k_scale, | ||
| BLOCK_D_ROPE=block_d_rope, | ||
| BLOCK_GROUPS=block_groups, | ||
| num_warps=1, |
There was a problem hiding this comment.
| EPS=downscale_eps, | ||
| BLOCK_D_ROPE=block_d_rope, | ||
| BLOCK_GROUPS=block_groups, | ||
| num_warps=1, |
There was a problem hiding this comment.
Sana-WM Triton Optimizations
Summary
This PR improves the Sana-WM Triton integration in
multimodal_genwhile keeping the current single-process execution model. It prepares the runtime and kernels for future tensor-parallel RMSNorm handling, but does not introduce TP sharding in this change.What Changed
*_with_inv_rmsTriton APIs for main GDN and camera preprocessing.(B, H, D, N)to(B, N, 3, H, D)by adding direct camera Phase A/C Triton paths.forward_longto use stateful chunkwise Triton for main GDN and camera GDN, with torch fallbacks preserved.Why
Compatibility
Tests
Added:
python/sglang/jit_kernel/tests/diffusion/test_sana_wm_qkv_preprocess.pypython/sglang/jit_kernel/tests/diffusion/test_sana_wm_camera_scan.pypython/sglang/jit_kernel/tests/diffusion/test_sana_wm_main_gdn.pyLocal checks run:
python3 -m py_compile ...python3 -m compileall ...git diff --checkCI States
Latest PR Test (Base): ❌ Run #28292163517
Latest PR Test (Extra): ❌ Run #28292163508