Skip to content

Commit 033500f

Browse files
authored
Merge branch 'main' into ameyn/bf16_baseline_fix
2 parents 185a1a7 + 9e3d8b9 commit 033500f

29 files changed

Lines changed: 7103 additions & 918 deletions

benchmarks/bench_topk.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def bench_median_ms(fn) -> float:
6464
enable_cupti=True,
6565
dry_run_iters=10,
6666
repeat_iters=100,
67+
use_cuda_graph=True,
6768
)
6869
return float(np.median(measurements))
6970

@@ -88,6 +89,7 @@ def bench_top_k_from_scores(
8889
"""Benchmark top-k on a pre-generated score tensor."""
8990
batch_size, seq_len = scores.shape
9091

92+
set_topk_algo("default")
9193
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
9294
lambda deterministic_mode: flashinfer.top_k(
9395
scores,
@@ -121,6 +123,12 @@ def bench_top_k_from_scores(
121123
result["torch_deterministic_us"] = torch_det_ms * 1e3
122124
result["speedup_vs_torch_deterministic"] = torch_det_ms / fi_ms
123125

126+
set_topk_algo("clusters")
127+
fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k))
128+
result["fast_topk_us"] = fast_topk_ms * 1e3
129+
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
130+
set_topk_algo("auto")
131+
124132
# SGLang comparison (only supports k=2048 and float32)
125133
if (
126134
compare_sglang
@@ -130,7 +138,7 @@ def bench_top_k_from_scores(
130138
):
131139
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
132140
sg_ms = bench_median_ms(
133-
lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None),
141+
lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None)
134142
)
135143
result["sglang_us"] = sg_ms * 1e3
136144
result["speedup_vs_sglang"] = sg_ms / fi_ms
@@ -345,7 +353,10 @@ def bench_page_table_transform(
345353
.expand(batch_size, -1)
346354
.contiguous()
347355
)
356+
use_cuda_graph = True
357+
enable_cupti = True
348358

359+
set_topk_algo("default")
349360
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
350361
lambda deterministic_mode: flashinfer.top_k_page_table_transform(
351362
scores,
@@ -370,13 +381,29 @@ def bench_page_table_transform(
370381
fi_ms / fi_nondeterministic_ms
371382
)
372383

384+
# FlashInfer clusters
385+
set_topk_algo("clusters")
386+
measurements = bench_gpu_time(
387+
lambda: flashinfer.top_k_page_table_transform(
388+
scores, src_page_table, lengths, k
389+
),
390+
enable_cupti=enable_cupti,
391+
dry_run_iters=10,
392+
repeat_iters=100,
393+
use_cuda_graph=use_cuda_graph,
394+
)
395+
fast_topk_ms = np.median(measurements)
396+
result["fast_topk_us"] = fast_topk_ms * 1e3
397+
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
398+
set_topk_algo("auto")
399+
373400
# SGLang comparison (only supports k=2048 and float32)
374401
if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32:
375402
cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda")
376403
sg_ms = bench_median_ms(
377404
lambda: sgl_kernel.fast_topk_transform_fused(
378405
scores, lengths, src_page_table, cu_seqlens_q, k
379-
),
406+
)
380407
)
381408
result["sglang_us"] = sg_ms * 1e3
382409
result["speedup_vs_sglang"] = sg_ms / fi_ms
@@ -399,7 +426,10 @@ def bench_ragged_transform(
399426
offsets = torch.arange(
400427
0, batch_size * seq_len, seq_len, device="cuda", dtype=torch.int32
401428
)
429+
use_cuda_graph = True
430+
enable_cupti = True
402431

432+
set_topk_algo("default")
403433
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
404434
lambda deterministic_mode: flashinfer.top_k_ragged_transform(
405435
scores,
@@ -424,12 +454,26 @@ def bench_ragged_transform(
424454
fi_ms / fi_nondeterministic_ms
425455
)
426456

457+
# FlashInfer clusters
458+
set_topk_algo("clusters")
459+
measurements = bench_gpu_time(
460+
lambda: flashinfer.top_k_ragged_transform(scores, offsets, lengths, k),
461+
enable_cupti=enable_cupti,
462+
dry_run_iters=10,
463+
repeat_iters=100,
464+
use_cuda_graph=use_cuda_graph,
465+
)
466+
fast_topk_ms = np.median(measurements)
467+
result["fast_topk_us"] = fast_topk_ms * 1e3
468+
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
469+
set_topk_algo("auto")
470+
427471
# SGLang comparison (only supports k=2048 and float32)
428472
if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32:
429473
sg_ms = bench_median_ms(
430474
lambda: sgl_kernel.fast_topk_transform_ragged_fused(
431475
scores, lengths, offsets, k
432-
),
476+
)
433477
)
434478
result["sglang_us"] = sg_ms * 1e3
435479
result["speedup_vs_sglang"] = sg_ms / fi_ms
@@ -637,13 +681,14 @@ def main():
637681
header = (
638682
f"{'batch':>6} {'seq_len':>10} {'k':>6} | "
639683
f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
684+
f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
640685
)
641686
if args.compare_torch_deterministic and not args.deterministic:
642687
header += f" {'torch.det':>12} {'Speedup':>10}"
643688
if args.compare_sglang:
644689
header += f" {'SGLang':>12} {'Speedup':>10}"
645690
print(header)
646-
divider_len = 96 if args.deterministic else 72
691+
divider_len = 96 if args.deterministic else 115
647692
if args.compare_torch_deterministic and not args.deterministic:
648693
divider_len += 24
649694
if args.compare_sglang:
@@ -677,6 +722,8 @@ def main():
677722
f"{result['flashinfer_us']:>12.2f}us {result['torch_us']:>10.2f}us "
678723
f"{result['speedup_vs_torch']:>9.2f}x"
679724
)
725+
if "fast_topk_us" in result:
726+
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
680727
if "torch_deterministic_us" in result:
681728
line += (
682729
f" {result['torch_deterministic_us']:>10.2f}us "
@@ -733,11 +780,12 @@ def main():
733780
header = (
734781
f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | "
735782
f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
783+
f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
736784
)
737785
if args.compare_torch_deterministic and not args.deterministic:
738786
header += f" {'torch.det':>12} {'Speedup':>10}"
739787
print(header)
740-
divider_len = 110 if args.deterministic else 86
788+
divider_len = 110 if args.deterministic else 129
741789
if args.compare_torch_deterministic and not args.deterministic:
742790
divider_len += 24
743791
print("-" * divider_len)
@@ -788,6 +836,8 @@ def main():
788836
f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us "
789837
f"{result['speedup_vs_torch']:>9.2f}x"
790838
)
839+
if "fast_topk_us" in result:
840+
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
791841
if "torch_deterministic_us" in result:
792842
line += (
793843
f" {result['torch_deterministic_us']:>10.2f}us "
@@ -826,13 +876,13 @@ def main():
826876
f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}"
827877
)
828878
else:
829-
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
879+
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
830880
if args.compare_sglang:
831881
header += f" {'SGLang':>12} {'Speedup':>10}"
832882
print(header)
833-
divider_len = 87 if args.deterministic else 70
883+
divider_len = 87 if args.deterministic else 109
834884
if args.compare_sglang:
835-
divider_len += 20
885+
divider_len += 24
836886
print("-" * divider_len)
837887

838888
for batch_size in batch_sizes:
@@ -862,6 +912,8 @@ def main():
862912
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
863913
f"{result['flashinfer_us']:>10.2f}us"
864914
)
915+
if "fast_topk_us" in result:
916+
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
865917
if "sglang_us" in result:
866918
line += (
867919
f" {result['sglang_us']:>10.2f}us "
@@ -901,13 +953,13 @@ def main():
901953
f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}"
902954
)
903955
else:
904-
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
956+
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
905957
if args.compare_sglang:
906958
header += f" {'SGLang':>12} {'Speedup':>10}"
907959
print(header)
908-
divider_len = 87 if args.deterministic else 70
960+
divider_len = 87 if args.deterministic else 109
909961
if args.compare_sglang:
910-
divider_len += 20
962+
divider_len += 24
911963
print("-" * divider_len)
912964

913965
for batch_size in batch_sizes:
@@ -937,6 +989,8 @@ def main():
937989
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
938990
f"{result['flashinfer_us']:>10.2f}us"
939991
)
992+
if "fast_topk_us" in result:
993+
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
940994
if "sglang_us" in result:
941995
line += (
942996
f" {result['sglang_us']:>10.2f}us "

0 commit comments

Comments
 (0)