@@ -64,6 +64,7 @@ def bench_median_ms(fn) -> float:
6464 enable_cupti = True ,
6565 dry_run_iters = 10 ,
6666 repeat_iters = 100 ,
67+ use_cuda_graph = True ,
6768 )
6869 return float (np .median (measurements ))
6970
@@ -88,6 +89,7 @@ def bench_top_k_from_scores(
8889 """Benchmark top-k on a pre-generated score tensor."""
8990 batch_size , seq_len = scores .shape
9091
92+ set_topk_algo ("default" )
9193 fi_ms , fi_nondeterministic_ms = bench_flashinfer_modes (
9294 lambda deterministic_mode : flashinfer .top_k (
9395 scores ,
@@ -121,6 +123,12 @@ def bench_top_k_from_scores(
121123 result ["torch_deterministic_us" ] = torch_det_ms * 1e3
122124 result ["speedup_vs_torch_deterministic" ] = torch_det_ms / fi_ms
123125
126+ set_topk_algo ("clusters" )
127+ fast_topk_ms = bench_median_ms (lambda : flashinfer .top_k (scores , k ))
128+ result ["fast_topk_us" ] = fast_topk_ms * 1e3
129+ result ["speedup_vs_flashinfer" ] = fi_ms / fast_topk_ms
130+ set_topk_algo ("auto" )
131+
124132 # SGLang comparison (only supports k=2048 and float32)
125133 if (
126134 compare_sglang
@@ -130,7 +138,7 @@ def bench_top_k_from_scores(
130138 ):
131139 lengths = torch .full ((batch_size ,), seq_len , dtype = torch .int32 , device = "cuda" )
132140 sg_ms = bench_median_ms (
133- lambda : sgl_kernel .fast_topk_v2 (scores , lengths , k , row_starts = None ),
141+ lambda : sgl_kernel .fast_topk_v2 (scores , lengths , k , row_starts = None )
134142 )
135143 result ["sglang_us" ] = sg_ms * 1e3
136144 result ["speedup_vs_sglang" ] = sg_ms / fi_ms
@@ -345,7 +353,10 @@ def bench_page_table_transform(
345353 .expand (batch_size , - 1 )
346354 .contiguous ()
347355 )
356+ use_cuda_graph = True
357+ enable_cupti = True
348358
359+ set_topk_algo ("default" )
349360 fi_ms , fi_nondeterministic_ms = bench_flashinfer_modes (
350361 lambda deterministic_mode : flashinfer .top_k_page_table_transform (
351362 scores ,
@@ -370,13 +381,29 @@ def bench_page_table_transform(
370381 fi_ms / fi_nondeterministic_ms
371382 )
372383
384+ # FlashInfer clusters
385+ set_topk_algo ("clusters" )
386+ measurements = bench_gpu_time (
387+ lambda : flashinfer .top_k_page_table_transform (
388+ scores , src_page_table , lengths , k
389+ ),
390+ enable_cupti = enable_cupti ,
391+ dry_run_iters = 10 ,
392+ repeat_iters = 100 ,
393+ use_cuda_graph = use_cuda_graph ,
394+ )
395+ fast_topk_ms = np .median (measurements )
396+ result ["fast_topk_us" ] = fast_topk_ms * 1e3
397+ result ["speedup_vs_flashinfer" ] = fi_ms / fast_topk_ms
398+ set_topk_algo ("auto" )
399+
373400 # SGLang comparison (only supports k=2048 and float32)
374401 if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch .float32 :
375402 cu_seqlens_q = torch .arange (0 , batch_size + 1 , dtype = torch .int32 , device = "cuda" )
376403 sg_ms = bench_median_ms (
377404 lambda : sgl_kernel .fast_topk_transform_fused (
378405 scores , lengths , src_page_table , cu_seqlens_q , k
379- ),
406+ )
380407 )
381408 result ["sglang_us" ] = sg_ms * 1e3
382409 result ["speedup_vs_sglang" ] = sg_ms / fi_ms
@@ -399,7 +426,10 @@ def bench_ragged_transform(
399426 offsets = torch .arange (
400427 0 , batch_size * seq_len , seq_len , device = "cuda" , dtype = torch .int32
401428 )
429+ use_cuda_graph = True
430+ enable_cupti = True
402431
432+ set_topk_algo ("default" )
403433 fi_ms , fi_nondeterministic_ms = bench_flashinfer_modes (
404434 lambda deterministic_mode : flashinfer .top_k_ragged_transform (
405435 scores ,
@@ -424,12 +454,26 @@ def bench_ragged_transform(
424454 fi_ms / fi_nondeterministic_ms
425455 )
426456
457+ # FlashInfer clusters
458+ set_topk_algo ("clusters" )
459+ measurements = bench_gpu_time (
460+ lambda : flashinfer .top_k_ragged_transform (scores , offsets , lengths , k ),
461+ enable_cupti = enable_cupti ,
462+ dry_run_iters = 10 ,
463+ repeat_iters = 100 ,
464+ use_cuda_graph = use_cuda_graph ,
465+ )
466+ fast_topk_ms = np .median (measurements )
467+ result ["fast_topk_us" ] = fast_topk_ms * 1e3
468+ result ["speedup_vs_flashinfer" ] = fi_ms / fast_topk_ms
469+ set_topk_algo ("auto" )
470+
427471 # SGLang comparison (only supports k=2048 and float32)
428472 if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch .float32 :
429473 sg_ms = bench_median_ms (
430474 lambda : sgl_kernel .fast_topk_transform_ragged_fused (
431475 scores , lengths , offsets , k
432- ),
476+ )
433477 )
434478 result ["sglang_us" ] = sg_ms * 1e3
435479 result ["speedup_vs_sglang" ] = sg_ms / fi_ms
@@ -637,13 +681,14 @@ def main():
637681 header = (
638682 f"{ 'batch' :>6} { 'seq_len' :>10} { 'k' :>6} | "
639683 f"{ 'FlashInfer' :>12} { 'torch.topk' :>12} { 'Speedup' :>10} "
684+ f" { 'Clusters' :>12} { 'Speedup Clusters vs. Default' :>29} "
640685 )
641686 if args .compare_torch_deterministic and not args .deterministic :
642687 header += f" { 'torch.det' :>12} { 'Speedup' :>10} "
643688 if args .compare_sglang :
644689 header += f" { 'SGLang' :>12} { 'Speedup' :>10} "
645690 print (header )
646- divider_len = 96 if args .deterministic else 72
691+ divider_len = 96 if args .deterministic else 115
647692 if args .compare_torch_deterministic and not args .deterministic :
648693 divider_len += 24
649694 if args .compare_sglang :
@@ -677,6 +722,8 @@ def main():
677722 f"{ result ['flashinfer_us' ]:>12.2f} us { result ['torch_us' ]:>10.2f} us "
678723 f"{ result ['speedup_vs_torch' ]:>9.2f} x"
679724 )
725+ if "fast_topk_us" in result :
726+ line += f" { result ['fast_topk_us' ]:>10.2f} us { result ['speedup_vs_flashinfer' ]:>28.2f} x"
680727 if "torch_deterministic_us" in result :
681728 line += (
682729 f" { result ['torch_deterministic_us' ]:>10.2f} us "
@@ -733,11 +780,12 @@ def main():
733780 header = (
734781 f"{ 'case' :>24} { 'rows' :>8} { 'seq_len' :>10} { 'k' :>6} | "
735782 f"{ 'FlashInfer' :>12} { 'torch.topk' :>12} { 'Speedup' :>10} "
783+ f" { 'Clusters' :>12} { 'Speedup Clusters vs. Default' :>29} "
736784 )
737785 if args .compare_torch_deterministic and not args .deterministic :
738786 header += f" { 'torch.det' :>12} { 'Speedup' :>10} "
739787 print (header )
740- divider_len = 110 if args .deterministic else 86
788+ divider_len = 110 if args .deterministic else 129
741789 if args .compare_torch_deterministic and not args .deterministic :
742790 divider_len += 24
743791 print ("-" * divider_len )
@@ -788,6 +836,8 @@ def main():
788836 f"{ result ['flashinfer_us' ]:>10.2f} us { result ['torch_us' ]:>10.2f} us "
789837 f"{ result ['speedup_vs_torch' ]:>9.2f} x"
790838 )
839+ if "fast_topk_us" in result :
840+ line += f" { result ['fast_topk_us' ]:>10.2f} us { result ['speedup_vs_flashinfer' ]:>28.2f} x"
791841 if "torch_deterministic_us" in result :
792842 line += (
793843 f" { result ['torch_deterministic_us' ]:>10.2f} us "
@@ -826,13 +876,13 @@ def main():
826876 f"{ 'FlashInfer' :>12} { 'FlashInfer(det)' :>14} { 'DetSlowdown' :>11} "
827877 )
828878 else :
829- header = f"{ 'batch' :>6} { 'seq_len' :>10} { 'k' :>6} | { 'FlashInfer' :>12} "
879+ header = f"{ 'batch' :>6} { 'seq_len' :>10} { 'k' :>6} | { 'FlashInfer' :>12} { 'Clusters' :>12 } { 'Speedup Clusters vs. Default' :>29 } "
830880 if args .compare_sglang :
831881 header += f" { 'SGLang' :>12} { 'Speedup' :>10} "
832882 print (header )
833- divider_len = 87 if args .deterministic else 70
883+ divider_len = 87 if args .deterministic else 109
834884 if args .compare_sglang :
835- divider_len += 20
885+ divider_len += 24
836886 print ("-" * divider_len )
837887
838888 for batch_size in batch_sizes :
@@ -862,6 +912,8 @@ def main():
862912 f"{ result ['batch_size' ]:>6} { result ['seq_len' ]:>10} { result ['k' ]:>6} | "
863913 f"{ result ['flashinfer_us' ]:>10.2f} us"
864914 )
915+ if "fast_topk_us" in result :
916+ line += f" { result ['fast_topk_us' ]:>10.2f} us { result ['speedup_vs_flashinfer' ]:>28.2f} x"
865917 if "sglang_us" in result :
866918 line += (
867919 f" { result ['sglang_us' ]:>10.2f} us "
@@ -901,13 +953,13 @@ def main():
901953 f"{ 'FlashInfer' :>12} { 'FlashInfer(det)' :>14} { 'DetSlowdown' :>11} "
902954 )
903955 else :
904- header = f"{ 'batch' :>6} { 'seq_len' :>10} { 'k' :>6} | { 'FlashInfer' :>12} "
956+ header = f"{ 'batch' :>6} { 'seq_len' :>10} { 'k' :>6} | { 'FlashInfer' :>12} { 'Clusters' :>12 } { 'Speedup Clusters vs. Default' :>29 } "
905957 if args .compare_sglang :
906958 header += f" { 'SGLang' :>12} { 'Speedup' :>10} "
907959 print (header )
908- divider_len = 87 if args .deterministic else 70
960+ divider_len = 87 if args .deterministic else 109
909961 if args .compare_sglang :
910- divider_len += 20
962+ divider_len += 24
911963 print ("-" * divider_len )
912964
913965 for batch_size in batch_sizes :
@@ -937,6 +989,8 @@ def main():
937989 f"{ result ['batch_size' ]:>6} { result ['seq_len' ]:>10} { result ['k' ]:>6} | "
938990 f"{ result ['flashinfer_us' ]:>10.2f} us"
939991 )
992+ if "fast_topk_us" in result :
993+ line += f" { result ['fast_topk_us' ]:>10.2f} us { result ['speedup_vs_flashinfer' ]:>28.2f} x"
940994 if "sglang_us" in result :
941995 line += (
942996 f" { result ['sglang_us' ]:>10.2f} us "
0 commit comments