Skip to content

Commit 2ec880b

Browse files
committed
add blackwell bench
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 8b88000 commit 2ec880b

1 file changed

Lines changed: 378 additions & 0 deletions

File tree

Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Performance benchmark for Blackwell GDN (Gated Delta Network) prefill kernel.
19+
20+
Compares FlashInfer's SM100 GDN prefill against FLA baseline.
21+
22+
Usage:
23+
python bench_blackwell_gdn_prefill.py --sweep
24+
python bench_blackwell_gdn_prefill.py --batch-size 8 --seq-len 1024
25+
python bench_blackwell_gdn_prefill.py --varlen --cu-seqlens 0 512 1024 2048
26+
"""
27+
28+
import argparse
29+
import sys
30+
from typing import List
31+
32+
import numpy as np
33+
import torch
34+
import torch.nn.functional as F
35+
36+
from flashinfer.gdn_prefill import chunk_gated_delta_rule
37+
from flashinfer.testing import bench_gpu_time
38+
from flashinfer.utils import is_sm100a_supported
39+
40+
try:
41+
from fla.ops.gated_delta_rule.chunk import chunk_gated_delta_rule_fwd as fla_base
42+
43+
_has_fla = True
44+
except ImportError:
45+
_has_fla = False
46+
47+
48+
def _make_inputs(
49+
total_len: int,
50+
num_seqs: int,
51+
num_qk_heads: int,
52+
num_v_heads: int,
53+
head_dim: int,
54+
dtype: torch.dtype,
55+
device: str = "cuda",
56+
use_initial_state: bool = True,
57+
):
58+
"""Create input tensors in FlashInfer 3D format (total_len, H, D)."""
59+
num_o_heads = max(num_qk_heads, num_v_heads)
60+
61+
q = torch.randn(total_len, num_qk_heads, head_dim, dtype=dtype, device=device)
62+
k = F.normalize(
63+
torch.randn(total_len, num_qk_heads, head_dim, dtype=torch.float32, device=device),
64+
p=2,
65+
dim=-1,
66+
).to(dtype)
67+
v = torch.randn(total_len, num_v_heads, head_dim, dtype=dtype, device=device)
68+
g = F.logsigmoid(
69+
torch.rand(total_len, num_o_heads, dtype=torch.float32, device=device)
70+
)
71+
beta = torch.rand(
72+
total_len, num_o_heads, dtype=torch.float32, device=device
73+
).sigmoid()
74+
75+
h0 = None
76+
if use_initial_state:
77+
h0 = torch.randn(
78+
num_seqs, num_o_heads, head_dim, head_dim,
79+
dtype=torch.float32, device=device,
80+
)
81+
82+
o = torch.empty(total_len, num_o_heads, head_dim, dtype=dtype, device=device)
83+
s_out = torch.empty(
84+
num_seqs, num_o_heads, head_dim, head_dim,
85+
dtype=torch.float32, device=device,
86+
)
87+
88+
return q, k, v, g, beta, h0, o, s_out
89+
90+
91+
def benchmark_fixlen(
92+
batch_size: int,
93+
seq_len: int,
94+
num_qk_heads: int,
95+
num_v_heads: int,
96+
head_dim: int,
97+
dtype: torch.dtype = torch.bfloat16,
98+
warmup_iters: int = 10,
99+
benchmark_iters: int = 100,
100+
use_initial_state: bool = True,
101+
) -> dict:
102+
"""Benchmark GDN with fixed-length sequences."""
103+
device = "cuda"
104+
total_len = batch_size * seq_len
105+
106+
q, k, v, g, beta, h0, o, s_out = _make_inputs(
107+
total_len, batch_size, num_qk_heads, num_v_heads, head_dim, dtype,
108+
use_initial_state=use_initial_state,
109+
)
110+
cu_seqlens = torch.arange(
111+
0, total_len + 1, seq_len, dtype=torch.int64, device=device
112+
)
113+
114+
def fn_gdn():
115+
chunk_gated_delta_rule(
116+
q, k, v, g, beta, None, h0, True, cu_seqlens, False, o, s_out,
117+
)
118+
119+
gdn_times = bench_gpu_time(
120+
fn_gdn, enable_cupti=True,
121+
dry_run_iters=warmup_iters, repeat_iters=benchmark_iters,
122+
)
123+
gdn_ms = float(np.median(gdn_times))
124+
125+
result = {
126+
"batch_size": batch_size,
127+
"seq_len": seq_len,
128+
"num_qk_heads": num_qk_heads,
129+
"num_v_heads": num_v_heads,
130+
"head_dim": head_dim,
131+
"gdn_ms": gdn_ms,
132+
}
133+
134+
# FLA baseline (only when qk_heads == v_heads, FLA doesn't support GVA)
135+
if _has_fla and num_qk_heads == num_v_heads:
136+
# FLA expects 4D (B, T, H, D)
137+
q4 = q.view(batch_size, seq_len, num_qk_heads, head_dim)
138+
k4 = k.view(batch_size, seq_len, num_qk_heads, head_dim)
139+
v4 = v.view(batch_size, seq_len, num_v_heads, head_dim)
140+
g4 = g.view(batch_size, seq_len, num_v_heads)
141+
beta4 = beta.view(batch_size, seq_len, num_v_heads)
142+
143+
def fn_fla():
144+
fla_base(q4, k4, v4, g4, beta4, None, initial_state=h0, output_final_state=True)
145+
146+
fla_times = bench_gpu_time(
147+
fn_fla, enable_cupti=True,
148+
dry_run_iters=warmup_iters, repeat_iters=benchmark_iters,
149+
)
150+
fla_ms = float(np.median(fla_times))
151+
result["fla_ms"] = fla_ms
152+
result["speedup"] = fla_ms / gdn_ms if gdn_ms > 0 else float("nan")
153+
154+
return result
155+
156+
157+
def benchmark_varlen(
158+
cu_seqlens: List[int],
159+
num_qk_heads: int,
160+
num_v_heads: int,
161+
head_dim: int,
162+
dtype: torch.dtype = torch.bfloat16,
163+
warmup_iters: int = 10,
164+
benchmark_iters: int = 100,
165+
) -> dict:
166+
"""Benchmark GDN with variable-length sequences."""
167+
device = "cuda"
168+
total_len = cu_seqlens[-1]
169+
num_seqs = len(cu_seqlens) - 1
170+
171+
q, k, v, g, beta, h0, o, s_out = _make_inputs(
172+
total_len, num_seqs, num_qk_heads, num_v_heads, head_dim, dtype,
173+
)
174+
cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int64, device=device)
175+
176+
def fn_gdn():
177+
chunk_gated_delta_rule(
178+
q, k, v, g, beta, None, h0, True, cu_seqlens_t, False, o, s_out,
179+
)
180+
181+
gdn_times = bench_gpu_time(
182+
fn_gdn, enable_cupti=True,
183+
dry_run_iters=warmup_iters, repeat_iters=benchmark_iters,
184+
)
185+
gdn_ms = float(np.median(gdn_times))
186+
187+
return {
188+
"num_seqs": num_seqs,
189+
"total_len": total_len,
190+
"avg_seq_len": total_len // num_seqs,
191+
"num_qk_heads": num_qk_heads,
192+
"num_v_heads": num_v_heads,
193+
"head_dim": head_dim,
194+
"gdn_ms": gdn_ms,
195+
}
196+
197+
198+
def print_results_table(results: List[dict], title: str = "Benchmark Results"):
199+
"""Print benchmark results in a formatted table."""
200+
if not results:
201+
return
202+
203+
print(f"\n{'=' * 90}")
204+
print(f" {title}")
205+
print(f"{'=' * 90}")
206+
207+
# Collect all keys across all results for consistent columns
208+
keys = []
209+
seen = set()
210+
for r in results:
211+
for key in r:
212+
if key not in seen:
213+
keys.append(key)
214+
seen.add(key)
215+
216+
widths = {}
217+
for key in keys:
218+
max_len = len(key)
219+
for r in results:
220+
val = r.get(key, "")
221+
if isinstance(val, float):
222+
val_str = f"{val:.3f}" if val < 100 else f"{val:.1f}"
223+
else:
224+
val_str = str(val)
225+
max_len = max(max_len, len(val_str))
226+
widths[key] = max_len + 2
227+
228+
header = " | ".join(f"{key:^{widths[key]}}" for key in keys)
229+
print(header)
230+
print("-" * len(header))
231+
232+
for r in results:
233+
row = []
234+
for key in keys:
235+
val = r.get(key, "")
236+
if isinstance(val, float):
237+
val_str = f"{val:.3f}" if val < 100 else f"{val:.1f}"
238+
else:
239+
val_str = str(val)
240+
row.append(f"{val_str:^{widths[key]}}")
241+
print(" | ".join(row))
242+
243+
print(f"{'=' * 90}\n")
244+
245+
246+
def run_sweep(warmup_iters: int = 10, benchmark_iters: int = 100):
247+
"""Run the standard sweep matching PR #2742 configurations."""
248+
head_dim = 128
249+
250+
fixlen_configs = [
251+
# (batch_size, seq_len, num_qk_heads, num_v_heads)
252+
(1, 512, 96, 96),
253+
(1, 1024, 96, 96),
254+
(1, 4096, 96, 96),
255+
(1, 8192, 96, 96),
256+
(9, 512, 32, 32),
257+
(9, 1024, 32, 32),
258+
(9, 4096, 32, 32),
259+
(9, 8192, 32, 32),
260+
(33, 512, 32, 32),
261+
(33, 1024, 32, 32),
262+
(33, 4096, 32, 32),
263+
(33, 8192, 32, 32),
264+
(1, 512, 148, 148),
265+
(1, 1024, 148, 148),
266+
(1, 4096, 148, 148),
267+
(1, 8192, 148, 148),
268+
]
269+
270+
print(f"\n{'#' * 80}")
271+
print(" BLACKWELL GDN PREFILL BENCHMARK SWEEP")
272+
print(f" GPU: {torch.cuda.get_device_name(0)}")
273+
if _has_fla:
274+
print(" FLA baseline: available")
275+
else:
276+
print(" FLA baseline: not installed (pip install flash-linear-attention)")
277+
print(f"{'#' * 80}")
278+
279+
# Fixed-length benchmarks
280+
fixlen_results = []
281+
for i, (bs, sl, nqk, nv) in enumerate(fixlen_configs):
282+
label = f"[{i + 1}/{len(fixlen_configs)}] bs={bs}, sl={sl}, nqk={nqk}, nv={nv}"
283+
print(f" Running fixlen {label} ...", end="", flush=True)
284+
try:
285+
result = benchmark_fixlen(
286+
bs, sl, nqk, nv, head_dim,
287+
warmup_iters=warmup_iters, benchmark_iters=benchmark_iters,
288+
)
289+
fixlen_results.append(result)
290+
msg = f" GDN: {result['gdn_ms']:.3f} ms"
291+
if "fla_ms" in result:
292+
msg += f" FLA: {result['fla_ms']:.3f} ms ({result['speedup']:.2f}x)"
293+
print(msg)
294+
except Exception as e:
295+
print(f" FAILED: {e}")
296+
torch.cuda.empty_cache()
297+
298+
print_results_table(fixlen_results, "Fixed-Length Results")
299+
300+
# Variable-length benchmarks (same configs, uniform seqlens)
301+
varlen_results = []
302+
for i, (bs, sl, nqk, nv) in enumerate(fixlen_configs):
303+
cu_seqlens = [sl * j for j in range(bs + 1)]
304+
num_seqs = bs
305+
total_len = cu_seqlens[-1]
306+
label = f"[{i + 1}/{len(fixlen_configs)}] seqs={num_seqs}, total={total_len}, nqk={nqk}, nv={nv}"
307+
print(f" Running varlen {label} ...", end="", flush=True)
308+
try:
309+
result = benchmark_varlen(
310+
cu_seqlens, nqk, nv, head_dim,
311+
warmup_iters=warmup_iters, benchmark_iters=benchmark_iters,
312+
)
313+
varlen_results.append(result)
314+
print(f" GDN: {result['gdn_ms']:.3f} ms")
315+
except Exception as e:
316+
print(f" FAILED: {e}")
317+
torch.cuda.empty_cache()
318+
319+
print_results_table(varlen_results, "Variable-Length Results")
320+
321+
322+
def main():
323+
parser = argparse.ArgumentParser(
324+
description="Blackwell GDN Prefill Benchmark (SM100+)"
325+
)
326+
parser.add_argument("--batch-size", "-b", type=int, default=4)
327+
parser.add_argument("--seq-len", "-t", type=int, default=4096)
328+
parser.add_argument("--num-qk-heads", "-nqk", type=int, default=32)
329+
parser.add_argument("--num-v-heads", "-nv", type=int, default=32)
330+
parser.add_argument("--head-dim", "-d", type=int, default=128)
331+
parser.add_argument("--warmup", type=int, default=10)
332+
parser.add_argument("--iters", type=int, default=100)
333+
parser.add_argument("--varlen", action="store_true", help="Variable-length mode")
334+
parser.add_argument(
335+
"--cu-seqlens", type=int, nargs="+", default=None,
336+
help="Cumulative sequence lengths for varlen (e.g. 0 512 1024 2048)",
337+
)
338+
parser.add_argument("--sweep", action="store_true", help="Run full sweep")
339+
args = parser.parse_args()
340+
341+
device = torch.device("cuda")
342+
if not is_sm100a_supported(device):
343+
print("Error: This benchmark requires SM100+ (Blackwell) GPU.")
344+
sys.exit(1)
345+
346+
if args.head_dim != 128:
347+
print(f"Error: head_dim must be 128, got {args.head_dim}")
348+
sys.exit(1)
349+
350+
print(f"\n{'=' * 60}")
351+
print(" Blackwell GDN Prefill Benchmark")
352+
print(f" GPU: {torch.cuda.get_device_name(0)}")
353+
print(f"{'=' * 60}")
354+
355+
if args.sweep:
356+
run_sweep(warmup_iters=args.warmup, benchmark_iters=args.iters)
357+
return
358+
359+
if args.varlen:
360+
if args.cu_seqlens is not None:
361+
cu_seqlens = args.cu_seqlens
362+
else:
363+
cu_seqlens = [args.seq_len * i for i in range(args.batch_size + 1)]
364+
result = benchmark_varlen(
365+
cu_seqlens, args.num_qk_heads, args.num_v_heads, args.head_dim,
366+
warmup_iters=args.warmup, benchmark_iters=args.iters,
367+
)
368+
print_results_table([result], "Variable-Length Benchmark")
369+
else:
370+
result = benchmark_fixlen(
371+
args.batch_size, args.seq_len, args.num_qk_heads, args.num_v_heads,
372+
args.head_dim, warmup_iters=args.warmup, benchmark_iters=args.iters,
373+
)
374+
print_results_table([result], "Fixed-Length Benchmark")
375+
376+
377+
if __name__ == "__main__":
378+
main()

0 commit comments

Comments
 (0)