|
| 1 | +""" |
| 2 | +Benchmark & Correctness: Triton KDA vs CuTeDSL KDA (prefill, SM100 Blackwell). |
| 3 | +
|
| 4 | +Compares: |
| 5 | + - Triton: sglang's chunk_kda (FLA chunkwise gated delta rule, per-channel gate) |
| 6 | + - CuteDSL: kda_blackwell pipeline (fused Triton prologue -> kkt_inv_uw -> h -> o) |
| 7 | +
|
| 8 | +KDA differs from GDN by a PER-CHANNEL decay gate (g is [T, H, K], not scalar). |
| 9 | +The cutedsl pipeline externalizes the per-channel decay into five pre-scaled |
| 10 | +key/query tensors computed by a fused Triton prologue; the chunk metadata is |
| 11 | +computed once and shared across layers in a real forward, so the benchmarked |
| 12 | +cutedsl path precomputes it outside the timed region (the realistic ceiling). |
| 13 | +
|
| 14 | +Correctness is checked against the token-by-token fused_recurrent_kda ground |
| 15 | +truth. Reports performance (ms, approx TFLOPS, TB/s, speedup). |
| 16 | +
|
| 17 | +Usage: |
| 18 | + python bench_kda_prefill_cutedsl.py # default sweep |
| 19 | + python bench_kda_prefill_cutedsl.py --mode bench # benchmark only |
| 20 | + python bench_kda_prefill_cutedsl.py --mode correctness # correctness only |
| 21 | +""" |
| 22 | + |
| 23 | +import argparse |
| 24 | +import os |
| 25 | +import sys |
| 26 | + |
| 27 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "python")) |
| 28 | + |
| 29 | +import torch |
| 30 | +import torch.nn.functional as F |
| 31 | + |
| 32 | +from sglang.srt.layers.attention.fla.kda import chunk_kda, fused_recurrent_kda |
| 33 | +from sglang.srt.layers.attention.linear.kernels.kda_blackwell import prepare_metadata |
| 34 | +from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_h import ( |
| 35 | + kda_h_cutedsl, |
| 36 | +) |
| 37 | +from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_kkt_inv_uw import ( |
| 38 | + kkt_inv_uw_cutedsl, |
| 39 | +) |
| 40 | +from sglang.srt.layers.attention.linear.kernels.kda_blackwell.kernel_o import ( |
| 41 | + kda_o_cutedsl, |
| 42 | +) |
| 43 | +from sglang.srt.layers.attention.linear.kernels.kda_blackwell.prologue import ( |
| 44 | + kda_prologue, |
| 45 | +) |
| 46 | + |
| 47 | +BT = 64 # chunk size |
| 48 | + |
| 49 | +# --------------------------------------------------------------------------- |
| 50 | +# Helpers |
| 51 | +# --------------------------------------------------------------------------- |
| 52 | + |
| 53 | + |
| 54 | +def _l2norm(x: torch.Tensor) -> torch.Tensor: |
| 55 | + return F.normalize(x.float(), p=2, dim=-1) |
| 56 | + |
| 57 | + |
| 58 | +def kda_flops(total_seq_len, num_heads, head_k, head_v): |
| 59 | + """Per-token-per-head: k@v outer (2*K*V) + q@state (2*K*V), plus the intra-chunk |
| 60 | + KKT (2*K*K averaged over the chunk). Approximate (ignores the inverse).""" |
| 61 | + return total_seq_len * num_heads * (4 * head_k * head_v + 2 * head_k * head_k) |
| 62 | + |
| 63 | + |
| 64 | +def kda_bytes(total_seq_len, num_heads, head_k, head_v, num_seqs, dtype): |
| 65 | + elem = dtype.itemsize |
| 66 | + q_b = total_seq_len * num_heads * head_k * elem |
| 67 | + k_b = total_seq_len * num_heads * head_k * elem |
| 68 | + v_b = total_seq_len * num_heads * head_v * elem |
| 69 | + o_b = total_seq_len * num_heads * head_v * elem |
| 70 | + g_b = total_seq_len * num_heads * head_k * 4 # per-channel gate, fp32 |
| 71 | + beta_b = total_seq_len * num_heads * 4 |
| 72 | + state_b = 2 * num_seqs * num_heads * head_k * head_v * 4 # fp32 r/w |
| 73 | + return q_b + k_b + v_b + o_b + g_b + beta_b + state_b |
| 74 | + |
| 75 | + |
| 76 | +# --------------------------------------------------------------------------- |
| 77 | +# Input factory (single sequence per benchmark point, B=1) |
| 78 | +# --------------------------------------------------------------------------- |
| 79 | + |
| 80 | + |
| 81 | +def make_inputs(T, H, K, V, device, dtype, seed=42): |
| 82 | + torch.manual_seed(seed) |
| 83 | + q = _l2norm(torch.randn(1, T, H, K, device=device)).to(dtype) |
| 84 | + k = _l2norm(torch.randn(1, T, H, K, device=device)).to(dtype) |
| 85 | + v = torch.randn(1, T, H, V, device=device).to(dtype) |
| 86 | + # Mild per-channel gate (real Kimi-Linear regime; keeps exp() in fp32 range). |
| 87 | + A_log = torch.randn(H, device=device) * 0.5 - 1.5 |
| 88 | + dt_bias = torch.randn(H, K, device=device) * 0.1 |
| 89 | + g_raw = torch.randn(1, T, H, K, device=device) |
| 90 | + g_act = ( |
| 91 | + -A_log.exp().view(1, 1, H, 1) * F.softplus(g_raw + dt_bias.view(1, 1, H, K)) |
| 92 | + ).float() |
| 93 | + beta = torch.sigmoid(torch.randn(1, T, H, device=device)).float() |
| 94 | + return dict(q=q, k=k, v=v, g_act=g_act, beta=beta, T=T, H=H, K=K, V=V) |
| 95 | + |
| 96 | + |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | +# Runners |
| 99 | +# --------------------------------------------------------------------------- |
| 100 | + |
| 101 | + |
| 102 | +def run_recurrent(inp, scale): |
| 103 | + """Token-by-token ground truth. Returns (o [T,H,V], state [1,H,V,K]).""" |
| 104 | + cu = torch.tensor([0, inp["T"]], dtype=torch.int64, device=inp["q"].device) |
| 105 | + h0 = torch.zeros(1, inp["H"], inp["V"], inp["K"], device=inp["q"].device) |
| 106 | + o, state = fused_recurrent_kda( |
| 107 | + q=inp["q"], |
| 108 | + k=inp["k"], |
| 109 | + v=inp["v"], |
| 110 | + g=inp["g_act"], |
| 111 | + beta=inp["beta"], |
| 112 | + scale=scale, |
| 113 | + initial_state=h0, |
| 114 | + inplace_final_state=False, |
| 115 | + use_qk_l2norm_in_kernel=False, |
| 116 | + cu_seqlens=cu, |
| 117 | + ) |
| 118 | + return o[0], state |
| 119 | + |
| 120 | + |
| 121 | +def cutedsl_buffers(inp, num_sms, device): |
| 122 | + """Precompute metadata + preallocate (shared across layers in a real forward).""" |
| 123 | + T, H, K, V = inp["T"], inp["H"], inp["K"], inp["V"] |
| 124 | + cu = torch.tensor([0, T], dtype=torch.int32, device=device) |
| 125 | + ci, co, tc, total = prepare_metadata(cu) |
| 126 | + pad_t = total * BT |
| 127 | + return dict( |
| 128 | + cu=cu, |
| 129 | + ci=ci, |
| 130 | + co=co, |
| 131 | + tc=tc, |
| 132 | + total=total, |
| 133 | + num_sms=num_sms, |
| 134 | + h0=torch.zeros(1, H, V, K, device=device, dtype=torch.float32), |
| 135 | + U=torch.empty(pad_t, H, V, device=device, dtype=torch.bfloat16), |
| 136 | + W=torch.empty(pad_t, H, K, device=device, dtype=torch.bfloat16), |
| 137 | + V_new=torch.empty(pad_t, H, V, device=device, dtype=torch.bfloat16), |
| 138 | + h_chunks=torch.empty(total, H, V, K, device=device, dtype=torch.bfloat16), |
| 139 | + ht=torch.empty(1, H, V, K, device=device, dtype=torch.float32), |
| 140 | + o=torch.empty(T, H, V, device=device, dtype=torch.bfloat16), |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +def run_cutedsl_pipeline(inp, buf, scale): |
| 145 | + """The fused-prologue + 3 cutedsl kernels (metadata precomputed in buf).""" |
| 146 | + q3, k3, v3 = inp["q"][0], inp["k"][0], inp["v"][0] |
| 147 | + g3, beta3 = inp["g_act"][0], inp["beta"][0].contiguous() |
| 148 | + KL, KR, KG, qg, qg2, g_cu = kda_prologue( |
| 149 | + q3, k3, g3, scale, buf["cu"], buf["ci"], buf["total"] |
| 150 | + ) |
| 151 | + kkt_inv_uw_cutedsl( |
| 152 | + KL, |
| 153 | + KR, |
| 154 | + KG, |
| 155 | + v3, |
| 156 | + buf["U"], |
| 157 | + buf["W"], |
| 158 | + beta3, |
| 159 | + buf["cu"], |
| 160 | + buf["ci"], |
| 161 | + buf["tc"], |
| 162 | + num_sms=buf["num_sms"], |
| 163 | + ) |
| 164 | + kda_h_cutedsl( |
| 165 | + KR, |
| 166 | + buf["U"], |
| 167 | + buf["W"], |
| 168 | + buf["V_new"], |
| 169 | + g_cu, |
| 170 | + buf["h_chunks"], |
| 171 | + buf["h0"], |
| 172 | + buf["ht"], |
| 173 | + buf["cu"], |
| 174 | + buf["co"], |
| 175 | + ) |
| 176 | + kda_o_cutedsl( |
| 177 | + qg, |
| 178 | + qg2, |
| 179 | + KR, |
| 180 | + buf["V_new"], |
| 181 | + buf["h_chunks"], |
| 182 | + buf["o"], |
| 183 | + buf["cu"], |
| 184 | + buf["ci"], |
| 185 | + buf["tc"], |
| 186 | + num_sms=buf["num_sms"], |
| 187 | + ) |
| 188 | + return buf["o"], buf["ht"] |
| 189 | + |
| 190 | + |
| 191 | +# --------------------------------------------------------------------------- |
| 192 | +# Correctness |
| 193 | +# --------------------------------------------------------------------------- |
| 194 | + |
| 195 | + |
| 196 | +def check_shape(T, H, K, V, device, dtype, num_sms): |
| 197 | + tag = f"T={T:>5} H={H:>2} K={K:>3} V={V:>3}" |
| 198 | + if K != 128 or V != 128: |
| 199 | + print(f" [SKIP] {tag} (cutedsl requires K=V=128)") |
| 200 | + return True |
| 201 | + |
| 202 | + scale = K**-0.5 |
| 203 | + inp = make_inputs(T, H, K, V, device, dtype) |
| 204 | + o_ref, state_ref = run_recurrent(inp, scale) |
| 205 | + |
| 206 | + try: |
| 207 | + buf = cutedsl_buffers(inp, num_sms, device) |
| 208 | + o, ht = run_cutedsl_pipeline(inp, buf, scale) |
| 209 | + torch.cuda.synchronize() |
| 210 | + except Exception as e: # noqa: BLE001 |
| 211 | + print(f" [SKIP] {tag} (cutedsl error: {e})") |
| 212 | + return True |
| 213 | + |
| 214 | + finite = bool(torch.isfinite(o).all() and torch.isfinite(ht).all()) |
| 215 | + o_err = (o.float() - o_ref.float()).abs().max().item() |
| 216 | + s_err = (ht.float() - state_ref.float()).abs().max().item() |
| 217 | + ok = finite and o_err < 1e-2 and s_err < 5e-2 |
| 218 | + status = "PASS" if ok else "FAIL" |
| 219 | + print( |
| 220 | + f" [{status}] {tag} | o_err {o_err:.2e} state_err {s_err:.2e} finite={finite}" |
| 221 | + ) |
| 222 | + return ok |
| 223 | + |
| 224 | + |
| 225 | +# --------------------------------------------------------------------------- |
| 226 | +# Benchmark |
| 227 | +# --------------------------------------------------------------------------- |
| 228 | + |
| 229 | + |
| 230 | +def bench_shape(T, H, K, V, device, dtype, num_sms): |
| 231 | + import triton.testing |
| 232 | + |
| 233 | + if K != 128 or V != 128: |
| 234 | + print(f" [SKIP] T={T} H={H} K={K} V={V} (cutedsl K=V=128 only)") |
| 235 | + return |
| 236 | + |
| 237 | + scale = K**-0.5 |
| 238 | + inp = make_inputs(T, H, K, V, device, dtype) |
| 239 | + q, k, v = inp["q"], inp["k"], inp["v"] |
| 240 | + g_act, beta = inp["g_act"], inp["beta"] |
| 241 | + |
| 242 | + h0f = torch.zeros(1, H, K, V, device=device, dtype=torch.float32) |
| 243 | + idx = torch.zeros(1, dtype=torch.int32, device=device) |
| 244 | + |
| 245 | + def fn_triton(): |
| 246 | + chunk_kda( |
| 247 | + q=q, |
| 248 | + k=k, |
| 249 | + v=v, |
| 250 | + g=g_act, |
| 251 | + beta=beta, |
| 252 | + scale=scale, |
| 253 | + initial_state=h0f, |
| 254 | + initial_state_indices=idx, |
| 255 | + use_qk_l2norm_in_kernel=False, |
| 256 | + cu_seqlens=None, |
| 257 | + A_log=None, |
| 258 | + dt_bias=None, |
| 259 | + lower_bound=None, |
| 260 | + ) |
| 261 | + |
| 262 | + buf = cutedsl_buffers(inp, num_sms, device) |
| 263 | + |
| 264 | + def fn_cutedsl(): |
| 265 | + run_cutedsl_pipeline(inp, buf, scale) |
| 266 | + |
| 267 | + quantiles = [0.5, 0.2, 0.8] |
| 268 | + fn_triton() |
| 269 | + fn_cutedsl() |
| 270 | + torch.cuda.synchronize() |
| 271 | + |
| 272 | + ms_triton, _, _ = triton.testing.do_bench_cudagraph(fn_triton, quantiles=quantiles) |
| 273 | + ms_cutedsl, _, _ = triton.testing.do_bench_cudagraph( |
| 274 | + fn_cutedsl, quantiles=quantiles |
| 275 | + ) |
| 276 | + |
| 277 | + flops = kda_flops(T, H, K, V) |
| 278 | + mem_bytes = kda_bytes(T, H, K, V, 1, dtype) |
| 279 | + speedup = ms_triton / ms_cutedsl if ms_cutedsl > 0 else float("inf") |
| 280 | + print( |
| 281 | + f" {H:>3} {T:>7} | " |
| 282 | + f"{ms_triton:>8.3f} {flops / ms_triton / 1e9:>7.2f} {mem_bytes / ms_triton / 1e9:>7.2f} | " |
| 283 | + f"{ms_cutedsl:>8.3f} {flops / ms_cutedsl / 1e9:>7.2f} {mem_bytes / ms_cutedsl / 1e9:>7.2f} | " |
| 284 | + f"{speedup:>7.2f}x" |
| 285 | + ) |
| 286 | + |
| 287 | + |
| 288 | +# --------------------------------------------------------------------------- |
| 289 | +# Main |
| 290 | +# --------------------------------------------------------------------------- |
| 291 | + |
| 292 | + |
| 293 | +def run_correctness(device, dtype, H, num_sms): |
| 294 | + print("=" * 72) |
| 295 | + print("Correctness: cutedsl pipeline vs fused_recurrent_kda (ground truth)") |
| 296 | + print("=" * 72) |
| 297 | + all_pass = True |
| 298 | + for T in (128, 192, 256, 512, 1024): |
| 299 | + if not check_shape(T, H, 128, 128, device, dtype, num_sms): |
| 300 | + all_pass = False |
| 301 | + print("\nALL PASSED." if all_pass else "\nSOME FAILED.") |
| 302 | + return all_pass |
| 303 | + |
| 304 | + |
| 305 | +def run_benchmark(device, dtype, args, num_sms): |
| 306 | + print() |
| 307 | + print("=" * 92) |
| 308 | + print("Benchmark: Triton chunk_kda vs CuTeDSL pipeline (do_bench_cudagraph)") |
| 309 | + print("=" * 92) |
| 310 | + print(f" Device SMs={num_sms}, K=V=128, dtype={dtype}, metadata precomputed") |
| 311 | + print( |
| 312 | + f" {'H':>3} {'T':>7} | " |
| 313 | + f"{'tri(ms)':>8} {'TFLOP':>7} {'TB/s':>7} | " |
| 314 | + f"{'cute(ms)':>8} {'TFLOP':>7} {'TB/s':>7} | {'speedup':>8}" |
| 315 | + ) |
| 316 | + print(" " + "-" * 84) |
| 317 | + for H in args.num_heads: |
| 318 | + for T in args.seq_lens: |
| 319 | + bench_shape(T, H, 128, 128, device, dtype, num_sms) |
| 320 | + |
| 321 | + |
| 322 | +def main(): |
| 323 | + parser = argparse.ArgumentParser( |
| 324 | + description="Benchmark & Correctness: Triton KDA vs CuTeDSL KDA (SM100)" |
| 325 | + ) |
| 326 | + parser.add_argument( |
| 327 | + "--mode", choices=["all", "correctness", "bench"], default="all" |
| 328 | + ) |
| 329 | + parser.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16") |
| 330 | + parser.add_argument("--num-heads", type=int, nargs="+", default=[32]) |
| 331 | + parser.add_argument( |
| 332 | + "--seq-lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 8192] |
| 333 | + ) |
| 334 | + args = parser.parse_args() |
| 335 | + |
| 336 | + device = "cuda" |
| 337 | + dtype = getattr(torch, args.dtype) |
| 338 | + cap = torch.cuda.get_device_capability() |
| 339 | + print(f"Device: {torch.cuda.get_device_name()} (SM {cap[0]}{cap[1]})") |
| 340 | + if cap[0] < 10: |
| 341 | + print("ERROR: CuTeDSL KDA prefill requires SM100+ (Blackwell). Exiting.") |
| 342 | + return 1 |
| 343 | + num_sms = torch.cuda.get_device_properties(0).multi_processor_count |
| 344 | + |
| 345 | + if args.mode in ("all", "correctness"): |
| 346 | + all_pass = run_correctness(device, dtype, args.num_heads[0], num_sms) |
| 347 | + if not all_pass and args.mode == "all": |
| 348 | + print("\nSkipping benchmark due to correctness failures.") |
| 349 | + return 1 |
| 350 | + |
| 351 | + if args.mode in ("all", "bench"): |
| 352 | + run_benchmark(device, dtype, args, num_sms) |
| 353 | + return 0 |
| 354 | + |
| 355 | + |
| 356 | +if __name__ == "__main__": |
| 357 | + sys.exit(main()) |
0 commit comments