diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 0863320ab6..7f9638f99f 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -16,6 +16,7 @@ import torch import flashinfer +from flashinfer.topk import TopKTieBreak from flashinfer.testing.utils import bench_gpu_time @@ -79,10 +80,60 @@ def bench_flashinfer_modes( return selected_ms, nondeterministic_ms +TIE_BREAK_VARIANTS: tuple[tuple[str, TopKTieBreak], ...] = ( + ("small", TopKTieBreak.SMALL), + ("large", TopKTieBreak.LARGE), +) + + +def bench_tie_break_variants( + run_flashinfer_with_tie_break, baseline_ms: float +) -> dict[str, float]: + metrics: dict[str, float] = {} + for suffix, tie_break in TIE_BREAK_VARIANTS: + try: + tie_ms = bench_median_ms(lambda: run_flashinfer_with_tie_break(tie_break)) + metrics[f"flashinfer_tie_{suffix}_us"] = tie_ms * 1e3 + metrics[f"tie_{suffix}_slowdown_vs_baseline"] = tie_ms / baseline_ms + except RuntimeError as exc: + error_label = classify_benchmark_runtime_error(exc) + if error_label is None: + raise + metrics[f"flashinfer_tie_{suffix}_error"] = error_label + return metrics + + +def append_tie_break_header(header: str, enabled: bool) -> str: + if not enabled: + return header + return ( + header + + f" {'FlashInfer(tie-small)':>21} {'TieSmallSlowdown':>17}" + + f" {'FlashInfer(tie-large)':>21} {'TieLargeSlowdown':>17}" + ) + + +def append_tie_break_columns(line: str, result: dict, enabled: bool) -> str: + if not enabled: + return line + + def format_variant(suffix: str) -> str: + us_key = f"flashinfer_tie_{suffix}_us" + slowdown_key = f"tie_{suffix}_slowdown_vs_baseline" + error_key = f"flashinfer_tie_{suffix}_error" + if us_key in result: + return f" {result[us_key]:>19.2f}us {result[slowdown_key]:>16.2f}x" + error_label = result.get(error_key, "n/a") + return f" {error_label:>21} {'n/a':>16}" + + return line + format_variant("small") + format_variant("large") + + def bench_top_k_from_scores( scores: torch.Tensor, k: int, deterministic: bool = False, + compare_tie_break: bool = False, compare_torch_deterministic: bool = False, compare_sglang: bool = False, ) -> dict: @@ -95,6 +146,7 @@ def bench_top_k_from_scores( scores, k, deterministic=deterministic_mode, + tie_break=TopKTieBreak.NONE, ), deterministic, ) @@ -116,6 +168,22 @@ def bench_top_k_from_scores( torch_ms = bench_median_ms(lambda: torch.topk(scores, k, dim=-1)) result["torch_us"] = torch_ms * 1e3 result["speedup_vs_torch"] = torch_ms / fi_ms + if compare_tie_break: + # Align tie-break slowdowns with the DetSlowdown baseline when present. + baseline_ms = ( + fi_nondeterministic_ms if fi_nondeterministic_ms is not None else fi_ms + ) + result.update( + bench_tie_break_variants( + lambda tie_break: flashinfer.top_k( + scores, + k, + deterministic=True, + tie_break=tie_break, + ), + baseline_ms, + ) + ) if compare_torch_deterministic and not deterministic: with torch_deterministic_algorithms(True): @@ -289,6 +357,7 @@ def bench_dsa_top_k( dtype: torch.dtype = torch.bfloat16, input_pattern: str = "dsa_relu", deterministic: bool = False, + compare_tie_break: bool = False, compare_torch_deterministic: bool = False, compare_sglang: bool = False, causal_chunk: bool = False, @@ -305,6 +374,7 @@ def bench_dsa_top_k( scores=scores, k=k, deterministic=deterministic, + compare_tie_break=compare_tie_break, compare_torch_deterministic=compare_torch_deterministic, compare_sglang=compare_sglang, ) @@ -321,6 +391,7 @@ def bench_top_k( dtype: torch.dtype = torch.float32, input_pattern: str = "random", deterministic: bool = False, + compare_tie_break: bool = False, compare_torch_deterministic: bool = False, compare_sglang: bool = False, ) -> dict: @@ -330,6 +401,7 @@ def bench_top_k( scores=scores, k=k, deterministic=deterministic, + compare_tie_break=compare_tie_break, compare_torch_deterministic=compare_torch_deterministic, compare_sglang=compare_sglang, ) @@ -342,6 +414,7 @@ def bench_page_table_transform( dtype: torch.dtype = torch.float32, input_pattern: str = "random", deterministic: bool = False, + compare_tie_break: bool = False, compare_sglang: bool = False, ) -> dict: """Benchmark fused top_k + page table transform.""" @@ -364,6 +437,7 @@ def bench_page_table_transform( lengths, k, deterministic=deterministic_mode, + tie_break=TopKTieBreak.NONE, ), deterministic, ) @@ -408,6 +482,25 @@ def bench_page_table_transform( result["sglang_us"] = sg_ms * 1e3 result["speedup_vs_sglang"] = sg_ms / fi_ms + if compare_tie_break: + # Align tie-break slowdowns with the DetSlowdown baseline when present. + baseline_ms = ( + fi_nondeterministic_ms if fi_nondeterministic_ms is not None else fi_ms + ) + result.update( + bench_tie_break_variants( + lambda tie_break: flashinfer.top_k_page_table_transform( + scores, + src_page_table, + lengths, + k, + deterministic=True, + tie_break=tie_break, + ), + baseline_ms, + ) + ) + return result @@ -418,6 +511,7 @@ def bench_ragged_transform( dtype: torch.dtype = torch.float32, input_pattern: str = "random", deterministic: bool = False, + compare_tie_break: bool = False, compare_sglang: bool = False, ) -> dict: """Benchmark fused top_k + ragged index transform.""" @@ -437,6 +531,7 @@ def bench_ragged_transform( lengths, k, deterministic=deterministic_mode, + tie_break=TopKTieBreak.NONE, ), deterministic, ) @@ -477,6 +572,23 @@ def bench_ragged_transform( ) result["sglang_us"] = sg_ms * 1e3 result["speedup_vs_sglang"] = sg_ms / fi_ms + if compare_tie_break: + baseline_ms = ( + fi_nondeterministic_ms if fi_nondeterministic_ms is not None else fi_ms + ) + result.update( + bench_tie_break_variants( + lambda tie_break: flashinfer.top_k_ragged_transform( + scores, + offsets, + lengths, + k, + deterministic=True, + tie_break=tie_break, + ), + baseline_ms, + ) + ) return result @@ -525,6 +637,14 @@ def main(): action="store_true", help="Enable deterministic mode for FlashInfer top-k kernels", ) + parser.add_argument( + "--tie-break", + action="store_true", + help=( + "Also benchmark deterministic tie-break variants and report " + "FlashInfer(tie-small/tie-large) columns with slowdown aligned to DetSlowdown baseline" + ), + ) parser.add_argument( "--compare-torch-deterministic", action="store_true", @@ -567,6 +687,12 @@ def main(): dtype = parse_dtype(args.dtype) + if args.tie_break and not args.deterministic: + print( + "NOTE: --tie-break requires deterministic kernels; enabling --deterministic." + ) + args.deterministic = True + if args.compare_sglang and not HAS_SGL_KERNEL: print("WARNING: sgl_kernel not found, skipping SGLang comparison") args.compare_sglang = False @@ -590,6 +716,9 @@ def main(): "ERROR: --compare-algorithms is only meaningful with non-deterministic mode" ) return + if args.tie_break: + print("ERROR: --compare-algorithms does not support --tie-break") + return print("=" * 100) print( "Algorithm comparison: Multi-CTA vs Filtered " @@ -643,11 +772,12 @@ def main(): return if args.op in ["all", "top_k"]: + show_det_or_tie = args.deterministic or args.tie_break print("=" * 100) print( "top_k: Basic radix-based top-k selection " f"(dtype={dtype_str}, deterministic={args.deterministic}, " - f"pattern={args.input_pattern})" + f"pattern={args.input_pattern}, tie_break={args.tie_break})" ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") @@ -665,35 +795,37 @@ def main(): "NOTE: torch column uses torch.topk with " "torch.use_deterministic_algorithms(True)" ) + if args.tie_break: + print( + "NOTE: tie-break columns benchmark deterministic tie-small/tie-large; " + "slowdowns align with the same baseline as DetSlowdown" + ) print( "NOTE: default top-k sweep includes two extra large-batch/long-vocab " "stress cases beyond the original grid" ) print("=" * 100) - if args.deterministic: + if show_det_or_tie: + torch_label = "torch.det" if args.deterministic else "torch.topk" header = ( f"{'batch':>6} {'seq_len':>10} {'k':>6} | " - f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11} " - f"{'torch.det':>12} {'Speedup':>10}" + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" ) + header = append_tie_break_header(header, args.tie_break) + header += f" {torch_label:>12} {'Speedup':>10}" else: header = ( f"{'batch':>6} {'seq_len':>10} {'k':>6} | " f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}" ) - if args.compare_torch_deterministic and not args.deterministic: + if args.compare_torch_deterministic and not show_det_or_tie: header += f" {'torch.det':>12} {'Speedup':>10}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - divider_len = 96 if args.deterministic else 115 - if args.compare_torch_deterministic and not args.deterministic: - divider_len += 24 - if args.compare_sglang: - divider_len += 24 - print("-" * divider_len) + print("-" * len(header)) for case in top_k_cases: try: @@ -704,16 +836,35 @@ def main(): dtype, input_pattern=args.input_pattern, deterministic=args.deterministic, + compare_tie_break=args.tie_break, compare_torch_deterministic=args.compare_torch_deterministic, compare_sglang=args.compare_sglang, ) - if args.deterministic: + if show_det_or_tie: + nondet_us = result.get("flashinfer_nondeterministic_us") + if nondet_us is None: + nondet_us = result["flashinfer_us"] + det_us = result["flashinfer_us"] if args.deterministic else None + det_us_str = ( + f"{det_us:>12.2f}us" if det_us is not None else f"{'n/a':>14}" + ) + det_slowdown = result.get( + "deterministic_slowdown_vs_nondeterministic" + ) + det_slowdown_str = ( + f"{det_slowdown:>10.2f}x" + if det_slowdown is not None + else f"{'n/a':>11}" + ) line = ( f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_nondeterministic_us']:>10.2f}us " - f"{result['flashinfer_us']:>12.2f}us " - f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x " - f"{result['torch_us']:>10.2f}us " + f"{nondet_us:>10.2f}us " + f"{det_us_str} " + f"{det_slowdown_str}" + ) + line = append_tie_break_columns(line, result, args.tie_break) + line += ( + f" {result['torch_us']:>10.2f}us " f"{result['speedup_vs_torch']:>9.2f}x" ) else: @@ -724,7 +875,7 @@ def main(): ) if "fast_topk_us" in result: line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x" - if "torch_deterministic_us" in result: + if "torch_deterministic_us" in result and not show_det_or_tie: line += ( f" {result['torch_deterministic_us']:>10.2f}us " f"{result['speedup_vs_torch_deterministic']:>9.2f}x" @@ -748,11 +899,13 @@ def main(): raise if args.op in ["all", "dsa_topk"]: + show_det_or_tie = args.deterministic or args.tie_break print("\n" + "=" * 100) print( "dsa_topk: DeepSeek DSA-like indexer top-k workload " f"(dtype={dtype_str}, deterministic={args.deterministic}, " - f"dsa_pattern={args.dsa_input_pattern}, k={args.dsa_topk})" + f"dsa_pattern={args.dsa_input_pattern}, k={args.dsa_topk}, " + f"tie_break={args.tie_break})" ) if args.deterministic: print( @@ -768,27 +921,31 @@ def main(): "NOTE: torch column uses torch.topk with " "torch.use_deterministic_algorithms(True)" ) + if args.tie_break: + print( + "NOTE: tie-break columns benchmark deterministic tie-small/tie-large; " + "slowdowns align with the same baseline as DetSlowdown" + ) print("=" * 100) - if args.deterministic: + if show_det_or_tie: + torch_label = "torch.det" if args.deterministic else "torch.topk" header = ( f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | " - f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11} " - f"{'torch.det':>12} {'Speedup':>10}" + f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" ) + header = append_tie_break_header(header, args.tie_break) + header += f" {torch_label:>12} {'Speedup':>10}" else: header = ( f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | " f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}" ) - if args.compare_torch_deterministic and not args.deterministic: + if args.compare_torch_deterministic and not show_det_or_tie: header += f" {'torch.det':>12} {'Speedup':>10}" print(header) - divider_len = 110 if args.deterministic else 129 - if args.compare_torch_deterministic and not args.deterministic: - divider_len += 24 - print("-" * divider_len) + print("-" * len(header)) dsa_cases = [ # DeepSeek Sparse Attention proxy cases: @@ -817,17 +974,36 @@ def main(): dtype=dtype, input_pattern=args.dsa_input_pattern, deterministic=args.deterministic, + compare_tie_break=args.tie_break, compare_torch_deterministic=args.compare_torch_deterministic, compare_sglang=False, causal_chunk=case.causal_chunk, ) - if args.deterministic: + if show_det_or_tie: + nondet_us = result.get("flashinfer_nondeterministic_us") + if nondet_us is None: + nondet_us = result["flashinfer_us"] + det_us = result["flashinfer_us"] if args.deterministic else None + det_us_str = ( + f"{det_us:>12.2f}us" if det_us is not None else f"{'n/a':>14}" + ) + det_slowdown = result.get( + "deterministic_slowdown_vs_nondeterministic" + ) + det_slowdown_str = ( + f"{det_slowdown:>10.2f}x" + if det_slowdown is not None + else f"{'n/a':>11}" + ) line = ( f"{case.name:>24} {result['rows']:>8} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_nondeterministic_us']:>10.2f}us " - f"{result['flashinfer_us']:>12.2f}us " - f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x " - f"{result['torch_us']:>10.2f}us " + f"{nondet_us:>10.2f}us " + f"{det_us_str} " + f"{det_slowdown_str}" + ) + line = append_tie_break_columns(line, result, args.tie_break) + line += ( + f" {result['torch_us']:>10.2f}us " f"{result['speedup_vs_torch']:>9.2f}x" ) else: @@ -838,7 +1014,7 @@ def main(): ) if "fast_topk_us" in result: line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x" - if "torch_deterministic_us" in result: + if "torch_deterministic_us" in result and not show_det_or_tie: line += ( f" {result['torch_deterministic_us']:>10.2f}us " f"{result['speedup_vs_torch_deterministic']:>9.2f}x" @@ -856,10 +1032,12 @@ def main(): raise if args.op in ["all", "page_table"]: + show_det_or_tie = args.deterministic or args.tie_break print("\n" + "=" * 100) print( "top_k_page_table_transform: Fused top-k + page table gather " - f"(dtype={dtype_str}, deterministic={args.deterministic}, pattern={args.input_pattern})" + f"(dtype={dtype_str}, deterministic={args.deterministic}, " + f"pattern={args.input_pattern}, tie_break={args.tie_break})" ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") @@ -868,22 +1046,25 @@ def main(): "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " "for direct comparison" ) + if args.tie_break: + print( + "NOTE: tie-break columns benchmark deterministic tie-small/tie-large; " + "slowdowns align with the same baseline as DetSlowdown" + ) print("=" * 100) - if args.deterministic: + if show_det_or_tie: header = ( f"{'batch':>6} {'seq_len':>10} {'k':>6} | " f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" ) + header = append_tie_break_header(header, args.tie_break) else: header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - divider_len = 87 if args.deterministic else 109 - if args.compare_sglang: - divider_len += 24 - print("-" * divider_len) + print("-" * len(header)) for batch_size in batch_sizes: for seq_len in seq_lens: @@ -898,14 +1079,37 @@ def main(): dtype, input_pattern=args.input_pattern, deterministic=args.deterministic, + compare_tie_break=args.tie_break, compare_sglang=args.compare_sglang, ) - if args.deterministic: + if show_det_or_tie: + nondet_us = result.get("flashinfer_nondeterministic_us") + if nondet_us is None: + nondet_us = result["flashinfer_us"] + det_us = ( + result["flashinfer_us"] if args.deterministic else None + ) + det_us_str = ( + f"{det_us:>12.2f}us" + if det_us is not None + else f"{'n/a':>14}" + ) + det_slowdown = result.get( + "deterministic_slowdown_vs_nondeterministic" + ) + det_slowdown_str = ( + f"{det_slowdown:>10.2f}x" + if det_slowdown is not None + else f"{'n/a':>11}" + ) line = ( f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_nondeterministic_us']:>10.2f}us " - f"{result['flashinfer_us']:>12.2f}us " - f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x" + f"{nondet_us:>10.2f}us " + f"{det_us_str} " + f"{det_slowdown_str}" + ) + line = append_tie_break_columns( + line, result, args.tie_break ) else: line = ( @@ -933,10 +1137,12 @@ def main(): raise if args.op in ["all", "ragged"]: + show_det_or_tie = args.deterministic or args.tie_break print("\n" + "=" * 100) print( "top_k_ragged_transform: Fused top-k + ragged index transform " - f"(dtype={dtype_str}, deterministic={args.deterministic}, pattern={args.input_pattern})" + f"(dtype={dtype_str}, deterministic={args.deterministic}, " + f"pattern={args.input_pattern}, tie_break={args.tie_break})" ) if args.compare_sglang: print("NOTE: SGLang only supports k=2048 and float32") @@ -945,22 +1151,25 @@ def main(): "NOTE: deterministic mode also benchmarks FlashInfer(non-det) " "for direct comparison" ) + if args.tie_break: + print( + "NOTE: tie-break columns benchmark deterministic tie-small/tie-large; " + "slowdowns align with the same baseline as DetSlowdown" + ) print("=" * 100) - if args.deterministic: + if show_det_or_tie: header = ( f"{'batch':>6} {'seq_len':>10} {'k':>6} | " f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}" ) + header = append_tie_break_header(header, args.tie_break) else: header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}" if args.compare_sglang: header += f" {'SGLang':>12} {'Speedup':>10}" print(header) - divider_len = 87 if args.deterministic else 109 - if args.compare_sglang: - divider_len += 24 - print("-" * divider_len) + print("-" * len(header)) for batch_size in batch_sizes: for seq_len in seq_lens: @@ -975,14 +1184,37 @@ def main(): dtype, input_pattern=args.input_pattern, deterministic=args.deterministic, + compare_tie_break=args.tie_break, compare_sglang=args.compare_sglang, ) - if args.deterministic: + if show_det_or_tie: + nondet_us = result.get("flashinfer_nondeterministic_us") + if nondet_us is None: + nondet_us = result["flashinfer_us"] + det_us = ( + result["flashinfer_us"] if args.deterministic else None + ) + det_us_str = ( + f"{det_us:>12.2f}us" + if det_us is not None + else f"{'n/a':>14}" + ) + det_slowdown = result.get( + "deterministic_slowdown_vs_nondeterministic" + ) + det_slowdown_str = ( + f"{det_slowdown:>10.2f}x" + if det_slowdown is not None + else f"{'n/a':>11}" + ) line = ( f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " - f"{result['flashinfer_nondeterministic_us']:>10.2f}us " - f"{result['flashinfer_us']:>12.2f}us " - f"{result['deterministic_slowdown_vs_nondeterministic']:>10.2f}x" + f"{nondet_us:>10.2f}us " + f"{det_us_str} " + f"{det_slowdown_str}" + ) + line = append_tie_break_columns( + line, result, args.tie_break ) else: line = ( diff --git a/csrc/flashinfer_topk_binding.cu b/csrc/flashinfer_topk_binding.cu index 44ce7b5349..36c23ec386 100644 --- a/csrc/flashinfer_topk_binding.cu +++ b/csrc/flashinfer_topk_binding.cu @@ -19,17 +19,17 @@ using tvm::ffi::Optional; void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, - bool deterministic); + bool deterministic, int64_t tie_break); void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, Optional maybe_row_states_buffer, int64_t top_k, - bool deterministic); + bool deterministic, int64_t tie_break); void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k, bool deterministic); + int64_t top_k, bool deterministic, int64_t tie_break); bool can_implement_filtered_topk(); diff --git a/csrc/topk.cu b/csrc/topk.cu index 64f661ac3c..45ef16c906 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -22,9 +22,25 @@ using namespace flashinfer; using tvm::ffi::Optional; +inline sampling::TopKTieBreak ParseTopKTieBreak(int64_t tie_break) { + switch (tie_break) { + case 0: + return sampling::TopKTieBreak::None; + case 1: + return sampling::TopKTieBreak::Small; + case 2: + return sampling::TopKTieBreak::Large; + default: + TVM_FFI_ICHECK(false) + << "Invalid tie_break mode " << tie_break + << ", expected 0 (none), 1 (prefer small indices), or 2 (prefer large indices)"; + return sampling::TopKTieBreak::None; + } +} + void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, - bool deterministic) { + bool deterministic, int64_t tie_break) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(output_values); @@ -40,7 +56,10 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v cudaError_t status; auto dtype = input.dtype(); - + sampling::TopKTieBreak tie_break_mode = ParseTopKTieBreak(tie_break); + if (tie_break_mode != sampling::TopKTieBreak::None) { + deterministic = true; + } // Get row_states_buffer if provided (for multi-CTA path) sampling::RadixRowState* row_states_ptr = nullptr; if (maybe_row_states_buffer.has_value()) { @@ -53,7 +72,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v status = sampling::TopKDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(output_values.data_ptr()), batch_size, static_cast(top_k), d, - row_states_ptr, sorted_output, deterministic, stream); + row_states_ptr, sorted_output, deterministic, tie_break_mode, stream); return true; }); @@ -65,7 +84,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, Optional maybe_row_states_buffer, int64_t top_k, - bool deterministic) { + bool deterministic, int64_t tie_break) { CHECK_INPUT(input); CHECK_INPUT(output_page_table); CHECK_INPUT(src_page_table); @@ -84,6 +103,10 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta cudaError_t status; auto dtype = input.dtype(); + sampling::TopKTieBreak tie_break_mode = ParseTopKTieBreak(tie_break); + if (tie_break_mode != sampling::TopKTieBreak::None) { + deterministic = true; + } sampling::RadixRowState* row_states_ptr = nullptr; if (maybe_row_states_buffer.has_value()) { @@ -102,7 +125,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta static_cast(input.data_ptr()), static_cast(output_page_table.data_ptr()), static_cast(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, static_cast(lengths.data_ptr()), num_rows, static_cast(top_k), max_len, - row_states_ptr, deterministic, stream); + row_states_ptr, deterministic, tie_break_mode, stream); return true; }); @@ -112,7 +135,7 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k, bool deterministic) { + int64_t top_k, bool deterministic, int64_t tie_break) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(offsets); @@ -130,6 +153,10 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te cudaError_t status; auto dtype = input.dtype(); + sampling::TopKTieBreak tie_break_mode = ParseTopKTieBreak(tie_break); + if (tie_break_mode != sampling::TopKTieBreak::None) { + deterministic = true; + } sampling::RadixRowState* row_states_ptr = nullptr; if (maybe_row_states_buffer.has_value()) { @@ -142,7 +169,8 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te status = sampling::TopKRaggedTransformDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(offsets.data_ptr()), static_cast(lengths.data_ptr()), - num_rows, static_cast(top_k), max_len, row_states_ptr, deterministic, stream); + num_rows, static_cast(top_k), max_len, row_states_ptr, deterministic, + tie_break_mode, stream); return true; }); diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index a07fed2a27..58187ca85c 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -175,6 +175,7 @@ from .topk import top_k as top_k from .topk import top_k_page_table_transform as top_k_page_table_transform from .topk import top_k_ragged_transform as top_k_ragged_transform +from .topk import TopKTieBreak as TopKTieBreak from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper from .sparse import ( VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper, diff --git a/flashinfer/topk.py b/flashinfer/topk.py index 15b0473505..eef1998d03 100644 --- a/flashinfer/topk.py +++ b/flashinfer/topk.py @@ -16,6 +16,7 @@ import functools import os +from enum import IntEnum from types import SimpleNamespace from typing import Optional, Tuple @@ -32,6 +33,30 @@ ) +class TopKTieBreak(IntEnum): + """Top-k tie-break mode. + + This mirrors an enum-class style API while keeping int-compatible values + for FFI dispatch: + - NONE = 0 (legacy behavior) + - SMALL = 1 (prefer smaller indices) + - LARGE = 2 (prefer larger indices) + """ + + NONE = 0 + SMALL = 1 + LARGE = 2 + + def __str__(self) -> str: + return self.name.lower() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __format__(self, format_spec: str) -> str: + return format(str(self), format_spec) + + @functools.cache def get_topk_module(): module = gen_topk_module().build_and_load() @@ -44,6 +69,7 @@ def radix_topk( top_k: int, sorted_output: bool, deterministic: bool, + tie_break: int, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, ) -> torch.Tensor: @@ -64,6 +90,7 @@ def radix_topk( top_k, sorted_output, deterministic, + tie_break, ) return output_indices @@ -73,6 +100,7 @@ def _fake_radix_topk( top_k: int, sorted_output: bool, deterministic: bool, + tie_break: int, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, ) -> torch.Tensor: @@ -221,6 +249,7 @@ def radix_topk_page_table_transform( row_states_buffer: Optional[torch.Tensor], top_k: int, deterministic: bool, + tie_break: int, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" @@ -234,6 +263,7 @@ def radix_topk_page_table_transform( row_states_buffer, top_k, deterministic, + tie_break, ) @register_fake_op("flashinfer::radix_topk_page_table_transform") @@ -246,6 +276,7 @@ def _fake_radix_topk_page_table_transform( row_states_buffer: Optional[torch.Tensor], top_k: int, deterministic: bool, + tie_break: int, ) -> None: pass @@ -261,6 +292,7 @@ def radix_topk_ragged_transform( row_states_buffer: Optional[torch.Tensor], top_k: int, deterministic: bool, + tie_break: int, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" @@ -273,6 +305,7 @@ def radix_topk_ragged_transform( row_states_buffer, top_k, deterministic, + tie_break, ) @register_fake_op("flashinfer::radix_topk_ragged_transform") @@ -284,6 +317,7 @@ def _fake_radix_topk_ragged_transform( row_states_buffer: Optional[torch.Tensor], top_k: int, deterministic: bool, + tie_break: int, ) -> None: pass @@ -455,6 +489,7 @@ def top_k( k: int, sorted: bool = False, deterministic: bool = False, + tie_break: int = TopKTieBreak.NONE, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Radix-based Top-K selection. @@ -481,6 +516,15 @@ def top_k( Deterministic mode guarantees repeatable FlashInfer output ordering for the selected top-k set on a fixed input and system. + tie_break : int, optional + Tie-breaking mode for equal values at the selection boundary. + Supported modes are (or use ``TopKTieBreak`` enum values): + + - ``0``: no explicit index tie-break + - ``1``: prefer smaller indices + - ``2``: prefer larger indices + + Default is ``0``. Returns ------- @@ -530,6 +574,10 @@ def top_k( batch_size = input.size(0) device = input.device + # tie_break modes 1/2 imply deterministic mode. + if tie_break != TopKTieBreak.NONE: + deterministic = True + if can_use_clusters_topk(input.device, deterministic): indices, output_values = topk_clusters_exact( input, k, output_values=True, out_dtype=torch.int64 @@ -558,7 +606,13 @@ def top_k( # For deterministic + sorted + k <= 2048: CUDA handles the stable value sort on device. sorted_cuda = sorted and deterministic and k <= 2048 indices_int32 = get_topk_module().radix_topk( - input, k, sorted_cuda, deterministic, row_states_buffer, output_values + input, + k, + sorted_cuda, + deterministic, + tie_break, + row_states_buffer, + output_values, ) # Convert to int64 for compatibility @@ -587,6 +641,7 @@ def top_k_page_table_transform( k: int, row_to_batch: Optional[torch.Tensor] = None, deterministic: bool = False, + tie_break: int = TopKTieBreak.NONE, ) -> torch.Tensor: r"""Fused Top-K selection + Page Table Transform for sparse attention. @@ -618,6 +673,16 @@ def top_k_page_table_transform( deterministic : bool, optional If True, uses deterministic mode. Default is False (non-deterministic, which is faster). + tie_break : int, optional + Tie-breaking mode for equal values at the selection boundary. + Supported modes are (or use ``TopKTieBreak`` enum values): + + - ``0``: no explicit index tie-break + - ``1``: prefer smaller indices + - ``2``: prefer larger indices + + Default is ``0``. + Returns ------- @@ -649,6 +714,10 @@ def top_k_page_table_transform( device = input.device num_rows = input.size(0) + # tie_break modes 1/2 imply deterministic mode. + if tie_break != TopKTieBreak.NONE: + deterministic = True + if can_use_clusters_topk(input.device, deterministic) and row_to_batch is None: return topk_clusters_page_table_transform(input, lengths, src_page_table, k) @@ -672,6 +741,7 @@ def top_k_page_table_transform( row_states_buffer, k, deterministic, + tie_break, ) return output_page_table @@ -684,6 +754,7 @@ def top_k_ragged_transform( lengths: torch.Tensor, k: int, deterministic: bool = False, + tie_break: int = TopKTieBreak.NONE, ) -> torch.Tensor: r"""Fused Top-K selection + Ragged Index Transform for sparse attention. @@ -708,6 +779,16 @@ def top_k_ragged_transform( deterministic : bool, optional If True, uses deterministic mode. Default is False (non-deterministic, which is faster). + tie_break : int, optional + Tie-breaking mode for equal values at the selection boundary. + Supported modes are (or use ``TopKTieBreak`` enum values): + + - ``0``: no explicit index tie-break + - ``1``: prefer smaller indices + - ``2``: prefer larger indices + + Default is ``0``. + Returns ------- @@ -740,6 +821,10 @@ def top_k_ragged_transform( device = input.device num_rows = input.size(0) + # tie_break modes 1/2 imply deterministic mode. + if tie_break != TopKTieBreak.NONE: + deterministic = True + if can_use_clusters_topk(input.device, deterministic): return topk_clusters_ragged_transform(input, lengths, offsets, k) @@ -762,6 +847,7 @@ def top_k_ragged_transform( row_states_buffer, k, deterministic, + tie_break, ) return output_indices diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 45efb6792e..4fa5108f9e 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -33,6 +33,12 @@ namespace flashinfer { namespace sampling { +enum class TopKTieBreak : uint32_t { + None = 0, + Small = 1, + Large = 2, +}; + template inline size_t GetRadixTopKAvailableOrderedSmemBytes(size_t max_smem_per_block, size_t fixed_smem_aligned, @@ -203,6 +209,97 @@ __device__ __forceinline__ void DeterministicThreadStridedCollect(uint32_t tx, u __syncthreads(); } +/*! + * \brief Deterministically collect contiguous-order matches with a full CTA scan. + * + * Unlike DeterministicThreadStridedCollect, this helper traverses the row in contiguous index + * order across the CTA. This is used for row-global tie-breaking where we must prefer either + * smaller indices first or larger indices first for equal pivot values. + * + * \tparam BLOCK_THREADS Number of threads in the CTA + * \tparam REVERSE If true, traverse indices in reverse order (length-1 ... 0) + */ +template +__device__ __forceinline__ void DeterministicContiguousCollect(uint32_t tx, uint32_t length, + TempStorage& scan_temp_storage, + Predicate is_selected, + uint32_t emit_limit, + EmitFn emit_selected) { + if (emit_limit == 0 || length == 0) { + __syncthreads(); + return; + } + using BlockScan = cub::BlockScan; + // TODO: maybe tune ITEMS_PER_THREAD and vectorize + constexpr uint32_t ITEMS_PER_THREAD = 4; + constexpr uint32_t CHUNK_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD; + __shared__ uint32_t s_emitted; + __shared__ uint32_t s_chunk_base; + __shared__ uint32_t s_chunk_take; + if (tx == 0) { + s_emitted = 0; + s_chunk_base = 0; + s_chunk_take = 0; + } + __syncthreads(); + + const uint32_t num_chunks = ceil_div(length, CHUNK_ITEMS); + for (uint32_t chunk = 0; chunk < num_chunks; ++chunk) { + uint32_t row_idx_per_item[ITEMS_PER_THREAD]; + uint32_t selected_per_item[ITEMS_PER_THREAD]; + uint32_t thread_local_selected_count = 0; + +#pragma unroll + for (uint32_t item = 0; item < ITEMS_PER_THREAD; ++item) { + const uint32_t linear_idx = chunk * CHUNK_ITEMS + tx * ITEMS_PER_THREAD + item; + const bool in_range = linear_idx < length; + uint32_t row_idx = 0; + if (in_range) { + row_idx = REVERSE ? (length - 1u - linear_idx) : linear_idx; + } + row_idx_per_item[item] = row_idx; + const uint32_t selected = (in_range && is_selected(row_idx)) ? 1u : 0u; + selected_per_item[item] = selected; + thread_local_selected_count += selected; + } + + uint32_t selected_prefix = 0; + uint32_t block_selected = 0; + BlockScan(scan_temp_storage) + .ExclusiveSum(thread_local_selected_count, selected_prefix, block_selected); + + if (tx == 0) { + s_chunk_base = s_emitted; + const uint32_t remaining = (s_emitted < emit_limit) ? (emit_limit - s_emitted) : 0u; + s_chunk_take = min(remaining, block_selected); + s_emitted += s_chunk_take; + } + __syncthreads(); + + if (thread_local_selected_count > 0 && selected_prefix < s_chunk_take) { + uint32_t thread_emit_pos = selected_prefix; + const uint32_t thread_emit_end = + min(selected_prefix + thread_local_selected_count, s_chunk_take); +#pragma unroll + for (uint32_t item = 0; item < ITEMS_PER_THREAD; ++item) { + if (selected_per_item[item]) { + emit_selected(row_idx_per_item[item], s_chunk_base + thread_emit_pos); + if (++thread_emit_pos == thread_emit_end) { + break; + } + } + } + } + __syncthreads(); + + if (s_emitted >= emit_limit) { + break; + } + } + __syncthreads(); +} + /*! * \brief Compute suffix sum in shared memory using parallel reduction. * @@ -2179,7 +2276,8 @@ enum class FilteredTopKMode { Plain, PageTable, Ragged }; * - PageTable: output = dst_page_table, aux_input = src_page_table, aux_stride = src_stride * - Ragged: output = indices, aux_input = offsets, aux_output/aux_stride/row_to_batch unused */ -template +template __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) FilteredTopKUnifiedKernel(const DType* __restrict__ input, IdType* __restrict__ output, DType* __restrict__ aux_output, // values for Plain mode @@ -2395,13 +2493,26 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) using DetCollectBlockScan = cub::BlockScan; __shared__ typename DetCollectBlockScan::TempStorage temp_storage; - DeterministicThreadStridedCollect( - tx, length, temp_storage, - [&](uint32_t idx) { return Traits::ToOrdered(score[idx]) == pivot; }, eq_needed, - [&](uint32_t idx, uint32_t local_pos) { - s_indices[static_cast(top_k) - eq_needed + static_cast(local_pos)] = - static_cast(idx); - }); + auto emit_pivot_eq = [&](uint32_t idx, uint32_t local_pos) { + s_indices[static_cast(top_k) - eq_needed + static_cast(local_pos)] = + static_cast(idx); + }; + if constexpr (TIE_BREAK == TopKTieBreak::Small) { + DeterministicContiguousCollect( + tx, length, temp_storage, + [&](uint32_t idx) { return Traits::ToOrdered(score[idx]) == pivot; }, eq_needed, + emit_pivot_eq); + } else if constexpr (TIE_BREAK == TopKTieBreak::Large) { + DeterministicContiguousCollect( + tx, length, temp_storage, + [&](uint32_t idx) { return Traits::ToOrdered(score[idx]) == pivot; }, eq_needed, + emit_pivot_eq); + } else { + DeterministicThreadStridedCollect( + tx, length, temp_storage, + [&](uint32_t idx) { return Traits::ToOrdered(score[idx]) == pivot; }, eq_needed, + emit_pivot_eq); + } } }; @@ -2953,7 +3064,9 @@ cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_o const IdType* aux_input, int64_t aux_stride, const IdType* row_to_batch, const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - bool deterministic = false, cudaStream_t stream = 0) { + bool deterministic = false, + TopKTieBreak tie_break = TopKTieBreak::None, + cudaStream_t stream = 0) { constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; constexpr int MAX_VEC = 16 / sizeof(DType); @@ -2964,20 +3077,28 @@ cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_o const int vec_size = ComputeFilteredTopKVecSize(max_len); -#define DISPATCH_VEC_SIZE(VS) \ - if (vec_size == VS) { \ - if (!deterministic) { \ - auto kernel = FilteredTopKUnifiedKernel; \ - FLASHINFER_CUDA_CALL( \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ - } else { \ - auto kernel = FilteredTopKUnifiedKernel; \ - FLASHINFER_CUDA_CALL( \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ - } \ - return cudaSuccess; \ +#define LAUNCH_FILTERED_KERNEL(VS, DET, TIE) \ + do { \ + auto kernel = FilteredTopKUnifiedKernel; \ + FLASHINFER_CUDA_CALL( \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, smem_size, stream)); \ + } while (0) + +#define DISPATCH_VEC_SIZE(VS) \ + if (vec_size == VS) { \ + if (!deterministic) { \ + LAUNCH_FILTERED_KERNEL(VS, false, TopKTieBreak::None); \ + } else { \ + if (tie_break == TopKTieBreak::Small) { \ + LAUNCH_FILTERED_KERNEL(VS, true, TopKTieBreak::Small); \ + } else if (tie_break == TopKTieBreak::Large) { \ + LAUNCH_FILTERED_KERNEL(VS, true, TopKTieBreak::Large); \ + } else { \ + LAUNCH_FILTERED_KERNEL(VS, true, TopKTieBreak::None); \ + } \ + } \ + return cudaSuccess; \ } DISPATCH_VEC_SIZE(1) @@ -2987,6 +3108,7 @@ cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_o DISPATCH_VEC_SIZE(8) } #undef DISPATCH_VEC_SIZE +#undef LAUNCH_FILTERED_KERNEL return cudaSuccess; } @@ -2997,36 +3119,40 @@ cudaError_t FilteredTopKPageTableTransform(DType* input, IdType* output_page_tab const IdType* src_page_table, int64_t src_stride, const IdType* row_to_batch, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - bool deterministic = false, cudaStream_t stream = 0) { + bool deterministic = false, + TopKTieBreak tie_break = TopKTieBreak::None, + cudaStream_t stream = 0) { DType* aux_output = nullptr; // Not used for PageTable mode return LaunchFilteredTopKUnified( input, output_page_table, aux_output, src_page_table, src_stride, row_to_batch, lengths, - num_rows, top_k_val, max_len, deterministic, stream); + num_rows, top_k_val, max_len, deterministic, tie_break, stream); } template cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, const IdType* offsets, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, bool deterministic = false, + TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0) { DType* aux_output = nullptr; // Not used for Ragged mode int64_t aux_stride = 0; // Not used for Ragged mode const IdType* row_to_batch = nullptr; // Not used for Ragged mode return LaunchFilteredTopKUnified( input, output_indices, aux_output, offsets, aux_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, stream); + top_k_val, max_len, deterministic, tie_break, stream); } template cudaError_t FilteredTopK(DType* input, IdType* output_indices, DType* output_values, const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, bool deterministic = false, cudaStream_t stream = 0) { + uint32_t max_len, bool deterministic = false, + TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0) { const IdType* aux_input = nullptr; // Not used for Plain mode int64_t aux_stride = 0; // Not used for Plain mode const IdType* row_to_batch = nullptr; // Not used for Plain mode return LaunchFilteredTopKUnified( input, output_indices, output_values, aux_input, aux_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, stream); + top_k_val, max_len, deterministic, tie_break, stream); } /*! @@ -3067,11 +3193,17 @@ inline TopKAlgoOverride GetTopKAlgoOverride() { * \param top_k_val Number of top elements to select * \param max_len Maximum sequence length * \param deterministic Whether deterministic top-k path is requested + * \param tie_break Mode of tie-break * \return true if FilteredTopK should be used, false for Multi-CTA RadixTopK */ template inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - bool deterministic) { + bool deterministic, TopKTieBreak tie_break) { + // Tie-break modes are only supported by FilteredTopK + if (tie_break != TopKTieBreak::None) { + return true; + } + // Check if GPU supports enough shared memory for FilteredTopK const bool gpu_supports_filtered = CanImplementFilteredTopK(); const bool k_fits_filtered = (top_k_val <= FILTERED_TOPK_MAX_K) && (max_len > top_k_val); @@ -3119,11 +3251,18 @@ cudaError_t TopKPageTableTransformDispatch(DType* input, IdType* output_page_tab const IdType* row_to_batch, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, bool deterministic, + TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + if (tie_break != TopKTieBreak::None) { + deterministic = true; + if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { + return cudaErrorNotSupported; + } + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { FLASHINFER_CUDA_CALL((FilteredTopKPageTableTransform( input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, stream))); + top_k_val, max_len, deterministic, tie_break, stream))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( output_page_table, static_cast(nullptr), src_page_table, src_stride, @@ -3140,11 +3279,19 @@ template cudaError_t TopKRaggedTransformDispatch(DType* input, IdType* output_indices, const IdType* offsets, IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, - bool deterministic, cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + bool deterministic, + TopKTieBreak tie_break = TopKTieBreak::None, + cudaStream_t stream = 0) { + if (tie_break != TopKTieBreak::None) { + deterministic = true; + if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { + return cudaErrorNotSupported; + } + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { FLASHINFER_CUDA_CALL((FilteredTopKRaggedTransform( input, output_indices, offsets, lengths, num_rows, top_k_val, max_len, deterministic, - stream))); + tie_break, stream))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( output_indices, static_cast(nullptr), offsets, 0, nullptr, num_rows, top_k_val, @@ -3161,11 +3308,18 @@ template cudaError_t TopKDispatch(DType* input, IdType* output_indices, DType* output_values, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, bool sorted_output = false, - bool deterministic = false, cudaStream_t stream = 0) { - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic)) { + bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, + cudaStream_t stream = 0) { + if (tie_break != TopKTieBreak::None) { + deterministic = true; + if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { + return cudaErrorNotSupported; + } + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { FLASHINFER_CUDA_CALL( (FilteredTopK(input, output_indices, output_values, nullptr, num_rows, - top_k_val, max_len, deterministic, stream))); + top_k_val, max_len, deterministic, tie_break, stream))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 2665921384..235e908574 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -104,8 +104,18 @@ def _build_strictly_descending_logits( @pytest.mark.parametrize("vocab_size", [32000, 65536, 128512]) @pytest.mark.parametrize("k", [256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_top_k(batch_size, vocab_size, k, dtype): +@pytest.mark.parametrize( + "tie_break", + [ + flashinfer.TopKTieBreak.NONE, + flashinfer.TopKTieBreak.SMALL, + flashinfer.TopKTieBreak.LARGE, + ], +) +def test_top_k(batch_size, vocab_size, k, dtype, tie_break): """Test top_k returns correct values and indices.""" + if tie_break != flashinfer.TopKTieBreak.NONE and not can_implement_filtered_topk(): + pytest.skip("Tie-break modes require filtered top-k support on this device") if k > vocab_size: pytest.skip("k should be less than vocab_size") @@ -113,7 +123,7 @@ def test_top_k(batch_size, vocab_size, k, dtype): logits = torch.randn(batch_size, vocab_size, device="cuda", dtype=dtype) # flashinfer top_k - values, indices = flashinfer.top_k(logits, k) + values, indices = flashinfer.top_k(logits, k, tie_break=tie_break) # Reference: torch.topk ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ -142,8 +152,18 @@ def test_top_k(batch_size, vocab_size, k, dtype): @pytest.mark.parametrize("vocab_size", [32000, 65536]) @pytest.mark.parametrize("k", [256, 512]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_top_k_sorted(batch_size, vocab_size, k, dtype): +@pytest.mark.parametrize( + "tie_break", + [ + flashinfer.TopKTieBreak.NONE, + flashinfer.TopKTieBreak.SMALL, + flashinfer.TopKTieBreak.LARGE, + ], +) +def test_top_k_sorted(batch_size, vocab_size, k, dtype, tie_break): """Test top_k with sorted=True returns sorted values.""" + if tie_break != flashinfer.TopKTieBreak.NONE and not can_implement_filtered_topk(): + pytest.skip("Tie-break modes require filtered top-k support on this device") if k > vocab_size: pytest.skip("k should be less than vocab_size") @@ -151,7 +171,7 @@ def test_top_k_sorted(batch_size, vocab_size, k, dtype): logits = torch.randn(batch_size, vocab_size, device="cuda", dtype=dtype) # flashinfer top_k with sorted=True - values, indices = flashinfer.top_k(logits, k, sorted=True) + values, indices = flashinfer.top_k(logits, k, sorted=True, tie_break=tie_break) # Reference: torch.topk with sorted=True ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True) @@ -179,13 +199,23 @@ def test_top_k_sorted(batch_size, vocab_size, k, dtype): @pytest.mark.parametrize("vocab_size", [32000, 65536]) @pytest.mark.parametrize("k", [256]) -def test_top_k_single_batch(vocab_size, k): +@pytest.mark.parametrize( + "tie_break", + [ + flashinfer.TopKTieBreak.NONE, + flashinfer.TopKTieBreak.SMALL, + flashinfer.TopKTieBreak.LARGE, + ], +) +def test_top_k_single_batch(vocab_size, k, tie_break): """Test top_k with batch_size=1 (common inference case).""" + if tie_break != flashinfer.TopKTieBreak.NONE and not can_implement_filtered_topk(): + pytest.skip("Tie-break modes require filtered top-k support on this device") torch.manual_seed(42) logits = torch.randn(1, vocab_size, device="cuda", dtype=torch.float32) # flashinfer top_k - values, indices = flashinfer.top_k(logits, k) + values, indices = flashinfer.top_k(logits, k, tie_break=tie_break) # Reference: torch.topk ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ -203,13 +233,25 @@ def test_top_k_single_batch(vocab_size, k): @pytest.mark.parametrize("vocab_size", [65536, 128512]) @pytest.mark.parametrize("k", [256]) @pytest.mark.parametrize("det", [True, False]) -def test_top_k_large_batch(batch_size, vocab_size, k, det): +@pytest.mark.parametrize( + "tie_break", + [ + flashinfer.TopKTieBreak.NONE, + flashinfer.TopKTieBreak.SMALL, + flashinfer.TopKTieBreak.LARGE, + ], +) +def test_top_k_large_batch(batch_size, vocab_size, k, det, tie_break): """Test top_k with large batch sizes (multi-CTA path).""" + if tie_break != flashinfer.TopKTieBreak.NONE and not can_implement_filtered_topk(): + pytest.skip("Tie-break modes require filtered top-k support on this device") torch.manual_seed(42) logits = torch.randn(batch_size, vocab_size, device="cuda", dtype=torch.float32) # flashinfer top_k (should use multi-CTA path for large vocab) - values, indices = flashinfer.top_k(logits, k, deterministic=det) + values, indices = flashinfer.top_k( + logits, k, deterministic=det, tie_break=tie_break + ) # Reference: torch.topk ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ -1894,6 +1936,128 @@ def test_top_k_deterministic_sorted_repeatable_valid_selection_under_ties( ) +@pytest.mark.parametrize( + ("algo", "batch_size", "vocab_size", "k"), + [ + ("filtered", 2, 128 * 1024, 2048), + ("filtered", 1, 1024 * 1024, 1024), + ("filtered", 74, 16 * 1024, 512), + ], + ids=[ + "filtered_b2_l128k_k2048", + "filtered_b1_l1m_k1024", + "filtered_b74_l16k_k512", + ], +) +def test_top_k_tie_break_modes(algo, batch_size, vocab_size, k, set_topk_algo): + """tie_break=1|2 should select row-global smallest/largest pivot indices.""" + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + device = "cuda" + generator = torch.Generator(device=device) + generator.manual_seed(0) + logits = ( + torch.randn( + (batch_size, 1), device=device, dtype=torch.float32, generator=generator + ) + .expand(batch_size, vocab_size) + .contiguous() + ) + + values_small, indices_small = flashinfer.top_k(logits, k, tie_break=1) + values_large, indices_large = flashinfer.top_k(logits, k, tie_break=2) + + expected_small = ( + torch.arange(k, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(batch_size, -1) + ) + expected_large = ( + torch.arange(vocab_size - k, vocab_size, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(batch_size, -1) + ) + expected_values = logits[:, :1].expand(batch_size, k).contiguous() + + torch.testing.assert_close(values_small, expected_values) + torch.testing.assert_close(values_large, expected_values) + _assert_unordered_indices_match(indices_small, expected_small) + _assert_unordered_indices_match(indices_large, expected_large) + + +@pytest.mark.parametrize( + ("algo", "num_rows", "max_len", "k"), + [ + ("filtered", 2, 128 * 1024, 2048), + ("filtered", 1, 1024 * 1024, 1024), + ("filtered", 74, 16 * 1024, 512), + ], + ids=[ + "filtered_rows2_l128k_k2048", + "filtered_rows1_l1m_k1024", + "filtered_rows74_l16k_k512", + ], +) +def test_top_k_tie_break_modes_transform_apis( + algo, num_rows, max_len, k, set_topk_algo +): + """Transform APIs should honor tie_break selection before remapping outputs.""" + if algo == "filtered" and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + device = "cuda" + + generator = torch.Generator(device=device) + generator.manual_seed(0) + scores = ( + torch.randn( + (num_rows, 1), device=device, dtype=torch.float32, generator=generator + ) + .expand(num_rows, max_len) + .contiguous() + ) + lengths = torch.full((num_rows,), max_len, device=device, dtype=torch.int32) + src_page_table = ( + torch.arange(max_len, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(num_rows, -1) + .contiguous() + ) + offsets = torch.zeros((num_rows,), device=device, dtype=torch.int32) + + page_small = flashinfer.top_k_page_table_transform( + scores, src_page_table, lengths, k, tie_break=1 + ) + page_large = flashinfer.top_k_page_table_transform( + scores, src_page_table, lengths, k, tie_break=2 + ) + ragged_small = flashinfer.top_k_ragged_transform( + scores, offsets, lengths, k, tie_break=1 + ) + ragged_large = flashinfer.top_k_ragged_transform( + scores, offsets, lengths, k, tie_break=2 + ) + + expected_small = ( + torch.arange(k, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(num_rows, -1) + ) + expected_large = ( + torch.arange(max_len - k, max_len, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(num_rows, -1) + ) + + assert torch.equal(page_small, expected_small) + assert torch.equal(page_large, expected_large) + assert torch.equal(ragged_small, expected_small) + assert torch.equal(ragged_large, expected_large) + + @pytest.mark.parametrize( ("algo", "vocab_size", "k"), [ @@ -2436,9 +2600,9 @@ def test_topk_clusters_ragged_transform(num_rows, seq_len, k, dtype): if __name__ == "__main__": # Basic tests - test_top_k(4, 32000, 256, torch.float32) - test_top_k_sorted(4, 32000, 256, torch.float32) - test_top_k_large_batch(64, 128512, 256, False) + test_top_k(4, 32000, 256, torch.float32, flashinfer.TopKTieBreak.NONE) + test_top_k_sorted(4, 32000, 256, torch.float32, flashinfer.TopKTieBreak.NONE) + test_top_k_large_batch(64, 128512, 256, False, flashinfer.TopKTieBreak.NONE) # Fused transform tests print("Testing page table transform...")