Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions benchmarks/bench_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def bench_median_ms(fn) -> float:
enable_cupti=True,
dry_run_iters=10,
repeat_iters=100,
use_cuda_graph=True,
)
return float(np.median(measurements))

Expand All @@ -88,6 +89,7 @@ def bench_top_k_from_scores(
"""Benchmark top-k on a pre-generated score tensor."""
batch_size, seq_len = scores.shape

set_topk_algo("default")
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
lambda deterministic_mode: flashinfer.top_k(
scores,
Expand Down Expand Up @@ -121,6 +123,12 @@ def bench_top_k_from_scores(
result["torch_deterministic_us"] = torch_det_ms * 1e3
result["speedup_vs_torch_deterministic"] = torch_det_ms / fi_ms

set_topk_algo("clusters")
fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k))
result["fast_topk_us"] = fast_topk_ms * 1e3
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
set_topk_algo("auto")

# SGLang comparison (only supports k=2048 and float32)
if (
compare_sglang
Expand All @@ -130,7 +138,7 @@ def bench_top_k_from_scores(
):
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
sg_ms = bench_median_ms(
lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None),
lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None)
)
result["sglang_us"] = sg_ms * 1e3
result["speedup_vs_sglang"] = sg_ms / fi_ms
Expand Down Expand Up @@ -345,7 +353,10 @@ def bench_page_table_transform(
.expand(batch_size, -1)
.contiguous()
)
use_cuda_graph = True
enable_cupti = True

set_topk_algo("default")
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
lambda deterministic_mode: flashinfer.top_k_page_table_transform(
scores,
Expand All @@ -370,13 +381,29 @@ def bench_page_table_transform(
fi_ms / fi_nondeterministic_ms
)

# FlashInfer clusters
set_topk_algo("clusters")
measurements = bench_gpu_time(
lambda: flashinfer.top_k_page_table_transform(
scores, src_page_table, lengths, k
),
enable_cupti=enable_cupti,
dry_run_iters=10,
repeat_iters=100,
use_cuda_graph=use_cuda_graph,
)
fast_topk_ms = np.median(measurements)
result["fast_topk_us"] = fast_topk_ms * 1e3
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
set_topk_algo("auto")

# SGLang comparison (only supports k=2048 and float32)
if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32:
cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda")
sg_ms = bench_median_ms(
lambda: sgl_kernel.fast_topk_transform_fused(
scores, lengths, src_page_table, cu_seqlens_q, k
),
)
)
result["sglang_us"] = sg_ms * 1e3
result["speedup_vs_sglang"] = sg_ms / fi_ms
Expand All @@ -399,7 +426,10 @@ def bench_ragged_transform(
offsets = torch.arange(
0, batch_size * seq_len, seq_len, device="cuda", dtype=torch.int32
)
use_cuda_graph = True
enable_cupti = True

set_topk_algo("default")
fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(
lambda deterministic_mode: flashinfer.top_k_ragged_transform(
scores,
Expand All @@ -424,12 +454,26 @@ def bench_ragged_transform(
fi_ms / fi_nondeterministic_ms
)

# FlashInfer clusters
set_topk_algo("clusters")
measurements = bench_gpu_time(
lambda: flashinfer.top_k_ragged_transform(scores, offsets, lengths, k),
enable_cupti=enable_cupti,
dry_run_iters=10,
repeat_iters=100,
use_cuda_graph=use_cuda_graph,
)
fast_topk_ms = np.median(measurements)
result["fast_topk_us"] = fast_topk_ms * 1e3
result["speedup_vs_flashinfer"] = fi_ms / fast_topk_ms
set_topk_algo("auto")

# SGLang comparison (only supports k=2048 and float32)
if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32:
sg_ms = bench_median_ms(
lambda: sgl_kernel.fast_topk_transform_ragged_fused(
scores, lengths, offsets, k
),
)
)
result["sglang_us"] = sg_ms * 1e3
result["speedup_vs_sglang"] = sg_ms / fi_ms
Expand Down Expand Up @@ -637,13 +681,14 @@ def main():
header = (
f"{'batch':>6} {'seq_len':>10} {'k':>6} | "
f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
)
if args.compare_torch_deterministic and not args.deterministic:
header += f" {'torch.det':>12} {'Speedup':>10}"
if args.compare_sglang:
header += f" {'SGLang':>12} {'Speedup':>10}"
print(header)
divider_len = 96 if args.deterministic else 72
divider_len = 96 if args.deterministic else 115
if args.compare_torch_deterministic and not args.deterministic:
divider_len += 24
if args.compare_sglang:
Expand Down Expand Up @@ -677,6 +722,8 @@ def main():
f"{result['flashinfer_us']:>12.2f}us {result['torch_us']:>10.2f}us "
f"{result['speedup_vs_torch']:>9.2f}x"
)
if "fast_topk_us" in result:
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
if "torch_deterministic_us" in result:
line += (
f" {result['torch_deterministic_us']:>10.2f}us "
Expand Down Expand Up @@ -733,11 +780,12 @@ def main():
header = (
f"{'case':>24} {'rows':>8} {'seq_len':>10} {'k':>6} | "
f"{'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
f" {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
)
if args.compare_torch_deterministic and not args.deterministic:
header += f" {'torch.det':>12} {'Speedup':>10}"
print(header)
divider_len = 110 if args.deterministic else 86
divider_len = 110 if args.deterministic else 129
if args.compare_torch_deterministic and not args.deterministic:
divider_len += 24
print("-" * divider_len)
Expand Down Expand Up @@ -788,6 +836,8 @@ def main():
f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us "
f"{result['speedup_vs_torch']:>9.2f}x"
)
if "fast_topk_us" in result:
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
if "torch_deterministic_us" in result:
line += (
f" {result['torch_deterministic_us']:>10.2f}us "
Expand Down Expand Up @@ -826,13 +876,13 @@ def main():
f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}"
)
else:
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
if args.compare_sglang:
header += f" {'SGLang':>12} {'Speedup':>10}"
print(header)
divider_len = 87 if args.deterministic else 70
divider_len = 87 if args.deterministic else 109
if args.compare_sglang:
divider_len += 20
divider_len += 24
print("-" * divider_len)

for batch_size in batch_sizes:
Expand Down Expand Up @@ -862,6 +912,8 @@ def main():
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
f"{result['flashinfer_us']:>10.2f}us"
)
if "fast_topk_us" in result:
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
if "sglang_us" in result:
line += (
f" {result['sglang_us']:>10.2f}us "
Expand Down Expand Up @@ -901,13 +953,13 @@ def main():
f"{'FlashInfer':>12} {'FlashInfer(det)':>14} {'DetSlowdown':>11}"
)
else:
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'Clusters':>12} {'Speedup Clusters vs. Default':>29}"
if args.compare_sglang:
header += f" {'SGLang':>12} {'Speedup':>10}"
print(header)
divider_len = 87 if args.deterministic else 70
divider_len = 87 if args.deterministic else 109
if args.compare_sglang:
divider_len += 20
divider_len += 24
print("-" * divider_len)

for batch_size in batch_sizes:
Expand Down Expand Up @@ -937,6 +989,8 @@ def main():
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
f"{result['flashinfer_us']:>10.2f}us"
)
if "fast_topk_us" in result:
line += f" {result['fast_topk_us']:>10.2f}us {result['speedup_vs_flashinfer']:>28.2f}x"
if "sglang_us" in result:
line += (
f" {result['sglang_us']:>10.2f}us "
Expand Down
159 changes: 159 additions & 0 deletions csrc/flashinfer_fast_topk_clusters_binding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <flashinfer/fast_topk_clusters_exact.cuh>

#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;
using namespace flashinfer::sampling;

void fast_topk_clusters_exact(TensorView logits, TensorView indices,
Optional<TensorView> output_values, Optional<TensorView> histogram,
TensorView cached_overflow, int64_t TopK, int64_t num_cached,
int64_t num_clusters, bool pdl_enabled) {
CHECK_DIM(2, logits); // input: (batch_size, seq_len)
CHECK_DIM(2, indices); // output_indices: (batch_size, top_k)
const int batch_size = static_cast<int>(logits.size(0));
const int seq_len = static_cast<int>(logits.size(1));

int* hist_ptr = nullptr;
if (histogram.has_value()) {
hist_ptr = (int*)histogram.value().data_ptr();
}

void* values_ptr = nullptr;
if (output_values.has_value()) {
values_ptr = (output_values.value().data_ptr());
}

const int logit_stride = static_cast<int>(logits.stride(0));
const int indices_stride = static_cast<int>(indices.stride(0));
const int n_clusters = static_cast<int>(num_clusters);
cudaStream_t stream = get_current_stream();
const int ovf_stride = static_cast<int>(cached_overflow.stride(0)) / (4 * n_clusters);

auto dtype = logits.dtype();

auto idx_dtype = indices.dtype();
TVM_FFI_ICHECK(idx_dtype.code == kDLInt && (idx_dtype.bits == 32 || idx_dtype.bits == 64))
<< "indices must be int32 or int64, got code=" << idx_dtype.code
<< " bits=" << idx_dtype.bits;
const bool idx_int64 = (idx_dtype.bits == 64);

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
if (idx_int64) {
launch_fast_topk_clusters_exact<c_type, int64_t>(
static_cast<const c_type*>(logits.data_ptr()), static_cast<int64_t*>(indices.data_ptr()),
(c_type*)(values_ptr), seq_len, (hist_ptr), static_cast<int*>(cached_overflow.data_ptr()),
ovf_stride, batch_size, logit_stride, indices_stride, static_cast<int>(num_cached),
n_clusters, pdl_enabled, static_cast<int>(TopK), stream);
} else {
launch_fast_topk_clusters_exact<c_type, int>(
static_cast<const c_type*>(logits.data_ptr()), static_cast<int*>(indices.data_ptr()),
(c_type*)(values_ptr), seq_len, (hist_ptr), static_cast<int*>(cached_overflow.data_ptr()),
ovf_stride, batch_size, logit_stride, indices_stride, static_cast<int>(num_cached),
n_clusters, pdl_enabled, static_cast<int>(TopK), stream);
}
return true;
});
Comment thread
Aalanli marked this conversation as resolved.
auto err = cudaGetLastError();
TVM_FFI_ICHECK(err == cudaSuccess)
<< "launch_fast_topk_clusters_exact failed: " << cudaGetErrorString(err);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(fast_topk_clusters_exact, fast_topk_clusters_exact);

void fast_topk_clusters_exact_page_table_transform(TensorView logits, TensorView indices,
TensorView seq_lens, TensorView page_table,
Optional<TensorView> histogram,
TensorView cached_overflow, int64_t TopK,
int64_t num_cached, int64_t num_clusters,
bool pdl_enabled) {
CHECK_DIM(2, logits);
CHECK_DIM(2, indices);
CHECK_DIM(1, seq_lens);
CHECK_DIM(2, page_table);
const int batch_size = static_cast<int>(logits.size(0));

const int* hist_ptr = nullptr;
if (histogram.has_value()) {
hist_ptr = static_cast<const int*>(histogram.value().data_ptr());
}

const int logit_stride = static_cast<int>(logits.stride(0));
const int indices_stride = static_cast<int>(indices.stride(0));
const int page_table_stride = static_cast<int>(page_table.stride(0));
const int n_clusters = static_cast<int>(num_clusters);
const int ovf_stride = static_cast<int>(cached_overflow.stride(0)) / (4 * n_clusters);
cudaStream_t stream = get_current_stream();

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(logits.dtype(), c_type, [&] {
launch_fast_topk_clusters_exact_page_table_transform<c_type>(
static_cast<const c_type*>(logits.data_ptr()), static_cast<int*>(indices.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()), static_cast<int*>(page_table.data_ptr()),
const_cast<int*>(hist_ptr), static_cast<int*>(cached_overflow.data_ptr()), ovf_stride,
batch_size, logit_stride, indices_stride, page_table_stride, static_cast<int>(num_cached),
n_clusters, pdl_enabled, static_cast<int>(TopK), stream);
return true;
});
auto err = cudaGetLastError();
TVM_FFI_ICHECK(err == cudaSuccess)
<< "launch_fast_topk_clusters_exact_page_table_transform failed: " << cudaGetErrorString(err);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(fast_topk_clusters_exact_page_table_transform,
fast_topk_clusters_exact_page_table_transform);

void fast_topk_clusters_exact_ragged_transform(TensorView logits, TensorView indices,
TensorView seq_lens, TensorView offsets,
Optional<TensorView> histogram,
TensorView cached_overflow, int64_t TopK,
int64_t num_cached, int64_t num_clusters,
bool pdl_enabled) {
CHECK_DIM(2, logits);
CHECK_DIM(2, indices);
CHECK_DIM(1, seq_lens);
CHECK_DIM(1, offsets);
const int batch_size = static_cast<int>(logits.size(0));

const int* hist_ptr = nullptr;
if (histogram.has_value()) {
hist_ptr = static_cast<const int*>(histogram.value().data_ptr());
}

const int logit_stride = static_cast<int>(logits.stride(0));
const int indices_stride = static_cast<int>(indices.stride(0));
const int n_clusters = static_cast<int>(num_clusters);
const int ovf_stride = static_cast<int>(cached_overflow.stride(0)) / (4 * n_clusters);
cudaStream_t stream = get_current_stream();

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(logits.dtype(), c_type, [&] {
launch_fast_topk_clusters_exact_ragged_transform<c_type>(
static_cast<const c_type*>(logits.data_ptr()), static_cast<int*>(indices.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()), static_cast<int*>(offsets.data_ptr()),
const_cast<int*>(hist_ptr), static_cast<int*>(cached_overflow.data_ptr()), ovf_stride,
batch_size, logit_stride, indices_stride, static_cast<int>(num_cached), n_clusters,
pdl_enabled, static_cast<int>(TopK), stream);
return true;
});
auto err = cudaGetLastError();
TVM_FFI_ICHECK(err == cudaSuccess)
<< "launch_fast_topk_clusters_exact_ragged_transform failed: " << cudaGetErrorString(err);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(fast_topk_clusters_exact_ragged_transform,
fast_topk_clusters_exact_ragged_transform);
2 changes: 2 additions & 0 deletions flashinfer/jit/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,7 @@ def gen_topk_module() -> JitSpec:
[
jit_env.FLASHINFER_CSRC_DIR / "topk.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_topk_binding.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_fast_topk_clusters_binding.cu",
],
extra_cuda_cflags=["-lineinfo"],
)
1 change: 1 addition & 0 deletions flashinfer/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ def call_fn():
for _ in range(repeat_iters):
if _do_l2_flush:
buffer.zero_()
torch.cuda.synchronize()
start_cpu = cupti.get_timestamp()
runner()
end_cpu = cupti.get_timestamp()
Expand Down
Loading
Loading