Skip to content

Commit 2c32015

Browse files
authored
Radix-selection based BlockTopK specialization (#7384)
* initial air top-k version * fixes failures with parallel solution * introduces a pre-filter op * fixes issues for potential selection of padded items * fixes dependent name function calls * maybe unused * re-establish shortcut * preparation for small variable-size segment tests * test case short-circuit path * extract find splitter prefix * reuses registers for twiddled keys * switches to finding the tightest policy for a given seg upper bound * unrolled histo init * improves writes * refactors find splitter and comments * cleanup for review * addresses review comments * unifies branches * carve out for follow-up pr * revert needed changes * style improvements * addresses review comments * adds algorithm description * fix +/-0.0 float handling * drops unused headers * fixes implicit conversion warning * switch to new device macro
1 parent d5b60e6 commit 2c32015

File tree

5 files changed

+639
-61
lines changed

5 files changed

+639
-61
lines changed

cub/cub/agent/agent_batched_topk.cuh

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ struct agent_batched_topk_worker_per_segment
112112
// -------------------------------------------------------------------------
113113
// Constructor
114114
// -------------------------------------------------------------------------
115-
_CCCL_DEVICE _CCCL_FORCEINLINE agent_batched_topk_worker_per_segment(
115+
_CCCL_DEVICE_API _CCCL_FORCEINLINE agent_batched_topk_worker_per_segment(
116116
TempStorage& temp_storage,
117117
KeyInputItItT d_key_segments_it,
118118
KeyOutputItItT d_key_segments_out_it,
@@ -133,7 +133,7 @@ struct agent_batched_topk_worker_per_segment
133133
, num_segments(num_segments)
134134
{}
135135

136-
_CCCL_DEVICE _CCCL_FORCEINLINE void Process()
136+
_CCCL_DEVICE_API _CCCL_FORCEINLINE void Process()
137137
{
138138
// Identify Segment
139139
const int segment_id = static_cast<int>(blockIdx.x);
@@ -145,6 +145,9 @@ struct agent_batched_topk_worker_per_segment
145145
return;
146146
}
147147

148+
constexpr bool is_full_tile = params::has_single_static_value_v<SegmentSizeParameterT>
149+
&& params::static_min_value_v<SegmentSizeParameterT> == tile_size;
150+
148151
// Resolve Segment Parameters
149152
const auto segment_size = segment_sizes.get_param(segment_id);
150153
const auto k = k_param.get_param(segment_id);
@@ -161,8 +164,7 @@ struct agent_batched_topk_worker_per_segment
161164

162165
// Load Keys
163166
key_t thread_keys[items_per_thread];
164-
if constexpr (params::has_single_static_value_v<SegmentSizeParameterT>
165-
&& params::static_min_value_v<SegmentSizeParameterT> == tile_size)
167+
if constexpr (is_full_tile)
166168
{
167169
// No padding needed
168170
block_load_keys_t(temp_storage.load_keys).Load(block_keys_in, thread_keys);
@@ -171,7 +173,7 @@ struct agent_batched_topk_worker_per_segment
171173
{
172174
// Potentially partial final load with padding
173175
// TODO (elstehle): explore whether a runtime check for segment_size == tile_size improves performance
174-
block_load_keys_t(temp_storage.load_keys).Load(block_keys_in, thread_keys, segment_size, padding_key);
176+
block_load_keys_t(temp_storage.load_keys).Load(block_keys_in, thread_keys, segment_size);
175177
}
176178

177179
// Load Values (if applicable)
@@ -182,8 +184,7 @@ struct agent_batched_topk_worker_per_segment
182184
__syncthreads();
183185
auto block_vals_in = d_value_segments_it[segment_id];
184186

185-
if constexpr (params::has_single_static_value_v<SegmentSizeParameterT>
186-
&& params::static_min_value_v<SegmentSizeParameterT> == tile_size)
187+
if constexpr (is_full_tile)
187188
{
188189
// No padding needed
189190
block_load_vals_t(temp_storage.load_vals).Load(block_vals_in, thread_values);
@@ -201,31 +202,33 @@ struct agent_batched_topk_worker_per_segment
201202
// Perform Block Top-K
202203
if constexpr (is_keys_only)
203204
{
204-
const bool is_successful_dispatch =
205-
detail::params::dispatch_discrete(select_directions, segment_id, [this, &thread_keys, k](auto direction_tag) {
205+
const bool is_successful_dispatch = cub::detail::params::dispatch_discrete(
206+
select_directions, segment_id, [this, &thread_keys, k, segment_size](auto direction_tag) {
206207
if constexpr (decltype(direction_tag)::value == detail::topk::select::max)
207208
{
208-
block_topk_t(temp_storage.topk).max_keys(thread_keys, k);
209+
block_topk_t(temp_storage.topk).template max_keys<is_full_tile>(thread_keys, k, segment_size);
209210
}
210211
else
211212
{
212-
block_topk_t(temp_storage.topk).min_keys(thread_keys, k);
213+
block_topk_t(temp_storage.topk).template min_keys<is_full_tile>(thread_keys, k, segment_size);
213214
}
214215
});
215216
_CCCL_ASSERT(is_successful_dispatch, "Error: Unsupported select direction");
216217
}
217218
else
218219
{
219220
// Pass both keys and values
220-
const bool is_successful_dispatch = detail::params::dispatch_discrete(
221-
select_directions, segment_id, [this, &thread_keys, &thread_values, k](auto direction_tag) {
221+
const bool is_successful_dispatch = cub::detail::params::dispatch_discrete(
222+
select_directions, segment_id, [this, &thread_keys, &thread_values, k, segment_size](auto direction_tag) {
222223
if constexpr (decltype(direction_tag)::value == detail::topk::select::max)
223224
{
224-
block_topk_t(temp_storage.topk).max_pairs(thread_keys, thread_values, k);
225+
block_topk_t(temp_storage.topk)
226+
.template max_pairs<is_full_tile>(thread_keys, thread_values, k, segment_size);
225227
}
226228
else
227229
{
228-
block_topk_t(temp_storage.topk).min_pairs(thread_keys, thread_values, k);
230+
block_topk_t(temp_storage.topk)
231+
.template min_pairs<is_full_tile>(thread_keys, thread_values, k, segment_size);
229232
}
230233
});
231234
_CCCL_ASSERT(is_successful_dispatch, "Error: Unsupported select direction");

cub/cub/block/block_topk.cuh

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
33

4-
//! @file
5-
//! The @c cub::detail::block_topk class provides a :ref:`collective <collective-primitives>` method for selecting the
6-
//! top-k elements from a set of items within a CUDA thread block.
7-
84
#pragma once
95

106
#include <cub/config.cuh>
@@ -17,63 +13,75 @@
1713
# pragma system_header
1814
#endif // no system header
1915

20-
#include <cub/block/block_radix_sort.cuh>
21-
#include <cub/util_ptx.cuh>
16+
#include <cub/block/specializations/block_topk_air.cuh>
17+
#include <cub/device/dispatch/dispatch_common.cuh>
2218
#include <cub/util_type.cuh>
2319

2420
CUB_NAMESPACE_BEGIN
2521

2622
namespace detail
2723
{
24+
// TODO (elstehle): Add documentation
2825
template <typename KeyT, int BlockDimX, int ItemsPerThread, typename ValueT = NullType>
2926
class block_topk
3027
{
3128
private:
32-
using BlockRadixSortT = BlockRadixSort<KeyT, BlockDimX, ItemsPerThread, ValueT>;
29+
using internal_block_topk_t = block_topk_air<KeyT, BlockDimX, ItemsPerThread, ValueT>;
3330

3431
public:
3532
struct TempStorage
3633
{
37-
typename BlockRadixSortT::TempStorage sort_storage;
34+
typename internal_block_topk_t::TempStorage topk_storage;
3835
};
3936

4037
private:
41-
TempStorage& temp_storage;
38+
TempStorage& storage;
4239

4340
public:
44-
_CCCL_DEVICE _CCCL_FORCEINLINE block_topk(TempStorage& temp_storage)
45-
: temp_storage(temp_storage)
41+
_CCCL_DEVICE_API _CCCL_FORCEINLINE block_topk(TempStorage& storage)
42+
: storage(storage)
4643
{}
4744

48-
_CCCL_DEVICE _CCCL_FORCEINLINE void max_pairs(
45+
template <bool IsFullTile>
46+
_CCCL_DEVICE_API _CCCL_FORCEINLINE void max_pairs(
4947
KeyT (&keys)[ItemsPerThread],
5048
ValueT (&values)[ItemsPerThread],
51-
int /*k*/,
49+
int k,
50+
int num_valid,
5251
int begin_bit = 0,
5352
int end_bit = sizeof(KeyT) * 8)
5453
{
55-
BlockRadixSortT(temp_storage.sort_storage).SortDescending(keys, values, begin_bit, end_bit);
54+
internal_block_topk_t(storage.topk_storage)
55+
.template select_pairs<detail::topk::select::max, IsFullTile>(keys, values, k, num_valid, begin_bit, end_bit);
5656
}
5757

58-
_CCCL_DEVICE _CCCL_FORCEINLINE void
59-
max_keys(KeyT (&keys)[ItemsPerThread], int /*k*/, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8)
58+
template <bool IsFullTile>
59+
_CCCL_DEVICE_API _CCCL_FORCEINLINE void
60+
max_keys(KeyT (&keys)[ItemsPerThread], int k, int num_valid, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8)
6061
{
61-
BlockRadixSortT(temp_storage.sort_storage).SortDescending(keys, begin_bit, end_bit);
62+
internal_block_topk_t(storage.topk_storage)
63+
.template select_keys<detail::topk::select::max, IsFullTile>(keys, k, num_valid, begin_bit, end_bit);
6264
}
63-
_CCCL_DEVICE _CCCL_FORCEINLINE void min_pairs(
65+
66+
template <bool IsFullTile>
67+
_CCCL_DEVICE_API _CCCL_FORCEINLINE void min_pairs(
6468
KeyT (&keys)[ItemsPerThread],
6569
ValueT (&values)[ItemsPerThread],
66-
int /*k*/,
70+
int k,
71+
int num_valid,
6772
int begin_bit = 0,
6873
int end_bit = sizeof(KeyT) * 8)
6974
{
70-
BlockRadixSortT(temp_storage.sort_storage).Sort(keys, values, begin_bit, end_bit);
75+
internal_block_topk_t(storage.topk_storage)
76+
.template select_pairs<detail::topk::select::min, IsFullTile>(keys, values, k, num_valid, begin_bit, end_bit);
7177
}
7278

73-
_CCCL_DEVICE _CCCL_FORCEINLINE void
74-
min_keys(KeyT (&keys)[ItemsPerThread], int /*k*/, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8)
79+
template <bool IsFullTile>
80+
_CCCL_DEVICE_API _CCCL_FORCEINLINE void
81+
min_keys(KeyT (&keys)[ItemsPerThread], int k, int num_valid, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8)
7582
{
76-
BlockRadixSortT(temp_storage.sort_storage).Sort(keys, begin_bit, end_bit);
83+
internal_block_topk_t(storage.topk_storage)
84+
.template select_keys<detail::topk::select::min, IsFullTile>(keys, k, num_valid, begin_bit, end_bit);
7785
}
7886
};
7987
} // namespace detail

0 commit comments

Comments
 (0)