|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +""" |
| 18 | +Performance benchmark for Blackwell GDN (Gated Delta Network) prefill kernel. |
| 19 | +
|
| 20 | +Compares FlashInfer's SM100 GDN prefill against FLA baseline. |
| 21 | +
|
| 22 | +Usage: |
| 23 | + python bench_blackwell_gdn_prefill.py --sweep |
| 24 | + python bench_blackwell_gdn_prefill.py --batch-size 8 --seq-len 1024 |
| 25 | + python bench_blackwell_gdn_prefill.py --varlen --cu-seqlens 0 512 1024 2048 |
| 26 | +""" |
| 27 | + |
| 28 | +import argparse |
| 29 | +import sys |
| 30 | +from typing import List |
| 31 | + |
| 32 | +import numpy as np |
| 33 | +import torch |
| 34 | +import torch.nn.functional as F |
| 35 | + |
| 36 | +from flashinfer.gdn_prefill import chunk_gated_delta_rule |
| 37 | +from flashinfer.testing import bench_gpu_time |
| 38 | +from flashinfer.utils import is_sm100a_supported |
| 39 | + |
| 40 | +try: |
| 41 | + from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_fwd as fla_base |
| 42 | + |
| 43 | + _has_fla = True |
| 44 | +except ImportError: |
| 45 | + _has_fla = False |
| 46 | + |
| 47 | + |
| 48 | +def _make_inputs( |
| 49 | + total_len: int, |
| 50 | + num_seqs: int, |
| 51 | + num_qk_heads: int, |
| 52 | + num_v_heads: int, |
| 53 | + head_dim: int, |
| 54 | + dtype: torch.dtype, |
| 55 | + device: str = "cuda", |
| 56 | + use_initial_state: bool = True, |
| 57 | +): |
| 58 | + """Create input tensors in FlashInfer 3D format (total_len, H, D).""" |
| 59 | + num_o_heads = max(num_qk_heads, num_v_heads) |
| 60 | + |
| 61 | + q = torch.randn(total_len, num_qk_heads, head_dim, dtype=dtype, device=device) |
| 62 | + k = F.normalize( |
| 63 | + torch.randn(total_len, num_qk_heads, head_dim, dtype=torch.float32, device=device), |
| 64 | + p=2, |
| 65 | + dim=-1, |
| 66 | + ).to(dtype) |
| 67 | + v = torch.randn(total_len, num_v_heads, head_dim, dtype=dtype, device=device) |
| 68 | + g = F.logsigmoid( |
| 69 | + torch.rand(total_len, num_o_heads, dtype=torch.float32, device=device) |
| 70 | + ) |
| 71 | + beta = torch.rand( |
| 72 | + total_len, num_o_heads, dtype=torch.float32, device=device |
| 73 | + ).sigmoid() |
| 74 | + |
| 75 | + h0 = None |
| 76 | + if use_initial_state: |
| 77 | + h0 = torch.randn( |
| 78 | + num_seqs, num_o_heads, head_dim, head_dim, |
| 79 | + dtype=torch.float32, device=device, |
| 80 | + ) |
| 81 | + |
| 82 | + o = torch.empty(total_len, num_o_heads, head_dim, dtype=dtype, device=device) |
| 83 | + s_out = torch.empty( |
| 84 | + num_seqs, num_o_heads, head_dim, head_dim, |
| 85 | + dtype=torch.float32, device=device, |
| 86 | + ) |
| 87 | + |
| 88 | + return q, k, v, g, beta, h0, o, s_out |
| 89 | + |
| 90 | + |
| 91 | +def benchmark_fixlen( |
| 92 | + batch_size: int, |
| 93 | + seq_len: int, |
| 94 | + num_qk_heads: int, |
| 95 | + num_v_heads: int, |
| 96 | + head_dim: int, |
| 97 | + dtype: torch.dtype = torch.bfloat16, |
| 98 | + warmup_iters: int = 10, |
| 99 | + benchmark_iters: int = 100, |
| 100 | + use_initial_state: bool = True, |
| 101 | +) -> dict: |
| 102 | + """Benchmark GDN with fixed-length sequences.""" |
| 103 | + device = "cuda" |
| 104 | + total_len = batch_size * seq_len |
| 105 | + |
| 106 | + q, k, v, g, beta, h0, o, s_out = _make_inputs( |
| 107 | + total_len, batch_size, num_qk_heads, num_v_heads, head_dim, dtype, |
| 108 | + use_initial_state=use_initial_state, |
| 109 | + ) |
| 110 | + cu_seqlens = torch.arange( |
| 111 | + 0, total_len + 1, seq_len, dtype=torch.int64, device=device |
| 112 | + ) |
| 113 | + |
| 114 | + def fn_gdn(): |
| 115 | + chunk_gated_delta_rule( |
| 116 | + q, k, v, g, beta, None, h0, True, cu_seqlens, False, o, s_out, |
| 117 | + ) |
| 118 | + |
| 119 | + gdn_times = bench_gpu_time( |
| 120 | + fn_gdn, enable_cupti=True, |
| 121 | + dry_run_iters=warmup_iters, repeat_iters=benchmark_iters, |
| 122 | + ) |
| 123 | + gdn_ms = float(np.median(gdn_times)) |
| 124 | + |
| 125 | + result = { |
| 126 | + "batch_size": batch_size, |
| 127 | + "seq_len": seq_len, |
| 128 | + "num_qk_heads": num_qk_heads, |
| 129 | + "num_v_heads": num_v_heads, |
| 130 | + "head_dim": head_dim, |
| 131 | + "gdn_ms": gdn_ms, |
| 132 | + } |
| 133 | + |
| 134 | + # FLA baseline (only when qk_heads == v_heads, FLA doesn't support GVA) |
| 135 | + if _has_fla and num_qk_heads == num_v_heads: |
| 136 | + # FLA expects 4D (B, T, H, D) |
| 137 | + q4 = q.view(batch_size, seq_len, num_qk_heads, head_dim) |
| 138 | + k4 = k.view(batch_size, seq_len, num_qk_heads, head_dim) |
| 139 | + v4 = v.view(batch_size, seq_len, num_v_heads, head_dim) |
| 140 | + g4 = g.view(batch_size, seq_len, num_v_heads) |
| 141 | + beta4 = beta.view(batch_size, seq_len, num_v_heads) |
| 142 | + |
| 143 | + def fn_fla(): |
| 144 | + fla_base(q4, k4, v4, g4, beta4, None, initial_state=h0, output_final_state=True) |
| 145 | + |
| 146 | + fla_times = bench_gpu_time( |
| 147 | + fn_fla, enable_cupti=True, |
| 148 | + dry_run_iters=warmup_iters, repeat_iters=benchmark_iters, |
| 149 | + ) |
| 150 | + fla_ms = float(np.median(fla_times)) |
| 151 | + result["fla_ms"] = fla_ms |
| 152 | + result["speedup"] = fla_ms / gdn_ms if gdn_ms > 0 else float("nan") |
| 153 | + |
| 154 | + return result |
| 155 | + |
| 156 | + |
| 157 | +def benchmark_varlen( |
| 158 | + cu_seqlens: List[int], |
| 159 | + num_qk_heads: int, |
| 160 | + num_v_heads: int, |
| 161 | + head_dim: int, |
| 162 | + dtype: torch.dtype = torch.bfloat16, |
| 163 | + warmup_iters: int = 10, |
| 164 | + benchmark_iters: int = 100, |
| 165 | +) -> dict: |
| 166 | + """Benchmark GDN with variable-length sequences.""" |
| 167 | + device = "cuda" |
| 168 | + total_len = cu_seqlens[-1] |
| 169 | + num_seqs = len(cu_seqlens) - 1 |
| 170 | + |
| 171 | + q, k, v, g, beta, h0, o, s_out = _make_inputs( |
| 172 | + total_len, num_seqs, num_qk_heads, num_v_heads, head_dim, dtype, |
| 173 | + ) |
| 174 | + cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int64, device=device) |
| 175 | + |
| 176 | + def fn_gdn(): |
| 177 | + chunk_gated_delta_rule( |
| 178 | + q, k, v, g, beta, None, h0, True, cu_seqlens_t, False, o, s_out, |
| 179 | + ) |
| 180 | + |
| 181 | + gdn_times = bench_gpu_time( |
| 182 | + fn_gdn, enable_cupti=True, |
| 183 | + dry_run_iters=warmup_iters, repeat_iters=benchmark_iters, |
| 184 | + ) |
| 185 | + gdn_ms = float(np.median(gdn_times)) |
| 186 | + |
| 187 | + return { |
| 188 | + "num_seqs": num_seqs, |
| 189 | + "total_len": total_len, |
| 190 | + "avg_seq_len": total_len // num_seqs, |
| 191 | + "num_qk_heads": num_qk_heads, |
| 192 | + "num_v_heads": num_v_heads, |
| 193 | + "head_dim": head_dim, |
| 194 | + "gdn_ms": gdn_ms, |
| 195 | + } |
| 196 | + |
| 197 | + |
| 198 | +def print_results_table(results: List[dict], title: str = "Benchmark Results"): |
| 199 | + """Print benchmark results in a formatted table.""" |
| 200 | + if not results: |
| 201 | + return |
| 202 | + |
| 203 | + print(f"\n{'=' * 90}") |
| 204 | + print(f" {title}") |
| 205 | + print(f"{'=' * 90}") |
| 206 | + |
| 207 | + # Collect all keys across all results for consistent columns |
| 208 | + keys = [] |
| 209 | + seen = set() |
| 210 | + for r in results: |
| 211 | + for key in r: |
| 212 | + if key not in seen: |
| 213 | + keys.append(key) |
| 214 | + seen.add(key) |
| 215 | + |
| 216 | + widths = {} |
| 217 | + for key in keys: |
| 218 | + max_len = len(key) |
| 219 | + for r in results: |
| 220 | + val = r.get(key, "") |
| 221 | + if isinstance(val, float): |
| 222 | + val_str = f"{val:.3f}" if val < 100 else f"{val:.1f}" |
| 223 | + else: |
| 224 | + val_str = str(val) |
| 225 | + max_len = max(max_len, len(val_str)) |
| 226 | + widths[key] = max_len + 2 |
| 227 | + |
| 228 | + header = " | ".join(f"{key:^{widths[key]}}" for key in keys) |
| 229 | + print(header) |
| 230 | + print("-" * len(header)) |
| 231 | + |
| 232 | + for r in results: |
| 233 | + row = [] |
| 234 | + for key in keys: |
| 235 | + val = r.get(key, "") |
| 236 | + if isinstance(val, float): |
| 237 | + val_str = f"{val:.3f}" if val < 100 else f"{val:.1f}" |
| 238 | + else: |
| 239 | + val_str = str(val) |
| 240 | + row.append(f"{val_str:^{widths[key]}}") |
| 241 | + print(" | ".join(row)) |
| 242 | + |
| 243 | + print(f"{'=' * 90}\n") |
| 244 | + |
| 245 | + |
| 246 | +def run_sweep(warmup_iters: int = 10, benchmark_iters: int = 100): |
| 247 | + """Run the standard sweep matching PR #2742 configurations.""" |
| 248 | + head_dim = 128 |
| 249 | + |
| 250 | + fixlen_configs = [ |
| 251 | + # (batch_size, seq_len, num_qk_heads, num_v_heads) |
| 252 | + (1, 512, 96, 96), |
| 253 | + (1, 1024, 96, 96), |
| 254 | + (1, 4096, 96, 96), |
| 255 | + (1, 8192, 96, 96), |
| 256 | + (9, 512, 32, 32), |
| 257 | + (9, 1024, 32, 32), |
| 258 | + (9, 4096, 32, 32), |
| 259 | + (9, 8192, 32, 32), |
| 260 | + (33, 512, 32, 32), |
| 261 | + (33, 1024, 32, 32), |
| 262 | + (33, 4096, 32, 32), |
| 263 | + (33, 8192, 32, 32), |
| 264 | + (1, 512, 148, 148), |
| 265 | + (1, 1024, 148, 148), |
| 266 | + (1, 4096, 148, 148), |
| 267 | + (1, 8192, 148, 148), |
| 268 | + ] |
| 269 | + |
| 270 | + print(f"\n{'#' * 80}") |
| 271 | + print(" BLACKWELL GDN PREFILL BENCHMARK SWEEP") |
| 272 | + print(f" GPU: {torch.cuda.get_device_name(0)}") |
| 273 | + if _has_fla: |
| 274 | + print(" FLA baseline: available") |
| 275 | + else: |
| 276 | + print(" FLA baseline: not installed (pip install flash-linear-attention)") |
| 277 | + print(f"{'#' * 80}") |
| 278 | + |
| 279 | + # Fixed-length benchmarks |
| 280 | + fixlen_results = [] |
| 281 | + for i, (bs, sl, nqk, nv) in enumerate(fixlen_configs): |
| 282 | + label = f"[{i + 1}/{len(fixlen_configs)}] bs={bs}, sl={sl}, nqk={nqk}, nv={nv}" |
| 283 | + print(f" Running fixlen {label} ...", end="", flush=True) |
| 284 | + try: |
| 285 | + result = benchmark_fixlen( |
| 286 | + bs, sl, nqk, nv, head_dim, |
| 287 | + warmup_iters=warmup_iters, benchmark_iters=benchmark_iters, |
| 288 | + ) |
| 289 | + fixlen_results.append(result) |
| 290 | + msg = f" GDN: {result['gdn_ms']:.3f} ms" |
| 291 | + if "fla_ms" in result: |
| 292 | + msg += f" FLA: {result['fla_ms']:.3f} ms ({result['speedup']:.2f}x)" |
| 293 | + print(msg) |
| 294 | + except Exception as e: |
| 295 | + print(f" FAILED: {e}") |
| 296 | + torch.cuda.empty_cache() |
| 297 | + |
| 298 | + print_results_table(fixlen_results, "Fixed-Length Results") |
| 299 | + |
| 300 | + # Variable-length benchmarks (same configs, uniform seqlens) |
| 301 | + varlen_results = [] |
| 302 | + for i, (bs, sl, nqk, nv) in enumerate(fixlen_configs): |
| 303 | + cu_seqlens = [sl * j for j in range(bs + 1)] |
| 304 | + num_seqs = bs |
| 305 | + total_len = cu_seqlens[-1] |
| 306 | + label = f"[{i + 1}/{len(fixlen_configs)}] seqs={num_seqs}, total={total_len}, nqk={nqk}, nv={nv}" |
| 307 | + print(f" Running varlen {label} ...", end="", flush=True) |
| 308 | + try: |
| 309 | + result = benchmark_varlen( |
| 310 | + cu_seqlens, nqk, nv, head_dim, |
| 311 | + warmup_iters=warmup_iters, benchmark_iters=benchmark_iters, |
| 312 | + ) |
| 313 | + varlen_results.append(result) |
| 314 | + print(f" GDN: {result['gdn_ms']:.3f} ms") |
| 315 | + except Exception as e: |
| 316 | + print(f" FAILED: {e}") |
| 317 | + torch.cuda.empty_cache() |
| 318 | + |
| 319 | + print_results_table(varlen_results, "Variable-Length Results") |
| 320 | + |
| 321 | + |
| 322 | +def main(): |
| 323 | + parser = argparse.ArgumentParser( |
| 324 | + description="Blackwell GDN Prefill Benchmark (SM100+)" |
| 325 | + ) |
| 326 | + parser.add_argument("--batch-size", "-b", type=int, default=4) |
| 327 | + parser.add_argument("--seq-len", "-t", type=int, default=4096) |
| 328 | + parser.add_argument("--num-qk-heads", "-nqk", type=int, default=32) |
| 329 | + parser.add_argument("--num-v-heads", "-nv", type=int, default=32) |
| 330 | + parser.add_argument("--head-dim", "-d", type=int, default=128) |
| 331 | + parser.add_argument("--warmup", type=int, default=10) |
| 332 | + parser.add_argument("--iters", type=int, default=100) |
| 333 | + parser.add_argument("--varlen", action="store_true", help="Variable-length mode") |
| 334 | + parser.add_argument( |
| 335 | + "--cu-seqlens", type=int, nargs="+", default=None, |
| 336 | + help="Cumulative sequence lengths for varlen (e.g. 0 512 1024 2048)", |
| 337 | + ) |
| 338 | + parser.add_argument("--sweep", action="store_true", help="Run full sweep") |
| 339 | + args = parser.parse_args() |
| 340 | + |
| 341 | + device = torch.device("cuda") |
| 342 | + if not is_sm100a_supported(device): |
| 343 | + print("Error: This benchmark requires SM100+ (Blackwell) GPU.") |
| 344 | + sys.exit(1) |
| 345 | + |
| 346 | + if args.head_dim != 128: |
| 347 | + print(f"Error: head_dim must be 128, got {args.head_dim}") |
| 348 | + sys.exit(1) |
| 349 | + |
| 350 | + print(f"\n{'=' * 60}") |
| 351 | + print(" Blackwell GDN Prefill Benchmark") |
| 352 | + print(f" GPU: {torch.cuda.get_device_name(0)}") |
| 353 | + print(f"{'=' * 60}") |
| 354 | + |
| 355 | + if args.sweep: |
| 356 | + run_sweep(warmup_iters=args.warmup, benchmark_iters=args.iters) |
| 357 | + return |
| 358 | + |
| 359 | + if args.varlen: |
| 360 | + if args.cu_seqlens is not None: |
| 361 | + cu_seqlens = args.cu_seqlens |
| 362 | + else: |
| 363 | + cu_seqlens = [args.seq_len * i for i in range(args.batch_size + 1)] |
| 364 | + result = benchmark_varlen( |
| 365 | + cu_seqlens, args.num_qk_heads, args.num_v_heads, args.head_dim, |
| 366 | + warmup_iters=args.warmup, benchmark_iters=args.iters, |
| 367 | + ) |
| 368 | + print_results_table([result], "Variable-Length Benchmark") |
| 369 | + else: |
| 370 | + result = benchmark_fixlen( |
| 371 | + args.batch_size, args.seq_len, args.num_qk_heads, args.num_v_heads, |
| 372 | + args.head_dim, warmup_iters=args.warmup, benchmark_iters=args.iters, |
| 373 | + ) |
| 374 | + print_results_table([result], "Fixed-Length Benchmark") |
| 375 | + |
| 376 | + |
| 377 | +if __name__ == "__main__": |
| 378 | + main() |
0 commit comments