[NPU]GLM-4.7-Flash optimize with fused kernels#29509
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the DeepSeek-V2 MLA attention preparation on NPU by introducing a fused split and normalization kernel (fused_split_qk_norm) for smaller sequence lengths. However, the review identifies two critical issues: first, using the fused kernel when context parallel is enabled leads to a NameError because latent_cache is not defined; second, removing the definition of k_pe from the outer scope causes a NameError when m.q_lora_rank is None.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| k_nope = m.kv_a_layernorm(k_nope).unsqueeze(1) | ||
| k_pe = latent_cache[..., m.kv_lora_rank :].unsqueeze(1) | ||
| else: | ||
| if qkv_latent.shape[0] < 65536: |
There was a problem hiding this comment.
When context parallel is enabled (dsa_use_prefill_cp(forward_batch) is True), latent_cache is required later in m.rebuild_cp_kv_cache (line 249). However, if qkv_latent.shape[0] < 65536 is True, the fused kernel fused_split_qk_norm is called, which does not define latent_cache, leading to a NameError at runtime.
We should add a check to ensure we do not use the fused kernel when context parallel is enabled, similar to the check in forward_dsa_prepare_npu.
| if qkv_latent.shape[0] < 65536: | |
| if qkv_latent.shape[0] < 65536 and not dsa_use_prefill_cp(forward_batch): |
| @@ -217,7 +237,6 @@ def forward_mla_prepare_npu( | |||
| k_nope = m.kv_a_layernorm(k_nope).unsqueeze(1) | |||
There was a problem hiding this comment.
By removing the line k_pe = latent_cache[..., m.kv_lora_rank :].unsqueeze(1) from the outer scope (previously line 220 on the LEFT side), k_pe is no longer defined when m.q_lora_rank is None (the else branch of the outer conditional). This will cause a NameError when attempting to use k_pe in m.rotary_emb (line 245).
We should define k_pe inside the else block to ensure it is available when m.q_lora_rank is None.
| k_nope = m.kv_a_layernorm(k_nope).unsqueeze(1) | |
| k_nope = m.kv_a_layernorm(k_nope).unsqueeze(1) | |
| k_pe = latent_cache[..., m.kv_lora_rank :].unsqueeze(1) |
Motivation
Introduce a fused Triton kernel to improve model performance.
Modifications
Replace the original split + RMSNorm pipeline with a fused Triton kernel.
Accuracy Tests
Before:


After:
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Run #28286980781
Latest PR Test (Extra): ❌ Run #28286980724