Skip to content

Commit e3fee40

Browse files
author
luoyuan.luo
committed
[KDA] Add SM100/Blackwell CuteDSL prefill kernels + backend wiring
CuteDSL chunk prefill pipeline for KDA on SM100, dispatched behind a cutedsl prefill backend that falls back to Triton on pre-SM100 GPUs (inert by default).
1 parent dd17638 commit e3fee40

9 files changed

Lines changed: 2992 additions & 5 deletions

File tree

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
"""
2+
Benchmark & Correctness: Triton KDA vs CuTeDSL KDA (prefill, SM100 Blackwell).
3+
4+
Compares:
5+
- Triton: sglang's chunk_kda (FLA chunkwise gated delta rule, per-channel gate)
6+
- CuteDSL: kda_blackwell pipeline (fused Triton prologue -> kkt_inv_uw -> h -> o)
7+
8+
KDA differs from GDN by a PER-CHANNEL decay gate (g is [T, H, K], not scalar).
9+
The cutedsl pipeline externalizes the per-channel decay into five pre-scaled
10+
key/query tensors computed by a fused Triton prologue; the chunk metadata is
11+
computed once and shared across layers in a real forward, so the benchmarked
12+
cutedsl path precomputes it outside the timed region (the realistic ceiling).
13+
14+
Correctness is checked against the token-by-token fused_recurrent_kda ground
15+
truth. Reports performance (ms, approx TFLOPS, TB/s, speedup).
16+
17+
Usage:
18+
python bench_kda_prefill_cutedsl.py # default sweep
19+
python bench_kda_prefill_cutedsl.py --mode bench # benchmark only
20+
python bench_kda_prefill_cutedsl.py --mode correctness # correctness only
21+
"""
22+
23+
import argparse
24+
import os
25+
import sys
26+
27+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "python"))
28+
29+
import torch
30+
import torch.nn.functional as F
31+
32+
from sglang.srt.layers.attention.fla.kda import chunk_kda, fused_recurrent_kda
33+
from sglang.srt.layers.attention.linear.kernels.kda_blackwell import prepare_metadata
34+
from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_h import (
35+
kda_h_cutedsl,
36+
)
37+
from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_kkt_inv_uw import (
38+
kkt_inv_uw_cutedsl,
39+
)
40+
from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_o import (
41+
kda_o_cutedsl,
42+
)
43+
from sglang.srt.layers.attention.linear.kernels.kda_blackwell.prologue import (
44+
kda_prologue,
45+
)
46+
47+
BT = 64 # chunk size
48+
49+
# ---------------------------------------------------------------------------
50+
# Helpers
51+
# ---------------------------------------------------------------------------
52+
53+
54+
def _l2norm(x: torch.Tensor) -> torch.Tensor:
55+
return F.normalize(x.float(), p=2, dim=-1)
56+
57+
58+
def kda_flops(total_seq_len, num_heads, head_k, head_v):
59+
"""Per-token-per-head: k@v outer (2*K*V) + q@state (2*K*V), plus the intra-chunk
60+
KKT (2*K*K averaged over the chunk). Approximate (ignores the inverse)."""
61+
return total_seq_len * num_heads * (4 * head_k * head_v + 2 * head_k * head_k)
62+
63+
64+
def kda_bytes(total_seq_len, num_heads, head_k, head_v, num_seqs, dtype):
65+
elem = dtype.itemsize
66+
q_b = total_seq_len * num_heads * head_k * elem
67+
k_b = total_seq_len * num_heads * head_k * elem
68+
v_b = total_seq_len * num_heads * head_v * elem
69+
o_b = total_seq_len * num_heads * head_v * elem
70+
g_b = total_seq_len * num_heads * head_k * 4 # per-channel gate, fp32
71+
beta_b = total_seq_len * num_heads * 4
72+
state_b = 2 * num_seqs * num_heads * head_k * head_v * 4 # fp32 r/w
73+
return q_b + k_b + v_b + o_b + g_b + beta_b + state_b
74+
75+
76+
# ---------------------------------------------------------------------------
77+
# Input factory (single sequence per benchmark point, B=1)
78+
# ---------------------------------------------------------------------------
79+
80+
81+
def make_inputs(T, H, K, V, device, dtype, seed=42):
82+
torch.manual_seed(seed)
83+
q = _l2norm(torch.randn(1, T, H, K, device=device)).to(dtype)
84+
k = _l2norm(torch.randn(1, T, H, K, device=device)).to(dtype)
85+
v = torch.randn(1, T, H, V, device=device).to(dtype)
86+
# Mild per-channel gate (real Kimi-Linear regime; keeps exp() in fp32 range).
87+
A_log = torch.randn(H, device=device) * 0.5 - 1.5
88+
dt_bias = torch.randn(H, K, device=device) * 0.1
89+
g_raw = torch.randn(1, T, H, K, device=device)
90+
g_act = (
91+
-A_log.exp().view(1, 1, H, 1) * F.softplus(g_raw + dt_bias.view(1, 1, H, K))
92+
).float()
93+
beta = torch.sigmoid(torch.randn(1, T, H, device=device)).float()
94+
return dict(q=q, k=k, v=v, g_act=g_act, beta=beta, T=T, H=H, K=K, V=V)
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# Runners
99+
# ---------------------------------------------------------------------------
100+
101+
102+
def run_recurrent(inp, scale):
103+
"""Token-by-token ground truth. Returns (o [T,H,V], state [1,H,V,K])."""
104+
cu = torch.tensor([0, inp["T"]], dtype=torch.int64, device=inp["q"].device)
105+
h0 = torch.zeros(1, inp["H"], inp["V"], inp["K"], device=inp["q"].device)
106+
o, state = fused_recurrent_kda(
107+
q=inp["q"],
108+
k=inp["k"],
109+
v=inp["v"],
110+
g=inp["g_act"],
111+
beta=inp["beta"],
112+
scale=scale,
113+
initial_state=h0,
114+
inplace_final_state=False,
115+
use_qk_l2norm_in_kernel=False,
116+
cu_seqlens=cu,
117+
)
118+
return o[0], state
119+
120+
121+
def cutedsl_buffers(inp, num_sms, device):
122+
"""Precompute metadata + preallocate (shared across layers in a real forward)."""
123+
T, H, K, V = inp["T"], inp["H"], inp["K"], inp["V"]
124+
cu = torch.tensor([0, T], dtype=torch.int32, device=device)
125+
ci, co, tc, total = prepare_metadata(cu)
126+
pad_t = total * BT
127+
return dict(
128+
cu=cu,
129+
ci=ci,
130+
co=co,
131+
tc=tc,
132+
total=total,
133+
num_sms=num_sms,
134+
h0=torch.zeros(1, H, V, K, device=device, dtype=torch.float32),
135+
U=torch.empty(pad_t, H, V, device=device, dtype=torch.bfloat16),
136+
W=torch.empty(pad_t, H, K, device=device, dtype=torch.bfloat16),
137+
V_new=torch.empty(pad_t, H, V, device=device, dtype=torch.bfloat16),
138+
h_chunks=torch.empty(total, H, V, K, device=device, dtype=torch.bfloat16),
139+
ht=torch.empty(1, H, V, K, device=device, dtype=torch.float32),
140+
o=torch.empty(T, H, V, device=device, dtype=torch.bfloat16),
141+
)
142+
143+
144+
def run_cutedsl_pipeline(inp, buf, scale):
145+
"""The fused-prologue + 3 cutedsl kernels (metadata precomputed in buf)."""
146+
q3, k3, v3 = inp["q"][0], inp["k"][0], inp["v"][0]
147+
g3, beta3 = inp["g_act"][0], inp["beta"][0].contiguous()
148+
KL, KR, KG, qg, qg2, g_cu = kda_prologue(
149+
q3, k3, g3, scale, buf["cu"], buf["ci"], buf["total"]
150+
)
151+
kkt_inv_uw_cutedsl(
152+
KL,
153+
KR,
154+
KG,
155+
v3,
156+
buf["U"],
157+
buf["W"],
158+
beta3,
159+
buf["cu"],
160+
buf["ci"],
161+
buf["tc"],
162+
num_sms=buf["num_sms"],
163+
)
164+
kda_h_cutedsl(
165+
KR,
166+
buf["U"],
167+
buf["W"],
168+
buf["V_new"],
169+
g_cu,
170+
buf["h_chunks"],
171+
buf["h0"],
172+
buf["ht"],
173+
buf["cu"],
174+
buf["co"],
175+
)
176+
kda_o_cutedsl(
177+
qg,
178+
qg2,
179+
KR,
180+
buf["V_new"],
181+
buf["h_chunks"],
182+
buf["o"],
183+
buf["cu"],
184+
buf["ci"],
185+
buf["tc"],
186+
num_sms=buf["num_sms"],
187+
)
188+
return buf["o"], buf["ht"]
189+
190+
191+
# ---------------------------------------------------------------------------
192+
# Correctness
193+
# ---------------------------------------------------------------------------
194+
195+
196+
def check_shape(T, H, K, V, device, dtype, num_sms):
197+
tag = f"T={T:>5} H={H:>2} K={K:>3} V={V:>3}"
198+
if K != 128 or V != 128:
199+
print(f" [SKIP] {tag} (cutedsl requires K=V=128)")
200+
return True
201+
202+
scale = K**-0.5
203+
inp = make_inputs(T, H, K, V, device, dtype)
204+
o_ref, state_ref = run_recurrent(inp, scale)
205+
206+
try:
207+
buf = cutedsl_buffers(inp, num_sms, device)
208+
o, ht = run_cutedsl_pipeline(inp, buf, scale)
209+
torch.cuda.synchronize()
210+
except Exception as e: # noqa: BLE001
211+
print(f" [SKIP] {tag} (cutedsl error: {e})")
212+
return True
213+
214+
finite = bool(torch.isfinite(o).all() and torch.isfinite(ht).all())
215+
o_err = (o.float() - o_ref.float()).abs().max().item()
216+
s_err = (ht.float() - state_ref.float()).abs().max().item()
217+
ok = finite and o_err < 1e-2 and s_err < 5e-2
218+
status = "PASS" if ok else "FAIL"
219+
print(
220+
f" [{status}] {tag} | o_err {o_err:.2e} state_err {s_err:.2e} finite={finite}"
221+
)
222+
return ok
223+
224+
225+
# ---------------------------------------------------------------------------
226+
# Benchmark
227+
# ---------------------------------------------------------------------------
228+
229+
230+
def bench_shape(T, H, K, V, device, dtype, num_sms):
231+
import triton.testing
232+
233+
if K != 128 or V != 128:
234+
print(f" [SKIP] T={T} H={H} K={K} V={V} (cutedsl K=V=128 only)")
235+
return
236+
237+
scale = K**-0.5
238+
inp = make_inputs(T, H, K, V, device, dtype)
239+
q, k, v = inp["q"], inp["k"], inp["v"]
240+
g_act, beta = inp["g_act"], inp["beta"]
241+
242+
h0f = torch.zeros(1, H, K, V, device=device, dtype=torch.float32)
243+
idx = torch.zeros(1, dtype=torch.int32, device=device)
244+
245+
def fn_triton():
246+
chunk_kda(
247+
q=q,
248+
k=k,
249+
v=v,
250+
g=g_act,
251+
beta=beta,
252+
scale=scale,
253+
initial_state=h0f,
254+
initial_state_indices=idx,
255+
use_qk_l2norm_in_kernel=False,
256+
cu_seqlens=None,
257+
A_log=None,
258+
dt_bias=None,
259+
lower_bound=None,
260+
)
261+
262+
buf = cutedsl_buffers(inp, num_sms, device)
263+
264+
def fn_cutedsl():
265+
run_cutedsl_pipeline(inp, buf, scale)
266+
267+
quantiles = [0.5, 0.2, 0.8]
268+
fn_triton()
269+
fn_cutedsl()
270+
torch.cuda.synchronize()
271+
272+
ms_triton, _, _ = triton.testing.do_bench_cudagraph(fn_triton, quantiles=quantiles)
273+
ms_cutedsl, _, _ = triton.testing.do_bench_cudagraph(
274+
fn_cutedsl, quantiles=quantiles
275+
)
276+
277+
flops = kda_flops(T, H, K, V)
278+
mem_bytes = kda_bytes(T, H, K, V, 1, dtype)
279+
speedup = ms_triton / ms_cutedsl if ms_cutedsl > 0 else float("inf")
280+
print(
281+
f" {H:>3} {T:>7} | "
282+
f"{ms_triton:>8.3f} {flops / ms_triton / 1e9:>7.2f} {mem_bytes / ms_triton / 1e9:>7.2f} | "
283+
f"{ms_cutedsl:>8.3f} {flops / ms_cutedsl / 1e9:>7.2f} {mem_bytes / ms_cutedsl / 1e9:>7.2f} | "
284+
f"{speedup:>7.2f}x"
285+
)
286+
287+
288+
# ---------------------------------------------------------------------------
289+
# Main
290+
# ---------------------------------------------------------------------------
291+
292+
293+
def run_correctness(device, dtype, H, num_sms):
294+
print("=" * 72)
295+
print("Correctness: cutedsl pipeline vs fused_recurrent_kda (ground truth)")
296+
print("=" * 72)
297+
all_pass = True
298+
for T in (128, 192, 256, 512, 1024):
299+
if not check_shape(T, H, 128, 128, device, dtype, num_sms):
300+
all_pass = False
301+
print("\nALL PASSED." if all_pass else "\nSOME FAILED.")
302+
return all_pass
303+
304+
305+
def run_benchmark(device, dtype, args, num_sms):
306+
print()
307+
print("=" * 92)
308+
print("Benchmark: Triton chunk_kda vs CuTeDSL pipeline (do_bench_cudagraph)")
309+
print("=" * 92)
310+
print(f" Device SMs={num_sms}, K=V=128, dtype={dtype}, metadata precomputed")
311+
print(
312+
f" {'H':>3} {'T':>7} | "
313+
f"{'tri(ms)':>8} {'TFLOP':>7} {'TB/s':>7} | "
314+
f"{'cute(ms)':>8} {'TFLOP':>7} {'TB/s':>7} | {'speedup':>8}"
315+
)
316+
print(" " + "-" * 84)
317+
for H in args.num_heads:
318+
for T in args.seq_lens:
319+
bench_shape(T, H, 128, 128, device, dtype, num_sms)
320+
321+
322+
def main():
323+
parser = argparse.ArgumentParser(
324+
description="Benchmark & Correctness: Triton KDA vs CuTeDSL KDA (SM100)"
325+
)
326+
parser.add_argument(
327+
"--mode", choices=["all", "correctness", "bench"], default="all"
328+
)
329+
parser.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16")
330+
parser.add_argument("--num-heads", type=int, nargs="+", default=[32])
331+
parser.add_argument(
332+
"--seq-lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 8192]
333+
)
334+
args = parser.parse_args()
335+
336+
device = "cuda"
337+
dtype = getattr(torch, args.dtype)
338+
cap = torch.cuda.get_device_capability()
339+
print(f"Device: {torch.cuda.get_device_name()} (SM {cap[0]}{cap[1]})")
340+
if cap[0] < 10:
341+
print("ERROR: CuTeDSL KDA prefill requires SM100+ (Blackwell). Exiting.")
342+
return 1
343+
num_sms = torch.cuda.get_device_properties(0).multi_processor_count
344+
345+
if args.mode in ("all", "correctness"):
346+
all_pass = run_correctness(device, dtype, args.num_heads[0], num_sms)
347+
if not all_pass and args.mode == "all":
348+
print("\nSkipping benchmark due to correctness failures.")
349+
return 1
350+
351+
if args.mode in ("all", "bench"):
352+
run_benchmark(device, dtype, args, num_sms)
353+
return 0
354+
355+
356+
if __name__ == "__main__":
357+
sys.exit(main())

python/sglang/srt/layers/attention/linear/kda_backend.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,28 @@ def __init__(
6060

6161
if prefill_backend.is_triton():
6262
self.extend_kernel = triton_kernel
63+
elif prefill_backend.is_cutedsl():
64+
if not is_cuda():
65+
raise ValueError("KDA CuTe DSL backend requires CUDA")
66+
from sglang.srt.layers.attention.linear.kernels.kda_cutedsl import (
67+
CuteDSLKDAKernel,
68+
)
69+
70+
cutedsl_kernel = CuteDSLKDAKernel()
71+
if getattr(cutedsl_kernel, "supports_prefill", False):
72+
# SM100 chunk prefill pipeline.
73+
self.extend_kernel = cutedsl_kernel
74+
else:
75+
# CuTe DSL prefill kernels need SM100 (Blackwell); on older GPUs
76+
# fall back to the Triton chunk kernel.
77+
self.extend_kernel = triton_kernel
78+
rank0_log(
79+
"KDA cutedsl prefill needs SM100; falling back to Triton extend."
80+
)
6381
else:
6482
raise ValueError(
6583
f"Unsupported KDA prefill backend: {prefill_backend}. "
66-
"KDA currently only supports 'triton'."
84+
"KDA supports 'triton' or 'cutedsl' (cutedsl prefill needs SM100)."
6785
)
6886

6987
self.supports_packed_decode = getattr(

0 commit comments

Comments
 (0)