Skip to content

[KDA] Add CuteDSL Prefill Kernel on SM100#27488

Open
yuan-luo wants to merge 2 commits into
sgl-project:mainfrom
yuan-luo:support_kda_sm100
Open

[KDA] Add CuteDSL Prefill Kernel on SM100#27488
yuan-luo wants to merge 2 commits into
sgl-project:mainfrom
yuan-luo:support_kda_sm100

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Jun 7, 2026

Summary

This PR provides KDA (Kimi Delta Attention) SM100/Blackwell CuteDSL prefill kernel for Kimi-Linear models (moonshotai/Kimi-Linear-48B-A3B-Instruct) and adds a regression test that exercises real-magnitude gates (the existing unit test only used near-minimum gates and could not catch the bug).

It also (a) makes the corrected prefill faster than Triton (1.08–1.52×) by removing per-call host overhead with a reusable scratch workspace, and (b) fixes a separate cuda-graph-padding crash in the CuteDSL extend wrapper.

Motivation / Root cause

The CuteDSL prefill folds the per-channel decay gate into pre-scaled key/query operands inside a fused Triton prologue:

kL  = k * exp(g_cu - g_last)     # exponent g_cu - g_last >= 0, UNBOUNDED
qg2 = scale * q * exp(g_cu - g_last)

g_cu is the chunk-local cumsum of the (negative) gate and g_last is its value at the chunk's last token, so g_cu - g_last >= 0 has no upper bound. For real Kimi-Linear retention gates this is large: from the checkpoint, exp(A_log) has mean 8.3 / max 201, per-token |g_act| reaches 2661, and the per-chunk g_cu - g_last span reaches ~6e4. exp() overflows fp32 at argument ~88, so kL/qg2 become +inf; their paired operands (exp(g_last - g_cu)) underflow to 0, and the KKT / Aqk MMAs produce inf * 0 = NaN.

Because KDA's gate is per-channel, the decay lives inside the MMA contraction and must be baked into the operands; a single chunk-global reference cannot keep both operands in fp32 range. (GDN does not hit this — its gate is scalar and applied as a post-MMA factor.)

The existing unit test used exp(A_log) = 0.22, which is the model's minimum gate, so the per-chunk span stayed small and the overflow never triggered. safe_gate (lower_bound * sigmoid(...), bounded) does not apply: Kimi-Linear's config has no lower_bound and is trained with the standard -exp(A_log) * softplus gate.

Modifications

  • .../kernels/kda_blackwell/__init__.py (chunk_kda_cutedsl) — numerical fix. The two overflowing operands (kL, qg2) are dropped; instead the intra-chunk gated matrices are computed by the sub-chunk-normalized FLA kernel chunk_kda_scaled_dot_kkt_fwd (every exp exponent is <= 0 → underflow-safe, never overflow) and injected through the unchanged CuteDSL KKT/Aqk MMAs via an identity right operand: with kL' = M (the gated matrix packed into the first 64 K-slots) and kR' = onehot(chunk-position), the MMA computes kL' @ kR'^T == M. The downstream kkt_inv_uw (beta + mask + Newton-Schulz inverse + U/W), kernel_h, and kernel_o are unchanged and operate on the now-correct matrices. No CuteDSL kernel code is modified.
  • .../kernels/kda_blackwell/__init__.py (_KDA_WS workspace) — performance. The CuteDSL kernels were already fast; the full function was below Triton purely from per-call host overhead (re-allocating + re-zeroing eye and the two pack buffers, ~200 MB/call, plus a total = chunk_offsets[-1].item() device sync). A module-level grow-only workspace keyed by (Hv, K, V, device, dtype) now serves the scratch tensors (kL/qg2/eye/U/W/V_new/h_chunks); the pack writes only the first 64 K-cols (the rest stay zero from the one-time zeroed alloc), and eye/metadata are recomputed only when the cu_seqlens object identity changes (sync-free, so all KDA layers in a forward share one computation). Only the returned o/ht are freshly allocated. Safe because KDA layers run sequentially on one CUDA stream; the numerics are unchanged (o_err 4.88e-4).
  • .../kernels/kda_cutedsl.py (CuteDSLKDAKernel.extend) — padding fix. unified_linear_attention_with_output narrows g/beta with [:real_num_tokens], but those tensors carry a leading batch dim, so the slice trims the batch (a no-op) rather than the token dim. Under cuda-graph padding this left g/beta longer than q, tripping the kernel's strict shape check (Mismatched beta.shape[0]). They are now trimmed to q's real token count.
  • .../linear/kda_backend.py — removed a temporary "numerically unstable" guard (no longer applicable).
  • test/registered/attention/test_kda_prefill_cutedsl.py — added test_kda_chunk_cutedsl_realistic_gate (large gates, exp(A_log) ~ 4.5) asserting finite output matching the recurrent reference.

Accuracy

moonshotai/Kimi-Linear-48B-A3B-Instruct, TP=2, gsm8k (200 examples, completion API):

prefill backend decode backend gsm8k
triton (baseline) triton 0.915
cutedsl (before this PR) triton 0.000
cutedsl (this PR) triton 0.920

Unit tests on B200 (test_kda_prefill_cutedsl.py): 5 passed (3 varlen correctness cases, internal gate activation, and the new realistic-gate test).

Benchmark

Full prefill function (chunk_kda_cutedsl vs chunk_kda), single sequence, H=32, K=V=128, bf16, realistic Kimi-Linear gates (B200), steady state:

T Triton (ms) CuteDSL this PR (ms) speedup
2048 0.525 0.346 1.52x
4096 0.768 0.628 1.22x
8192 1.277 1.182 1.08x

o_err vs the token-by-token recurrent reference is 4.88e-4 (FINITE) at every size — numerically identical to the correctness baseline.

Performance note. The CuteDSL KKT/inverse/state/output kernels were always fast; the bottleneck was per-call host overhead — re-allocating and re-zeroing the eye and two pack buffers (~200 MB/call), plus a metadata .item() sync. The intra-chunk matrices are still computed by Triton and injected (the sub-chunk normalization that fixes the real-gate overflow), but the scratch tensors are now served from a per-(Hv,K,V,device) grow-only workspace and the eye/metadata are recomputed only when cu_seqlens changes (sync-free, so all KDA layers in a forward share one computation). Net: the prefill now beats Triton at all sizes (1.08–1.52x) with no change to the numerics. An earlier attempt to move the off-diagonal KKT into the kernel via tensor-core MMA was abandoned — the row-/col-masked operands [NB,T,Hv,K] inflate operand memory ~4x, making it net-negative; the diagonal block is per-channel-gated and has no efficient in-kernel (non-serial) form.

Test plan

  • python -m pytest test/registered/attention/test_kda_prefill_cutedsl.py on a Blackwell (SM100) GPU.
  • e2e: python -m sglang.launch_server --model-path moonshotai/Kimi-Linear-48B-A3B-Instruct --tp-size 2 --trust-remote-code --linear-attn-prefill-backend cutedsl --linear-attn-decode-backend triton, then gsm8k.

Checklist

  • Correctness validated on real model (gsm8k 0.920) and unit tests (5 passed).
  • Regression test added for the failure mode (realistic gate magnitudes).
  • Prefill faster than Triton (1.08–1.52x) via host-side workspace reuse.

CI States

Latest PR Test (Base): ❌ Run #27093943363
Latest PR Test (Extra): ✅ Run #27093943326

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance SM100/Blackwell CuTeDSL prefill pipeline for Kimi Delta Attention (KDA), featuring a fused Triton prologue and three specialized CuTeDSL kernels (KKT-inverse, recurrent-state update, and output). It also adds comprehensive correctness tests and benchmarking scripts. The review feedback highlights a critical concurrency issue where the global scratch workspace could suffer from data corruption if multiple CUDA streams execute concurrently, suggesting the inclusion of the stream ID in the workspace key. Additionally, a performance optimization is suggested to avoid unnecessary type casting by initializing the token index tensor directly as an int32.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread python/sglang/srt/layers/attention/linear/kernels/kda_blackwell/__init__.py Outdated
CuteDSL chunk prefill pipeline for KDA on SM100, dispatched behind a cutedsl prefill backend that
falls back to Triton on pre-SM100 GPUs (inert by default).
@yuan-luo yuan-luo force-pushed the support_kda_sm100 branch from e3fee40 to 1840775 Compare June 7, 2026 09:25
@yuan-luo yuan-luo force-pushed the support_kda_sm100 branch from 7793956 to 9ce98db Compare June 7, 2026 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant