[KDA] Add CuteDSL Prefill Kernel on SM100#27488
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a high-performance SM100/Blackwell CuTeDSL prefill pipeline for Kimi Delta Attention (KDA), featuring a fused Triton prologue and three specialized CuTeDSL kernels (KKT-inverse, recurrent-state update, and output). It also adds comprehensive correctness tests and benchmarking scripts. The review feedback highlights a critical concurrency issue where the global scratch workspace could suffer from data corruption if multiple CUDA streams execute concurrently, suggesting the inclusion of the stream ID in the workspace key. Additionally, a performance optimization is suggested to avoid unnecessary type casting by initializing the token index tensor directly as an int32.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
CuteDSL chunk prefill pipeline for KDA on SM100, dispatched behind a cutedsl prefill backend that falls back to Triton on pre-SM100 GPUs (inert by default).
e3fee40 to
1840775
Compare
7793956 to
9ce98db
Compare
Summary
This PR provides KDA (Kimi Delta Attention) SM100/Blackwell CuteDSL prefill kernel for Kimi-Linear models (
moonshotai/Kimi-Linear-48B-A3B-Instruct) and adds a regression test that exercises real-magnitude gates (the existing unit test only used near-minimum gates and could not catch the bug).It also (a) makes the corrected prefill faster than Triton (1.08–1.52×) by removing per-call host overhead with a reusable scratch workspace, and (b) fixes a separate cuda-graph-padding crash in the CuteDSL extend wrapper.
Motivation / Root cause
The CuteDSL prefill folds the per-channel decay gate into pre-scaled key/query operands inside a fused Triton
prologue:g_cuis the chunk-local cumsum of the (negative) gate andg_lastis its value at the chunk's last token, sog_cu - g_last >= 0has no upper bound. For real Kimi-Linear retention gates this is large: from the checkpoint,exp(A_log)has mean 8.3 / max 201, per-token|g_act|reaches 2661, and the per-chunkg_cu - g_lastspan reaches ~6e4.exp()overflows fp32 at argument ~88, sokL/qg2become+inf; their paired operands (exp(g_last - g_cu)) underflow to0, and the KKT / Aqk MMAs produceinf * 0 = NaN.Because KDA's gate is per-channel, the decay lives inside the MMA contraction and must be baked into the operands; a single chunk-global reference cannot keep both operands in fp32 range. (GDN does not hit this — its gate is scalar and applied as a post-MMA factor.)
The existing unit test used
exp(A_log) = 0.22, which is the model's minimum gate, so the per-chunk span stayed small and the overflow never triggered.safe_gate(lower_bound * sigmoid(...), bounded) does not apply: Kimi-Linear's config has nolower_boundand is trained with the standard-exp(A_log) * softplusgate.Modifications
.../kernels/kda_blackwell/__init__.py(chunk_kda_cutedsl) — numerical fix. The two overflowing operands (kL,qg2) are dropped; instead the intra-chunk gated matrices are computed by the sub-chunk-normalized FLA kernelchunk_kda_scaled_dot_kkt_fwd(everyexpexponent is<= 0→ underflow-safe, never overflow) and injected through the unchanged CuteDSL KKT/Aqk MMAs via an identity right operand: withkL' = M(the gated matrix packed into the first 64 K-slots) andkR' = onehot(chunk-position), the MMA computeskL' @ kR'^T == M. The downstreamkkt_inv_uw(beta + mask + Newton-Schulz inverse + U/W),kernel_h, andkernel_oare unchanged and operate on the now-correct matrices. No CuteDSL kernel code is modified..../kernels/kda_blackwell/__init__.py(_KDA_WSworkspace) — performance. The CuteDSL kernels were already fast; the full function was below Triton purely from per-call host overhead (re-allocating + re-zeroingeyeand the two pack buffers, ~200 MB/call, plus atotal = chunk_offsets[-1].item()device sync). A module-level grow-only workspace keyed by(Hv, K, V, device, dtype)now serves the scratch tensors (kL/qg2/eye/U/W/V_new/h_chunks); the pack writes only the first 64 K-cols (the rest stay zero from the one-time zeroed alloc), andeye/metadata are recomputed only when thecu_seqlensobject identity changes (sync-free, so all KDA layers in a forward share one computation). Only the returnedo/htare freshly allocated. Safe because KDA layers run sequentially on one CUDA stream; the numerics are unchanged (o_err4.88e-4)..../kernels/kda_cutedsl.py(CuteDSLKDAKernel.extend) — padding fix.unified_linear_attention_with_outputnarrowsg/betawith[:real_num_tokens], but those tensors carry a leading batch dim, so the slice trims the batch (a no-op) rather than the token dim. Under cuda-graph padding this leftg/betalonger thanq, tripping the kernel's strict shape check (Mismatched beta.shape[0]). They are now trimmed toq's real token count..../linear/kda_backend.py— removed a temporary "numerically unstable" guard (no longer applicable).test/registered/attention/test_kda_prefill_cutedsl.py— addedtest_kda_chunk_cutedsl_realistic_gate(large gates,exp(A_log) ~ 4.5) asserting finite output matching the recurrent reference.Accuracy
moonshotai/Kimi-Linear-48B-A3B-Instruct, TP=2, gsm8k (200 examples, completion API):Unit tests on B200 (
test_kda_prefill_cutedsl.py): 5 passed (3 varlen correctness cases, internal gate activation, and the new realistic-gate test).Benchmark
Full prefill function (
chunk_kda_cutedslvschunk_kda), single sequence, H=32, K=V=128, bf16, realistic Kimi-Linear gates (B200), steady state:o_errvs the token-by-token recurrent reference is 4.88e-4 (FINITE) at every size — numerically identical to the correctness baseline.Performance note. The CuteDSL KKT/inverse/state/output kernels were always fast; the bottleneck was per-call host overhead — re-allocating and re-zeroing the
eyeand two pack buffers (~200 MB/call), plus a metadata.item()sync. The intra-chunk matrices are still computed by Triton and injected (the sub-chunk normalization that fixes the real-gate overflow), but the scratch tensors are now served from a per-(Hv,K,V,device)grow-only workspace and theeye/metadata are recomputed only whencu_seqlenschanges (sync-free, so all KDA layers in a forward share one computation). Net: the prefill now beats Triton at all sizes (1.08–1.52x) with no change to the numerics. An earlier attempt to move the off-diagonal KKT into the kernel via tensor-core MMA was abandoned — the row-/col-masked operands[NB,T,Hv,K]inflate operand memory ~4x, making it net-negative; the diagonal block is per-channel-gated and has no efficient in-kernel (non-serial) form.Test plan
python -m pytest test/registered/attention/test_kda_prefill_cutedsl.pyon a Blackwell (SM100) GPU.python -m sglang.launch_server --model-path moonshotai/Kimi-Linear-48B-A3B-Instruct --tp-size 2 --trust-remote-code --linear-attn-prefill-backend cutedsl --linear-attn-decode-backend triton, then gsm8k.Checklist
CI States
Latest PR Test (Base): ❌ Run #27093943363
Latest PR Test (Extra): ✅ Run #27093943326