diff --git a/paddle/phi/kernels/funcs/top_k_cuda_kernel.h b/paddle/phi/kernels/funcs/top_k_cuda_kernel.h new file mode 100644 index 00000000000000..368cb21c217b17 --- /dev/null +++ b/paddle/phi/kernels/funcs/top_k_cuda_kernel.h @@ -0,0 +1,2532 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_PHI_KERNELS_FUNCS_TOP_K_CUDA_KERNEL_H_ +#define PADDLE_PHI_KERNELS_FUNCS_TOP_K_CUDA_KERNEL_H_ + +// GPU TopK kernel implementation using radix-select and multi-tier sorting. + +#include +#include +#include + +// Include top_k_function_cuda.h to get CUB NumericTraits for float16/bfloat16. +// This header includes cub/cub.cuh and defines the required traits. +#include "paddle/phi/kernels/funcs/top_k_function_cuda.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/phi/kernels/argsort_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/take_along_axis_kernel.h" + +// ============================================================================ +// Helper definitions +// All helpers are placed in an anonymous namespace to avoid ODR conflicts +// with Paddle's existing implementations. +// ============================================================================ + +namespace topk_detail { + +// Stream type alias: gpuStream_t is in phi:: namespace, bring it into scope +using phi::gpuStream_t; + +// --- Constants --- +constexpr int MAX_TENSORINFO_DIMS = 25; +constexpr int64_t MAX_GRID_SIZE = 65535LL; + +// --- ceil_div and round_up --- +template +__host__ __device__ __forceinline__ T topk_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__host__ __device__ __forceinline__ T topk_round_up(T a, T b) { + return topk_ceil_div(a, b) * b; +} + +// --- getGridFromTiles --- +inline bool getGridFromTiles(int64_t gridTiles, dim3* grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + int64_t gridY = 1; + int64_t gridZ = 1; + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = topk_ceil_div(gridTiles, (int64_t)MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = topk_ceil_div(gridTiles, (int64_t)MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + *grid = dim3(gridX, gridY, gridZ); + return true; +} + +// --- getLinearBlockId --- +template +__device__ __forceinline__ index_t getLinearBlockId() { + return static_cast(blockIdx.z) * gridDim.y * gridDim.x + + static_cast(blockIdx.y) * gridDim.x + blockIdx.x; +} + +// --- doLdg --- +// Generic fallback for custom types (phi::float16, phi::bfloat16, etc.) +template +__device__ __forceinline__ T doLdg(const T* p) { + return *p; +} + +// Specializations for built-in types that support __ldg +#if !defined(__HIPCC__) +template <> +__device__ __forceinline__ float doLdg(const float* p) { +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ double doLdg(const double* p) { +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ int doLdg(const int* p) { +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ unsigned int doLdg(const unsigned int* p) { +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ long long doLdg( // NOLINT + const long long* p) { // NOLINT +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ unsigned long long doLdg( // NOLINT + const unsigned long long* p) { // NOLINT +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +template <> +__device__ __forceinline__ int16_t doLdg(const int16_t* p) { +#if __CUDA_ARCH__ >= 350 + return __ldg(p); +#else + return *p; +#endif +} +#endif // !__HIPCC__ + +// --- Bitfield --- +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __forceinline__ unsigned int getBitfield(unsigned int val, + int pos, + int len) { + unsigned int ret; +#if defined(__HIPCC__) + ret = (val >> pos) & ((1u << len) - 1u); +#else + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); +#endif + return ret; + } + + static __device__ __forceinline__ unsigned int setBitfield( + unsigned int val, unsigned int to_insert, int pos, int len) { + unsigned int ret; +#if defined(__HIPCC__) + unsigned int mask = ((1u << len) - 1u) << pos; + ret = (val & ~mask) | ((to_insert << pos) & mask); +#else + asm("bfi.b32 %0, %1, %2, %3, %4;" + : "=r"(ret) + : "r"(to_insert), "r"(val), "r"(pos), "r"(len)); +#endif + return ret; + } +}; + +template <> +struct Bitfield { + static __device__ __forceinline__ uint64_t getBitfield(uint64_t val, + int pos, + int len) { + uint64_t ret; +#if defined(__HIPCC__) + ret = (val >> pos) & ((1ULL << len) - 1ULL); +#else + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); +#endif + return ret; + } + + static __device__ __forceinline__ uint64_t setBitfield(uint64_t val, + uint64_t to_insert, + int pos, + int len) { + uint64_t ret; +#if defined(__HIPCC__) + uint64_t mask = ((1ULL << len) - 1ULL) << pos; + ret = (val & ~mask) | ((to_insert << pos) & mask); +#else + asm("bfi.b64 %0, %1, %2, %3, %4;" + : "=l"(ret) + : "l"(to_insert), "l"(val), "r"(pos), "r"(len)); +#endif + return ret; + } +}; + +// --- getLaneId / getLaneMaskLe --- +__device__ __forceinline__ int getLaneId() { +#if defined(__HIPCC__) + return __lane_id(); +#else + int laneId; + asm("mov.s32 %0, %%laneid;" : "=r"(laneId)); + return laneId; +#endif +} + +__device__ __forceinline__ unsigned getLaneMaskLe() { +#if defined(__HIPCC__) + // HIP warp size is 64, construct mask for lanes <= current lane + return (getLaneId() == 63) ? 0xFFFFFFFFFFFFFFFFULL + : (1ULL << (getLaneId() + 1)) - 1ULL; +#else + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +#endif +} + +__device__ __forceinline__ unsigned getLaneMaskLt() { +#if defined(__HIPCC__) + return (getLaneId() == 0) ? 0ULL : (1ULL << getLaneId()) - 1ULL; +#else + unsigned mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +#endif +} + +// --- WARP macros --- +#ifdef __HIPCC__ +#define TOPK_WARP_SIZE 64 +#define TOPK_WARP_BALLOT(PREDICATE) __ballot((PREDICATE)) +#define TOPK_WARP_BALLOT_MASK(PREDICATE, MASK) __ballot((PREDICATE)) +#define TOPK_WARP_SHFL_DOWN(VAL, DELTA) \ + __shfl_down((VAL), static_cast(DELTA)) +#else +#define TOPK_WARP_SIZE 32 +#define TOPK_WARP_BALLOT(PREDICATE) __ballot_sync(0xffffffff, (PREDICATE)) +#define TOPK_WARP_BALLOT_MASK(PREDICATE, MASK) \ + __ballot_sync((MASK), (PREDICATE)) +#define TOPK_WARP_SHFL_DOWN(VAL, DELTA) \ + __shfl_down_sync(0xffffffff, (VAL), static_cast(DELTA)) +#endif + +// --- TopKTypeConfig --- +template +struct TopKTypeConfig {}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(float v) { + RadixType x = __float_as_int(v); + RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; + return (v == v) ? (x ^ mask) : 0xffffffff; + } + + static inline __device__ float deconvert(RadixType v) { + RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; + return __int_as_float(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(phi::dtype::float16 v) { + RadixType x = __half_as_ushort(v.to_half()); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + half v_h = v.to_half(); + return (v_h == v_h) ? (x ^ mask) : 0xffff; + } + + static inline __device__ phi::dtype::float16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return static_cast(__ushort_as_half(v ^ mask)); + } +}; + +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(phi::dtype::bfloat16 v) { + RadixType x = v.x; + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; + } + + static inline __device__ phi::dtype::bfloat16 deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + phi::dtype::bfloat16 r; + r.x = (v ^ mask); + return r; + } +}; + +// uint8_t is needed by the radix select +template <> +struct TopKTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType convert(uint8_t v) { return v; } + + static inline __device__ uint8_t deconvert(RadixType v) { return v; } +}; + +// --- TensorInfo --- +template +struct TensorInfo { + T* data; + IndexType sizes[MAX_TENSORINFO_DIMS]; + IndexType strides[MAX_TENSORINFO_DIMS]; + int dims; + + // collapse_dims: merges contiguous dimensions for efficient indexing + // See note on [collapse dims]. + int collapseDims(const int excludeDim = -1) { + int stopDim = (excludeDim == -1) ? dims : excludeDim; + int newIndex = -1; + int oldIndex = 0; + int remappedExcludedDim = -1; + + while (oldIndex < dims) { + // Finds a dimension to collapse into + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + ++oldIndex; + break; + } + + // Collapses dims + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { + sizes[newIndex] *= sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } else { + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } + } + + // Handles excludeDim being set (oldIndex == excludeDim) + if (oldIndex != dims) { + // Preserves excluded dimension + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + remappedExcludedDim = newIndex; + + // Restarts iteration after excludeDim + ++oldIndex; + stopDim = dims; + } + } + + // Handles special case of all dims size 1 + if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { + dims = 1; + sizes[0] = 1; + strides[0] = 1; + return 0; + } + + dims = newIndex + 1; + return remappedExcludedDim; + } +}; + +// --- IndexToOffset --- +template +struct IndexToOffset { + static __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + IndexType offset = 0; + for (int i = Dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + offset += curDimIndex * info.strides[i]; + linearId /= info.sizes[i]; + } + return offset + linearId * info.strides[0]; + } +}; + +// Specialization for Dim == -1 (runtime dims) +template +struct IndexToOffset { + static __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + IndexType offset = 0; + for (int i = info.dims - 1; i > 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + offset += curDimIndex * info.strides[i]; + linearId /= info.sizes[i]; + } + return offset + linearId * info.strides[0]; + } +}; + +// Specialization for Dim == 1 +template +struct IndexToOffset { + static __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + return linearId * info.strides[0]; + } +}; + +// Specialization for Dim == 2 +template +struct IndexToOffset { + static __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + IndexType curDimIndex = linearId % info.sizes[1]; + IndexType offset = curDimIndex * info.strides[1]; + linearId /= info.sizes[1]; + return offset + linearId * info.strides[0]; + } +}; + +// Specialization for Dim == 3 +template +struct IndexToOffset { + static __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + IndexType curDimIndex = linearId % info.sizes[2]; + IndexType offset = curDimIndex * info.strides[2]; + linearId /= info.sizes[2]; + curDimIndex = linearId % info.sizes[1]; + offset += curDimIndex * info.strides[1]; + linearId /= info.sizes[1]; + return offset + linearId * info.strides[0]; + } +}; + +// --- inclusiveBinaryPrefixScan / exclusiveBinaryPrefixScan --- +// Prefix scan utilities + +template +__device__ inline void swapVars(T* t1, T* t2) { + T tmp = *t1; + *t1 = *t2; + *t2 = tmp; +} + +template +__device__ inline void bitonicSwap(K* kA, + V* vA, + bool* validA, + K* kB, + V* vB, + bool* validB, + bool dir, + const Comparator& comp) { + // Invalid entries always sort to the end + bool swap = (comp(*kA, *kB) && *validA) || !*validB; + if (swap == dir) { + swapVars(kA, kB); + swapVars(vA, vB); + swapVars(validA, validB); + } +} + +template +__device__ inline void bitonicSort(K* keys, + V* values, + bool* valid, + const Comparator& comp) { +#pragma unroll + for (unsigned int size = 2; size < Power2SortSize; size *= 2) { + bool flag = ((threadIdx.x & (size / 2)) != 0); + +#pragma unroll + for (unsigned int stride = size / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap(&keys[pos], + &values[pos], + &valid[pos], + &keys[pos + stride], + &values[pos + stride], + &valid[pos + stride], + flag, + comp); + } + } + +#pragma unroll + for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) { + __syncthreads(); + + unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + bitonicSwap(&keys[pos], + &values[pos], + &valid[pos], + &keys[pos + stride], + &values[pos + stride], + &valid[pos + stride], + false, + comp); + } + + __syncthreads(); +} + +template +__global__ void __launch_bounds__(block_dim_x* max_block_dim_y) + bitonicSortKVInPlace(TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo values, + IndexType valueSliceStride, + Comparator comp) { + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + if (blockIndex * blockDim.y >= keySlices) { + return; + } + const bool row_valid = linearIndex < keySlices; + + constexpr int items_per_thread = 2; + constexpr int Power2SortSize = block_dim_x * items_per_thread; + + __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize]; + __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize]; + __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize]; + + auto sharedKeys = blockSharedKeys[threadIdx.y]; + auto sharedValues = blockSharedValues[threadIdx.y]; + auto sharedValid = blockSharedValid[threadIdx.y]; + + const IndexType keyStartOffset = + IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + IndexToOffset::get(linearIndex, values); + +#pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + bool valid = row_valid && idx < keySliceSize; + + sharedKeys[idx] = + valid ? keys.data[idx * keySliceStride + keyStartOffset] : K{}; + sharedValues[idx] = + valid ? values.data[idx * valueSliceStride + valueStartOffset] : V{}; + sharedValid[idx] = valid; + } + + bitonicSort( + sharedKeys, sharedValues, sharedValid, comp); + + if (!row_valid) { + return; + } + +#pragma unroll + for (int k = 0; k < items_per_thread; ++k) { + auto idx = threadIdx.x + k * blockDim.x; + if (idx < keySliceSize) { + keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx]; + values.data[idx * valueSliceStride + valueStartOffset] = + sharedValues[idx]; + } + } +} + +template +struct GTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && (lhs != lhs) && !(rhs != rhs)) || + (static_cast(lhs) > static_cast(rhs)); + } +}; + +template +struct LTOp { + __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const { + return (handleNaN && !(lhs != lhs) && (rhs != rhs)) || + (static_cast(lhs) < static_cast(rhs)); + } +}; + +template +void launch_bitonic_sort(TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo valueInfo, + IndexType valueSliceStride, + bool largest, + gpuStream_t stream) { + constexpr int sort_size = 32; + constexpr int max_block_y = 16; + constexpr int items_per_thread = 2; + constexpr int block_x = sort_size / items_per_thread; + + const int block_y = std::min( + static_cast(max_block_y), + static_cast(std::max(static_cast(1), keySlices))); + dim3 block(block_x, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + getGridFromTiles(grid_count, &grid); + + if (largest) { + bitonicSortKVInPlace + <<>>(keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp()); + } else { + bitonicSortKVInPlace + <<>>(keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp()); + } +} + +// ============================================================================ +// StridedRandomAccessor +// Required by CUB WarpLoad/WarpStore and BlockLoad/BlockStore for strided +// tensor access. +// ============================================================================ + +template +class ConstStridedRandomAccessor { + public: + using difference_type = index_t; + using value_type = const T; + using pointer = const T*; + using reference = const T&; + using iterator_category = std::random_access_iterator_tag; + using PtrType = T*; + using index_type = index_t; + + __host__ __device__ ConstStridedRandomAccessor(PtrType ptr, index_t stride) + : ptr_{ptr}, stride_{stride} {} + __host__ __device__ explicit ConstStridedRandomAccessor(PtrType ptr) + : ptr_{ptr}, stride_{1} {} + __host__ __device__ ConstStridedRandomAccessor() + : ptr_{nullptr}, stride_{1} {} + + __host__ __device__ reference operator*() const { return *ptr_; } + __host__ __device__ const T* operator->() const { + return reinterpret_cast(ptr_); + } + __host__ __device__ reference operator[](index_t idx) const { + return ptr_[idx * stride_]; + } + + __host__ __device__ ConstStridedRandomAccessor& operator++() { + ptr_ += stride_; + return *this; + } + __host__ __device__ ConstStridedRandomAccessor operator++(int) { + ConstStridedRandomAccessor copy(*this); + ++*this; + return copy; + } + __host__ __device__ ConstStridedRandomAccessor& operator--() { + ptr_ -= stride_; + return *this; + } + __host__ __device__ ConstStridedRandomAccessor operator--(int) { + ConstStridedRandomAccessor copy(*this); + --*this; + return copy; + } + __host__ __device__ ConstStridedRandomAccessor& operator+=(index_t offset) { + ptr_ += offset * stride_; + return *this; + } + __host__ __device__ ConstStridedRandomAccessor + operator+(index_t offset) const { + return ConstStridedRandomAccessor(ptr_ + offset * stride_, stride_); + } + __host__ __device__ friend ConstStridedRandomAccessor operator+( + index_t offset, const ConstStridedRandomAccessor& accessor) { + return accessor + offset; + } + __host__ __device__ ConstStridedRandomAccessor& operator-=(index_t offset) { + ptr_ -= offset * stride_; + return *this; + } + __host__ __device__ ConstStridedRandomAccessor + operator-(index_t offset) const { + return ConstStridedRandomAccessor(ptr_ - offset * stride_, stride_); + } + __host__ __device__ difference_type + operator-(const ConstStridedRandomAccessor& other) const { + return (ptr_ - other.ptr_) / stride_; + } + __host__ __device__ bool operator==( + const ConstStridedRandomAccessor& other) const { + return (ptr_ == other.ptr_) && (stride_ == other.stride_); + } + __host__ __device__ bool operator!=( + const ConstStridedRandomAccessor& other) const { + return !(*this == other); + } + __host__ __device__ bool operator<( + const ConstStridedRandomAccessor& other) const { + return ptr_ < other.ptr_; + } + __host__ __device__ bool operator<=( + const ConstStridedRandomAccessor& other) const { + return (*this < other) || (*this == other); + } + __host__ __device__ bool operator>( + const ConstStridedRandomAccessor& other) const { + return !(*this <= other); + } + __host__ __device__ bool operator>=( + const ConstStridedRandomAccessor& other) const { + return !(*this < other); + } + + protected: + PtrType ptr_; + index_t stride_; +}; + +template +class StridedRandomAccessor : public ConstStridedRandomAccessor { + public: + using difference_type = index_t; + using value_type = T; + using pointer = T*; + using reference = T&; + using BaseType = ConstStridedRandomAccessor; + using PtrType = T*; + + __host__ __device__ StridedRandomAccessor(PtrType ptr, index_t stride) + : BaseType(ptr, stride) {} + __host__ __device__ explicit StridedRandomAccessor(PtrType ptr) + : BaseType(ptr) {} + __host__ __device__ StridedRandomAccessor() : BaseType() {} + + __host__ __device__ reference operator*() const { return *this->ptr_; } + __host__ __device__ T* operator->() const { + return reinterpret_cast(this->ptr_); + } + __host__ __device__ reference operator[](index_t idx) const { + return this->ptr_[idx * this->stride_]; + } + + __host__ __device__ StridedRandomAccessor& operator++() { + this->ptr_ += this->stride_; + return *this; + } + __host__ __device__ StridedRandomAccessor operator++(int) { + StridedRandomAccessor copy(*this); + ++*this; + return copy; + } + __host__ __device__ StridedRandomAccessor& operator--() { + this->ptr_ -= this->stride_; + return *this; + } + __host__ __device__ StridedRandomAccessor operator--(int) { + StridedRandomAccessor copy(*this); + --*this; + return copy; + } + __host__ __device__ StridedRandomAccessor& operator+=(index_t offset) { + this->ptr_ += offset * this->stride_; + return *this; + } + __host__ __device__ StridedRandomAccessor operator+(index_t offset) const { + return StridedRandomAccessor(this->ptr_ + offset * this->stride_, + this->stride_); + } + __host__ __device__ friend StridedRandomAccessor operator+( + index_t offset, const StridedRandomAccessor& accessor) { + return accessor + offset; + } + __host__ __device__ StridedRandomAccessor& operator-=(index_t offset) { + this->ptr_ -= offset * this->stride_; + return *this; + } + __host__ __device__ StridedRandomAccessor operator-(index_t offset) const { + return StridedRandomAccessor(this->ptr_ - offset * this->stride_, + this->stride_); + } + __host__ __device__ difference_type operator-(const BaseType& other) const { + return (static_cast(*this) - other); + } +}; + +// ============================================================================ +// CubKeyType mapping - maps Paddle types to CUB-compatible CUDA types +// For BlockRadixSort, CUB needs __half / __nv_bfloat16 instead of +// phi::float16 / phi::bfloat16. +// ============================================================================ +template +struct CubKeyType { + using type = T; +}; + +template <> +struct CubKeyType { + using type = __half; +}; + +template <> +struct CubKeyType { +#if defined(__HIPCC__) + using type = hip_bfloat16; +#else + using type = __nv_bfloat16; +#endif +}; + +// ============================================================================ +// Utility functions +// ============================================================================ + +inline int64_t nextHighestPowerOf2(int64_t n) { + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + n++; + return n; +} + +template +static int minimum_grid_for_occupancy(T kernel, int max_block_size) { + int minGridSize = 0; + int blockSize = 0; + cudaOccupancyMaxPotentialBlockSize( + &minGridSize, &blockSize, kernel, /*dynamicSMemSize=*/0, max_block_size); + return minGridSize; +} + +template +constexpr bool type_has_nan() { + if constexpr (std::numeric_limits::is_specialized) { + return std::numeric_limits::has_quiet_NaN; + } else if constexpr (std::is_same_v || // NOLINT + std::is_same_v) { + return true; + } else { + return false; + } +} + +// ============================================================================ +// warpMergeSortKVInPlace kernel +// For sort sizes 33..128, uses CUB WarpMergeSort (one warp per slice, +// multiple slices per block via blockDim.y). +// ============================================================================ + +template +__global__ void __launch_bounds__(32 * max_block_dim_y) + warpMergeSortKVInPlace(TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo values, + IndexType valueSliceStride, + Comparator comp, + K invalid_key) { + const IndexType blockIndex = getLinearBlockId(); + const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y; + + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + IndexToOffset::get(linearIndex, values); + + K* keys_slice = &keys.data[keyStartOffset]; + V* values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, + valueSliceStride); + + constexpr int warp_size = 32; + constexpr int kItemsPerThread = sort_size / warp_size; + static_assert(kItemsPerThread * warp_size == sort_size, + "sort_size must be a multiple of warp_size (32)"); + + using LoadKeys = cub::WarpLoad; + using LoadValues = + cub::WarpLoad; + using Sort = cub::WarpMergeSort; + using StoreKeys = + cub::WarpStore; + using StoreValues = + cub::WarpStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage[max_block_dim_y]; + + auto& warp_storage = tmp_storage[threadIdx.y]; + + K local_keys[kItemsPerThread]; + V local_values[kItemsPerThread]; + + const auto invalid_value = V{}; + LoadKeys(warp_storage.load_keys) + .Load(keys_iter, local_keys, keySliceSize, invalid_key); +#if !defined(__HIPCC__) + __syncwarp(); +#endif + LoadValues(warp_storage.load_values) + .Load(values_iter, local_values, keySliceSize, invalid_value); +#if !defined(__HIPCC__) + __syncwarp(); +#endif + + Sort(warp_storage.sort) + .StableSort(local_keys, local_values, comp, keySliceSize, invalid_key); +#if !defined(__HIPCC__) + __syncwarp(); +#endif + + StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); +#if !defined(__HIPCC__) + __syncwarp(); +#endif + StoreValues(warp_storage.store_values) + .Store(values_iter, local_values, keySliceSize); +} + +// ============================================================================ +// radixSortKVInPlace kernel +// For sort sizes 129..4096, uses CUB BlockRadixSort (one block per slice). +// ============================================================================ + +template +__global__ void __launch_bounds__(block_size) + radixSortKVInPlace(TensorInfo keys, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo values, + IndexType valueSliceStride, + bool descending) { + static_assert(block_size > 0, ""); + + const IndexType linearIndex = getLinearBlockId(); + if (linearIndex >= keySlices) { + return; + } + + const IndexType keyStartOffset = + IndexToOffset::get(linearIndex, keys); + const IndexType valueStartOffset = + IndexToOffset::get(linearIndex, values); + + K* keys_slice = &keys.data[keyStartOffset]; + V* values_slice = &values.data[valueStartOffset]; + + StridedRandomAccessor keys_iter(keys_slice, keySliceStride); + StridedRandomAccessor values_iter(values_slice, + valueSliceStride); + + using key_t = typename CubKeyType::type; + using LoadKeys = + cub::BlockLoad; + using LoadValues = + cub::BlockLoad; + using Sort = cub::BlockRadixSort; + using StoreKeys = cub:: + BlockStore; + using StoreValues = cub:: + BlockStore; + + __shared__ union { + typename LoadKeys::TempStorage load_keys; + typename LoadValues::TempStorage load_values; + typename Sort::TempStorage sort; + typename StoreKeys::TempStorage store_keys; + typename StoreValues::TempStorage store_values; + } tmp_storage; + + // Compute invalid key: always sorts higher than any valid key + const K invalid_key = [descending] { + using radix_t = typename cub::Traits::UnsignedBits; + union { + K key; + radix_t radix; + } tmp; + tmp.radix = descending ? cub::Traits::LOWEST_KEY + : cub::Traits::MAX_KEY; + return tmp.key; + }(); + const V invalid_value = static_cast(0); + + K local_keys[kItemsPerThread]; + V local_values[kItemsPerThread]; + + LoadKeys(tmp_storage.load_keys) + .Load(keys_iter, local_keys, keySliceSize, invalid_key); + __syncthreads(); + LoadValues(tmp_storage.load_values) + .Load(values_iter, local_values, keySliceSize, invalid_value); + __syncthreads(); + + if (descending) { + Sort(tmp_storage.sort) + .SortDescending(reinterpret_cast(local_keys), + local_values); + } else { + Sort(tmp_storage.sort) + .Sort(reinterpret_cast(local_keys), + local_values); + } + __syncthreads(); + + StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize); + __syncthreads(); + StoreValues(tmp_storage.store_values) + .Store(values_iter, local_values, keySliceSize); +} + +// ============================================================================ +// launch_warp_merge_sort - wrapper for CUB WarpMergeSort<128> +// ============================================================================ + +template +void launch_warp_merge_sort(TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo valueInfo, + IndexType valueSliceStride, + bool largest, + gpuStream_t stream) { + constexpr int sort_size = 128; + constexpr int max_block_dim_y = 16; + constexpr int warp_size = 32; + + // Scale batch size down if the grid would be too small + const auto min_grid = + minimum_grid_for_occupancy(warpMergeSortKVInPlace, + IndexType>, + warp_size * max_block_dim_y); + const auto max_batch = + std::max(IndexType{1}, keySlices / (IndexType)min_grid); + const int block_y = std::min((IndexType)max_block_dim_y, max_batch); + dim3 block(warp_size, block_y); + + dim3 grid; + const int grid_count = (keySlices + block_y - 1) / block_y; + getGridFromTiles(grid_count, &grid); + + if (largest) { + // Use numeric limits for invalid_key: lower_bound for descending + const T invalid_key = std::numeric_limits::lowest(); + warpMergeSortKVInPlace + <<>>(keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + GTOp(), + invalid_key); + } else { + // For ascending: NAN sorts after inf, otherwise use upper_bound + const T invalid_key = [] { + if constexpr (type_has_nan()) { + return T(NAN); + } + return std::numeric_limits::max(); + }(); + warpMergeSortKVInPlace + <<>>(keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + LTOp(), + invalid_key); + } +} + +// ============================================================================ +// launch_medium_radix_sort - wrapper for CUB BlockRadixSort +// ============================================================================ + +template +void fixed_size_radix_sort(TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending, + gpuStream_t stream) { + static_assert(sort_size % items_per_thread == 0, ""); + constexpr int block = sort_size / items_per_thread; + dim3 grid; + getGridFromTiles(keySlices, &grid); + + radixSortKVInPlace + <<>>(keyInfo, + keySlices, + keySliceSize, + keySliceStride, + valueInfo, + valueSliceStride, + descending); +} + +template +void launch_medium_radix_sort(TensorInfo keyInfo, + IndexType keySlices, + IndexType keySliceSize, + IndexType keySliceStride, + TensorInfo valueInfo, + IndexType valueSliceStride, + bool descending, + gpuStream_t stream) { + int64_t ceilPowerOf2 = nextHighestPowerOf2(keySliceSize); + constexpr int default_ipt = 32; + +#define HANDLE_RADIX_CASE(SIZE, IPT) \ + fixed_size_radix_sort(keyInfo, \ + keySlices, \ + keySliceSize, \ + keySliceStride, \ + valueInfo, \ + valueSliceStride, \ + descending, \ + stream) + + switch (ceilPowerOf2) { + case 4096: + HANDLE_RADIX_CASE(4096, default_ipt); + break; + case 2048: + HANDLE_RADIX_CASE(2048, default_ipt); + break; + case 1024: + case 512: + case 256: + HANDLE_RADIX_CASE(1024, default_ipt); + break; + // sizes <= 128 should have been handled by WarpMergeSort + default: + break; + } +#undef HANDLE_RADIX_CASE +} + +template +__device__ void inclusiveBinaryPrefixScan(T* smem, + bool in, + T* out, + BinaryFunction binop) { + T vote = TOPK_WARP_BALLOT(in); + T index = __popc(getLaneMaskLe() & vote); + T carry = __popc(vote); + + int warp = threadIdx.x / TOPK_WARP_SIZE; + + if (getLaneId() == 0) { + smem[warp] = carry; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + int current = 0; + for (int i = 0; + i < topk_ceil_div(static_cast(blockDim.x), TOPK_WARP_SIZE); + ++i) { + T v = smem[i]; + smem[i] = binop(smem[i], current); + current = binop(current, v); + } + } + + __syncthreads(); + + if (warp >= 1) { + index = binop(index, smem[warp - 1]); + } + + *out = index; + + if (KillWARDependency) { + __syncthreads(); + } +} + +template +__device__ void exclusiveBinaryPrefixScan( + T* smem, bool in, T* out, T* carry, BinaryFunction binop) { + inclusiveBinaryPrefixScan(smem, in, out, binop); + *out -= static_cast(in); + *carry = + smem[topk_ceil_div(static_cast(blockDim.x), TOPK_WARP_SIZE) - 1]; + if (KillWARDependency) { + __syncthreads(); + } +} + +// --- AddOp --- +template +struct AddOp { + __device__ __forceinline__ T operator()(T const& lhs, T const& rhs) { + return (lhs + rhs); + } +}; + +// ============================================================================ +// SortingRadixSelect.cuh ported content +// ============================================================================ + +namespace radix_select { + +// Over what radix we are selecting values (single-block variant) +constexpr int RADIX_BITS = 2; +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +// CountType is separate from IndexType — counts always fit in int32 +// because indices are limited to integer fp precision. +template +__device__ void countRadixUsingMask(const T* data, + CountType counts[RadixSize], + CountType* smem, + RadixType desired, + RadixType desiredMask, + int radixDigitPos, + IndexType sliceSize, + IndexType withinSliceStride) { +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + if (threadIdx.x < RadixSize) { + smem[threadIdx.x] = 0; + } + __syncthreads(); + + // Must be called outside of loop to ensure all threads participate. + // This creates a dynamic mask of which threads will enter the loop. + // When sliceSize < blockDim.x, only threads with threadIdx.x < sliceSize + // will enter the loop body, so we need a mask to avoid deadlock in + // __ballot_sync. +#if !defined(__HIPCC__) + unsigned mask = TOPK_WARP_BALLOT(threadIdx.x < sliceSize); +#endif + for (IndexType i = threadIdx.x; i < sliceSize;) { + RadixType val = + TopKTypeConfig::convert(doLdg(&data[i * withinSliceStride])); + bool hasVal = ((val & desiredMask) == desired); + RadixType digitInRadix = + Bitfield::getBitfield(val, radixDigitPos, RadixBits); +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = hasVal && (digitInRadix == j); +#if defined(__HIPCC__) + counts[j] += __popcll(TOPK_WARP_BALLOT(vote)); +#else + counts[j] += __popc(TOPK_WARP_BALLOT_MASK(vote, mask)); +#endif + } + i += blockDim.x; +#if !defined(__HIPCC__) + mask = TOPK_WARP_BALLOT_MASK(i < sliceSize, mask); +#endif + } + + if (getLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + atomicAdd(&smem[i], counts[i]); + } + } + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = smem[i]; + } + __syncthreads(); +} + +template +__device__ T findPattern(const T* data, + T* smem, + IndexType sliceSize, + IndexType withinSliceStride, + RadixType desired, + RadixType desiredMask) { + if (threadIdx.x < 2) { + smem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + IndexType numIterations = topk_round_up(sliceSize, (IndexType)blockDim.x); + for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < sliceSize); + T v = inRange ? doLdg(&data[i * withinSliceStride]) : static_cast(0); + if (inRange && ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + smem[0] = static_cast(1); + smem[1] = v; + } + __syncthreads(); + T found = smem[0]; + T val = smem[1]; + __syncthreads(); + if (found != static_cast(0)) { + return val; + } + } + // should not get here + assert(false); + return static_cast(0); +} + +template +__device__ void radixSelect(const T* data, + IndexType k, + bool largest, + IndexType sliceSize, + IndexType withinSliceStride, + int* smem, + T* topKValue) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of IndexType + int counts[RADIX_SIZE]; + RadixType desired = 0; + RadixType desiredMask = 0; + + IndexType kToFind = k; + +#pragma unroll + for (int digitPos = sizeof(T) * 8 - RADIX_BITS; digitPos >= 0; + digitPos -= RADIX_BITS) { + countRadixUsingMask( + data, + counts, + smem, + desired, + desiredMask, + digitPos, + sliceSize, + withinSliceStride); + + auto found_unique = [&](int i, int count) -> bool { + if (count == 1 && kToFind == 1) { + desired = + Bitfield::setBitfield(desired, i, digitPos, RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + *topKValue = + findPattern(data, + reinterpret_cast(smem), + sliceSize, + withinSliceStride, + desired, + desiredMask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= kToFind) { + desired = + Bitfield::setBitfield(desired, i, digitPos, RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, RADIX_MASK, digitPos, RADIX_BITS); + return true; + } + kToFind -= count; + return false; + }; + + if (largest) { +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) return; + if (found_non_unique(i, count)) break; + } + } else { +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) return; + if (found_non_unique(i, count)) break; + } + } + } + *topKValue = TopKTypeConfig::deconvert(desired); +} + +} // namespace radix_select + +// ============================================================================ +// CUDA_KERNEL_LOOP_TYPE macro +// ============================================================================ +#define TOPK_CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \ + for (index_type i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +// ============================================================================ +// CUB_SUPPORTS_SCAN_BY_KEY check +// CUB >= 1.15 supports DeviceScan::InclusiveSumByKey +// ============================================================================ +#ifndef __HIPCC__ +// CUDA path: check CUB version +#if defined(CUB_VERSION) && CUB_VERSION >= 101500 +#define TOPK_CUB_SUPPORTS_SCAN_BY_KEY() 1 +#else +// Try to detect based on CUDA version (CUDA 11.6+ bundles CUB >= 1.15) +#if CUDART_VERSION >= 11060 +#define TOPK_CUB_SUPPORTS_SCAN_BY_KEY() 1 +#else +#define TOPK_CUB_SUPPORTS_SCAN_BY_KEY() 0 +#endif +#endif +#else +// HIP/ROCm path +#define TOPK_CUB_SUPPORTS_SCAN_BY_KEY() 0 +#endif + +} // namespace topk_detail + +// ============================================================================ +// Main TopK implementation +// ============================================================================ + +namespace topk_impl { + +using namespace topk_detail; // NOLINT + +// getTensorInfo: builds TensorInfo from DenseTensor +template +TensorInfo getTensorInfo(const phi::DenseTensor& tensor) { + TensorInfo info; + info.data = reinterpret_cast(const_cast(tensor.data())); + info.dims = tensor.dims().size(); + for (int i = 0; i < info.dims; i++) { + info.sizes[i] = tensor.dims()[i]; + info.strides[i] = tensor.strides()[i]; + } + return info; +} + +// SegmentOffsetIter for sorted output - must be at namespace scope for CUDA +struct SegmentOffsetIter { + int64_t k; + __host__ __device__ __forceinline__ int64_t operator()(int64_t idx) const { + return idx * k; + } +}; + +template +void sortKeyValueInplace(const Context& dev_ctx, + phi::DenseTensor* out, + phi::DenseTensor* indices, + int axis, + bool largest) { + const auto& out_dims = out->dims(); + int dim = axis; + int64_t sliceSize = out_dims[dim]; + int64_t numSlices = out->numel() / sliceSize; + auto stream = dev_ctx.stream(); + + if (sliceSize <= 1) return; + + auto keyInfo = getTensorInfo(*out); + auto valueInfo = getTensorInfo(*indices); + + auto strideKey = keyInfo.strides[dim]; + keyInfo.sizes[dim] = 1; + int collapseKeyDim = keyInfo.collapseDims(dim); + keyInfo.strides[collapseKeyDim] = strideKey; + + auto strideValue = valueInfo.strides[dim]; + valueInfo.sizes[dim] = 1; + int collapseValueDim = valueInfo.collapseDims(dim); + valueInfo.strides[collapseValueDim] = strideValue; + + // Three-tier sort dispatch: + // 1. sliceSize <= 32: Bitonic Sort (unstable, fast, no extra memory) + // 2. sliceSize <= 128: WarpMergeSort (CUB, one slice per warp) + // 3. sliceSize <= 4096: BlockRadixSort (CUB, one slice per block) + // Dispatch on the actual number of collapsed dims (keyInfo.dims), + // NOT on collapseKeyDim (the remapped excluded-dim index). + // When the excluded dim is in the middle (e.g. dim=1 of a 3-D tensor), + // collapseKeyDim==1 but keyInfo.dims==3; using DIM=1 would make + // IndexToOffset ignore the trailing dimensions, producing wrong offsets. + +#define TOPK_SORT_DIM_DISPATCH(LAUNCH_FUNC) \ + if (keyInfo.dims == 1) { \ + LAUNCH_FUNC(1); \ + } else if (keyInfo.dims == 2) { \ + LAUNCH_FUNC(2); \ + } else if (keyInfo.dims == 3) { \ + LAUNCH_FUNC(3); \ + } else { \ + LAUNCH_FUNC(-1); \ + } + + if (sliceSize <= 32) { + // Bitonic sort (unstable) +#define LAUNCH_BITONIC(DIM) \ + launch_bitonic_sort(keyInfo, \ + numSlices, \ + sliceSize, \ + strideKey, \ + valueInfo, \ + strideValue, \ + largest, \ + stream) + TOPK_SORT_DIM_DISPATCH(LAUNCH_BITONIC); +#undef LAUNCH_BITONIC + } else if (sliceSize <= 128) { + // WarpMergeSort (stable, uses CUB WarpMergeSort) +#define LAUNCH_WARP(DIM) \ + launch_warp_merge_sort(keyInfo, \ + numSlices, \ + sliceSize, \ + strideKey, \ + valueInfo, \ + strideValue, \ + largest, \ + stream) + TOPK_SORT_DIM_DISPATCH(LAUNCH_WARP); +#undef LAUNCH_WARP + } else { + // BlockRadixSort (for sizes up to 4096) + bool descending = largest; +#define LAUNCH_RADIX(DIM) \ + launch_medium_radix_sort(keyInfo, \ + numSlices, \ + sliceSize, \ + strideKey, \ + valueInfo, \ + strideValue, \ + descending, \ + stream) + TOPK_SORT_DIM_DISPATCH(LAUNCH_RADIX); +#undef LAUNCH_RADIX + } + +#undef TOPK_SORT_DIM_DISPATCH +} + +namespace sbtopk { // single_block_topk + +template +__global__ void __launch_bounds__(1024) + gatherTopK(TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, // aka `k` + bool largest, + IndexType numInputSlices, + IndexType inputWithinSliceStride, + TensorInfo topK, + IndexType topKWithinSliceStride, + TensorInfo indices, + IndexType indicesWithinSliceStride, + T* kthValues) { + // Indices are limited to integer fp precision, so counts can fit in + // int32, regardless of IndexType +#if defined(__HIPCC__) + __shared__ int smem[64]; +#else + __shared__ int smem[32]; // one per each warp, up to warp limit +#endif + IndexType slice = getLinearBlockId(); + if (slice >= numInputSlices) { + return; + } + + // Find the start offset for our slice + IndexType sliceStartIndex = + IndexToOffset::get(slice, input); + IndexType topKSliceStartIndex = + IndexToOffset::get(slice, topK); + IndexType indicesSliceStartIndex = + IndexToOffset::get(slice, indices); + + const T* inputSliceStart = &input.data[sliceStartIndex]; + T* topKSliceStart = &topK.data[topKSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + // Find the k-th highest element in our input + T topKValue; + if (WithKthValues) { + topKValue = kthValues[slice]; + } else { + topKValue = static_cast(0); + radix_select::radixSelect::RadixType, + IndexType>(inputSliceStart, + outputSliceSize, + largest, + inputSliceSize, + inputWithinSliceStride, + smem, + &topKValue); + } + const auto topKConverted = TopKTypeConfig::convert(topKValue); + + IndexType numIterations = + topk_round_up(inputSliceSize, (IndexType)blockDim.x); + IndexType writeIndexStart = 0; + + for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + T v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + const auto convertedV = TopKTypeConfig::convert(v); + bool hasTopK; + if (largest) { + hasTopK = inRange && (convertedV > topKConverted); + } else { + hasTopK = inRange && (convertedV < topKConverted); + } + + int index; + int carry; + exclusiveBinaryPrefixScan( + smem, hasTopK, &index, &carry, AddOp()); + + if (hasTopK) { + int writeIndex = writeIndexStart + index; + assert(writeIndex < outputSliceSize); + IndexType topKOffset = writeIndex * topKWithinSliceStride; + IndexType indexOffset = writeIndex * indicesWithinSliceStride; + topKSliceStart[topKOffset] = v; + indicesSliceStart[indexOffset] = i; + } + writeIndexStart += carry; + } + + // Fill in the rest with actual == top-K values. + assert(outputSliceSize >= writeIndexStart); + IndexType topKRemaining = (outputSliceSize - writeIndexStart); + + for (IndexType i = threadIdx.x; i < numIterations; i += blockDim.x) { + bool inRange = (i < inputSliceSize); + T v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) + : static_cast(0); + const auto convertedV = TopKTypeConfig::convert(v); + bool hasTopK = inRange && (convertedV == topKConverted); + + int index; + int carry; + exclusiveBinaryPrefixScan( + smem, hasTopK, &index, &carry, AddOp()); + + if (hasTopK && index < topKRemaining) { + int writeIndex = writeIndexStart + index; + assert(writeIndex < outputSliceSize); + IndexType topKOffset = writeIndex * topKWithinSliceStride; + IndexType indexOffset = writeIndex * indicesWithinSliceStride; + topKSliceStart[topKOffset] = v; + indicesSliceStart[indexOffset] = i; + } + + if (carry >= topKRemaining) { + break; + } + topKRemaining -= carry; + writeIndexStart += carry; + } +} + +template +void launch(TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, + bool largest, + IndexType numInputSlices, + IndexType inputWithinSliceStride, + TensorInfo topK, + IndexType topKWithinSliceStride, + TensorInfo indices, + IndexType indicesWithinSliceStride, + gpuStream_t stream) { + dim3 grid; + bool ok = getGridFromTiles(numInputSlices, &grid); + assert(ok); + (void)ok; + int warp_size = TOPK_WARP_SIZE; + dim3 block( + std::min(topk_ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * + (int64_t)warp_size, + (int64_t)1024)); + gatherTopK + <<>>(input, + inputSliceSize, + outputSliceSize, + largest, + numInputSlices, + inputWithinSliceStride, + topK, + topKWithinSliceStride, + indices, + indicesWithinSliceStride, + nullptr); +} + +} // namespace sbtopk + +namespace mbtopk { // multi_block_topk + +constexpr int BLOCK_THREADS = 256; +constexpr int RADIX_BITS = 8; +constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 256 +constexpr int RADIX_MASK = (RADIX_DIGITS - 1); +static_assert( + RADIX_DIGITS <= BLOCK_THREADS, + "radixFindKthValues kernel requires RADIX_DIGITS <= BLOCK_THREADS"); +constexpr int MIN_ITEMS_PER_THREAD = 4; +constexpr int MAX_ITEMS_PER_THREAD = 64; + +template +__global__ void fill(T* x, T value, IndexType size) { + IndexType idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + for (IndexType i = idx; i < size; i += static_cast(gridDim.x) * + static_cast(blockDim.x)) { + x[i] = value; + } +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + radixFindKthValues(TensorInfo input, + uint32_t slice_size, + uint32_t* ks_to_find, + uint32_t num_slices, + IndexType withinSliceStride, + int current_bit, + int items_per_thread, + uint32_t blocks_per_slice, + Bitwise desiredMask, + Bitwise* desires, + int16_t* counts) { + int items_per_block = items_per_thread * BLOCK_THREADS; + int tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + uint32_t slice_idx = block_idx / blocks_per_slice; + uint32_t blk_idx_in_slice = block_idx % blocks_per_slice; + if (slice_idx >= num_slices) { + return; + } + + Bitwise desired = desires[slice_idx]; + IndexType slice_start_index = + IndexToOffset::get(slice_idx, input); + const T* data = &input.data[slice_start_index]; + + static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < + std::numeric_limits::max(), + "blockwise counter too large"); + union __align__(16) TempStorage { + uint32_t digit_counters[RADIX_DIGITS]; + }; + __shared__ TempStorage temp_storage; + + if (tidx < RADIX_DIGITS) { + temp_storage.digit_counters[tidx] = 0; + } + __syncthreads(); + + items_per_thread = + (blk_idx_in_slice + 1 < blocks_per_slice) + ? items_per_thread + : topk_ceil_div( + (int64_t)(slice_size - blk_idx_in_slice * items_per_block), + (int64_t)BLOCK_THREADS); + + for (int i = 0; i < items_per_thread; ++i) { + IndexType idx = + blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx; + if (idx < slice_size) { + idx *= withinSliceStride; + Bitwise val = TopKTypeConfig::convert(doLdg(&data[idx])); + bool has_val = ((val & desiredMask) == (desired & desiredMask)); + Bitwise digit = + Bitfield::getBitfield(val, current_bit, RADIX_BITS); + if (has_val) { + atomicAdd(&temp_storage.digit_counters[digit], 1); + } + } + } + + __syncthreads(); + + static_assert(RADIX_DIGITS <= BLOCK_THREADS, + "this kernel requires RADIX_DIGITS <= BLOCK_THREADS"); + uint32_t digit_count = 0; + if (tidx < RADIX_DIGITS) { + digit_count = temp_storage.digit_counters[tidx]; + } + + if (tidx < RADIX_DIGITS) { + counts[block_idx * RADIX_DIGITS + tidx] = digit_count; + } +} + +template +__global__ void __launch_bounds__(RADIX_DIGITS) + computeBlockwiseWithinKCounts(Bitwise* desires_in, + int16_t* counts, + uint32_t* ks_to_find_in, + uint32_t blocks_per_slice, + int current_bit, + bool largest, + uint32_t* withinKCounts, + T* kthValues, + uint32_t* ks_to_find_out, + Bitwise* desires_out, + uint32_t num_blocks) { + int tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + uint32_t slice_idx = block_idx / blocks_per_slice; + + if (block_idx >= num_blocks) { + return; + } + + typedef cub::BlockScan BlockScan; + union __align__(16) TempStorage { + uint32_t digit_count_cumsum[RADIX_DIGITS]; + typename BlockScan::TempStorage scan_storage; + }; + __shared__ TempStorage temp_storage; + + uint32_t digit_count = 0; + if (tidx < RADIX_DIGITS) { + for (uint32_t blk = 0; blk < blocks_per_slice; ++blk) { + digit_count += + counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + tidx]; + } + } + + uint32_t digit_count_cumsum; + BlockScan(temp_storage.scan_storage) + .InclusiveSum(digit_count, digit_count_cumsum); + __syncthreads(); + if (tidx < RADIX_DIGITS) { + temp_storage.digit_count_cumsum[tidx] = digit_count_cumsum; + } + __syncthreads(); + + __shared__ Bitwise desired; + uint32_t k_to_find = ks_to_find_in[slice_idx]; + + if (tidx < RADIX_DIGITS) { + uint32_t digit_count_cumsum_left = + (tidx == 0) ? 0 : temp_storage.digit_count_cumsum[tidx - 1]; + + if (digit_count_cumsum_left < k_to_find && + k_to_find <= digit_count_cumsum) { + desired = desires_in[slice_idx]; + desired = Bitfield::setBitfield( + desired, tidx, current_bit, RADIX_BITS); + if (block_idx == slice_idx * blocks_per_slice) { + desires_out[slice_idx] = desired; + if (current_bit > 0) { + ks_to_find_out[slice_idx] = k_to_find - digit_count_cumsum_left; + } else { + kthValues[slice_idx] = TopKTypeConfig::deconvert(desired); + } + } + } + } + __syncthreads(); + +#if !TOPK_CUB_SUPPORTS_SCAN_BY_KEY() + return; +#endif + + Bitwise desired_digit = + Bitfield::getBitfield(desired, current_bit, RADIX_BITS); + + bool warp_is_active, thread_is_active; + int warp = tidx / TOPK_WARP_SIZE; + if (largest) { + int end_of_warp = warp * TOPK_WARP_SIZE + TOPK_WARP_SIZE - 1; + warp_is_active = end_of_warp > static_cast(desired_digit); + thread_is_active = tidx > static_cast(desired_digit); + } else { + int start_of_warp = warp * TOPK_WARP_SIZE; + warp_is_active = start_of_warp < static_cast(desired_digit); + thread_is_active = tidx < static_cast(desired_digit); + } + uint32_t count = 0; + if (warp_is_active) { + if (thread_is_active) { + count = doLdg(counts + block_idx * RADIX_DIGITS + tidx); + } + for (int offset = TOPK_WARP_SIZE / 2; offset > 0; offset /= 2) { + count += TOPK_WARP_SHFL_DOWN(count, offset); + } + } + + constexpr int num_warps = RADIX_DIGITS / TOPK_WARP_SIZE; + __shared__ uint32_t warp_counts[num_warps]; + if (tidx % TOPK_WARP_SIZE == 0) { + warp_counts[warp] = count; + } + __syncthreads(); +#ifdef __HIPCC__ + assert(RADIX_DIGITS < TOPK_WARP_SIZE * TOPK_WARP_SIZE); +#else + static_assert(RADIX_DIGITS < TOPK_WARP_SIZE * TOPK_WARP_SIZE, + "Assuming only 1 warp is needed for final reduction"); +#endif + if (warp != 0) { + return; + } + count = 0; + if (tidx < num_warps) { + count = warp_counts[tidx]; + } + for (int offset = num_warps / 2; offset > 0; offset /= 2) { + count += TOPK_WARP_SHFL_DOWN(count, offset); + } + if (tidx == 0) { + withinKCounts[block_idx] += count; + } +} + +#if TOPK_CUB_SUPPORTS_SCAN_BY_KEY() +template +__global__ void computeBlockwiseKthCounts(Bitwise* desires, + int16_t* counts, + uint32_t num_blocks, + uint32_t blocks_per_slice, + uint32_t* kthCounts) { + TOPK_CUDA_KERNEL_LOOP_TYPE(idx, num_blocks, uint32_t) { + uint32_t slice_idx = idx / blocks_per_slice; + Bitwise desired = doLdg(desires + slice_idx); + Bitwise desired_digit = + Bitfield::getBitfield(desired, 0, RADIX_BITS); + kthCounts[idx] = doLdg(counts + idx * RADIX_DIGITS + desired_digit); + } +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + gatherTopK(TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, + bool largest, + uint32_t numInputSlices, + IndexType inputWithinSliceStride, + TensorInfo topK, + IndexType topKWithinSliceStride, + TensorInfo indices, + IndexType indicesWithinSliceStride, + uint32_t items_per_thread, + uint32_t blocks_per_slice, + T* kthValues, + uint32_t* withinKCounts, + uint32_t* kthCounts, + uint32_t num_blocks) { + uint32_t items_per_block = items_per_thread * BLOCK_THREADS; + uint32_t tidx = threadIdx.x; + uint32_t block_idx = getLinearBlockId(); + + if (block_idx >= num_blocks) { + return; + } + + uint32_t slice_idx = block_idx / blocks_per_slice; + uint32_t blk_idx_in_slice = block_idx % blocks_per_slice; + + items_per_thread = + (blk_idx_in_slice + 1 < blocks_per_slice) + ? items_per_thread + : topk_ceil_div( + (int64_t)(inputSliceSize - blk_idx_in_slice * items_per_block), + (int64_t)BLOCK_THREADS); + + IndexType sliceStartIndex = + IndexToOffset::get(slice_idx, input); + IndexType topKSliceStartIndex = + IndexToOffset::get(slice_idx, topK); + IndexType indicesSliceStartIndex = + IndexToOffset::get(slice_idx, indices); + + const T* inputSliceStart = &input.data[sliceStartIndex]; + T* topKSliceStart = &topK.data[topKSliceStartIndex]; + int64_t* indicesSliceStart = &indices.data[indicesSliceStartIndex]; + + T kthValue = kthValues[slice_idx]; + const auto kthValueConverted = TopKTypeConfig::convert(kthValue); + + uint32_t startWithinK = 0; + if (blk_idx_in_slice > 0) { + startWithinK = withinKCounts[block_idx - 1]; + } + uint32_t startKth = + withinKCounts[slice_idx * blocks_per_slice + blocks_per_slice - 1]; + if (blk_idx_in_slice > 0) { + startKth += kthCounts[block_idx - 1]; + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + for (uint32_t i = 0; i < items_per_thread; ++i) { + IndexType idx = + blk_idx_in_slice * items_per_block + i * BLOCK_THREADS + tidx; + T val; + int withinK = 0; + int kth = 0; + if (idx < inputSliceSize) { + val = doLdg(inputSliceStart + idx * inputWithinSliceStride); + const auto valConverted = TopKTypeConfig::convert(val); + withinK = (largest ? valConverted > kthValueConverted + : valConverted < kthValueConverted); + kth = (valConverted == kthValueConverted); + } + + uint32_t withinKIndex; + uint32_t numWithinK; + BlockScan(temp_storage).ExclusiveSum(withinK, withinKIndex, numWithinK); + __syncthreads(); + if (withinK) { + uint32_t offset = withinKIndex + startWithinK; + topKSliceStart[offset * topKWithinSliceStride] = val; + indicesSliceStart[offset * indicesWithinSliceStride] = idx; + } + startWithinK += numWithinK; + + if (startKth < outputSliceSize) { + uint32_t kthIndex; + uint32_t numKth; + BlockScan(temp_storage).ExclusiveSum(kth, kthIndex, numKth); + __syncthreads(); + if (kth) { + uint32_t offset = kthIndex + startKth; + if (offset < outputSliceSize) { + topKSliceStart[offset * topKWithinSliceStride] = val; + indicesSliceStart[offset * indicesWithinSliceStride] = idx; + } + } + startKth += numKth; + } + } +} +#endif // TOPK_CUB_SUPPORTS_SCAN_BY_KEY + +// get_items_per_thread: compute optimal items per thread based on GPU occupancy +int get_items_per_thread(uint64_t num_slices, + uint64_t slice_size, + int device_id) { + constexpr int REGS_PER_THREAD = 40; + constexpr int REGS_PER_BLOCK = REGS_PER_THREAD * BLOCK_THREADS; + const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id); + int mpc = prop.multiProcessorCount; +#ifdef PADDLE_WITH_HIP + // HIP/DCU: hipDeviceProp_t lacks regsPerMultiprocessor and + // maxBlocksPerMultiProcessor. Use conservative defaults: + // 65536 registers per CU is typical for AMD GCN/CDNA architectures. + // maxThreadsPerMultiProcessor / BLOCK_THREADS as blocks_per_mp estimate. + int regs_per_mp = 65536; + int max_blocks_per_mp = prop.maxThreadsPerMultiProcessor / BLOCK_THREADS; +#else + int regs_per_mp = prop.regsPerMultiprocessor; + int max_blocks_per_mp = prop.maxBlocksPerMultiProcessor; +#endif + int blocks_per_mp = std::min(regs_per_mp / REGS_PER_BLOCK, max_blocks_per_mp); + int64_t items_per_thread = + topk_ceil_div((int64_t)(slice_size * num_slices), + (int64_t)(mpc * blocks_per_mp * BLOCK_THREADS)); + items_per_thread = std::max( + MIN_ITEMS_PER_THREAD, + std::min(static_cast(items_per_thread), MAX_ITEMS_PER_THREAD)); + return items_per_thread; +} + +class BlockIdxToKey { + uint32_t blocks_per_slice; + + public: + explicit BlockIdxToKey(uint32_t blocks_per_slice) + : blocks_per_slice(blocks_per_slice) {} + __device__ __forceinline__ uint32_t operator()(uint32_t blk) const { + return blk / blocks_per_slice; + } +}; + +template +void launch(TensorInfo input, + IndexType inputSliceSize, + IndexType outputSliceSize, + bool largest, + uint32_t numInputSlices, + IndexType inputWithinSliceStride, + TensorInfo topK, + IndexType topKWithinSliceStride, + TensorInfo indices, + IndexType indicesWithinSliceStride, + gpuStream_t stream, + int device_id, + const phi::Place& place) { + int items_per_thread = + get_items_per_thread(numInputSlices, inputSliceSize, device_id); + int items_per_block = items_per_thread * BLOCK_THREADS; + + using Bitwise = typename TopKTypeConfig::RadixType; + uint32_t blocks_per_slice = + topk_ceil_div((int64_t)inputSliceSize, (int64_t)items_per_block); + uint32_t num_blocks = numInputSlices * blocks_per_slice; + + // Temporary storage allocation using phi::memory_utils + auto phi_stream = phi::Stream(reinterpret_cast(stream)); + + auto kthValues_buffer = + phi::memory_utils::Alloc(place, numInputSlices * sizeof(T), phi_stream); + T* kthValues = reinterpret_cast(kthValues_buffer->ptr()); + + auto semaphores_buffer = phi::memory_utils::Alloc( + place, numInputSlices * sizeof(uint32_t), phi_stream); + uint32_t* semaphores = reinterpret_cast(semaphores_buffer->ptr()); +#ifdef PADDLE_WITH_HIP + hipMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream); +#else + cudaMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream); +#endif + + auto ks_to_find_buffer = phi::memory_utils::Alloc( + place, 2 * numInputSlices * sizeof(uint32_t), phi_stream); + uint32_t* ks_to_find = reinterpret_cast(ks_to_find_buffer->ptr()); + uint32_t k_to_find = + largest ? inputSliceSize - outputSliceSize + 1 : outputSliceSize; + fill + <<>>(ks_to_find, k_to_find, numInputSlices); + + auto desired_buffer = phi::memory_utils::Alloc( + place, 2 * numInputSlices * sizeof(Bitwise), phi_stream); + Bitwise* desired = reinterpret_cast(desired_buffer->ptr()); + + auto counts_buffer = phi::memory_utils::Alloc( + place, num_blocks * RADIX_DIGITS * sizeof(int16_t), phi_stream); + int16_t* counts = reinterpret_cast(counts_buffer->ptr()); + static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < + std::numeric_limits::max(), + "blockwise counter too large"); + +#if TOPK_CUB_SUPPORTS_SCAN_BY_KEY() + auto withinKCounts_buffer = phi::memory_utils::Alloc( + place, num_blocks * sizeof(uint32_t), phi_stream); + uint32_t* withinKCounts = + reinterpret_cast(withinKCounts_buffer->ptr()); +#ifdef PADDLE_WITH_HIP + hipMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream); +#else + cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream); +#endif + + auto kthCounts_buffer = phi::memory_utils::Alloc( + place, num_blocks * sizeof(uint32_t), phi_stream); + uint32_t* kthCounts = reinterpret_cast(kthCounts_buffer->ptr()); +#else + uint32_t* withinKCounts = nullptr; +#endif + + Bitwise desiredMask = 0; + dim3 grid; + bool ok = getGridFromTiles(num_blocks, &grid); + assert(ok); + (void)ok; + dim3 block(BLOCK_THREADS); + + uint32_t* ks_to_find_in = ks_to_find; + uint32_t* ks_to_find_out = ks_to_find + numInputSlices; + Bitwise* desired_in = desired; + Bitwise* desired_out = desired + numInputSlices; + + for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; + current_bit -= RADIX_BITS) { + radixFindKthValues + <<>>(input, + inputSliceSize, + ks_to_find_in, + numInputSlices, + inputWithinSliceStride, + current_bit, + items_per_thread, + blocks_per_slice, + desiredMask, + desired_in, + counts); + + computeBlockwiseWithinKCounts + <<>>(desired_in, + counts, + ks_to_find_in, + blocks_per_slice, + current_bit, + largest, + withinKCounts, + kthValues, + ks_to_find_out, + desired_out, + num_blocks); + + auto tmp_desired = desired_in; + desired_in = desired_out; + desired_out = tmp_desired; + auto tmp_ks = ks_to_find_in; + ks_to_find_in = ks_to_find_out; + ks_to_find_out = tmp_ks; + // Host-side equivalent of Bitfield::setBitfield(desiredMask, + // RADIX_MASK, current_bit, RADIX_BITS) Cannot use Bitfield::setBitfield + // here because it's __device__-only (uses PTX asm) + { + Bitwise mask = ((Bitwise(1) << RADIX_BITS) - 1) << current_bit; + desiredMask = + (desiredMask & ~mask) | ((Bitwise(RADIX_MASK) << current_bit) & mask); + } + } + desired = desired_in; + +#if TOPK_CUB_SUPPORTS_SCAN_BY_KEY() + computeBlockwiseKthCounts + <<>>(desired, counts, num_blocks, blocks_per_slice, kthCounts); + + // Use cub::DeviceScan::InclusiveSumByKey + using counting_iter_t = cub::CountingInputIterator; + using slice_idx_iter_t = + cub::TransformInputIterator; + slice_idx_iter_t slice_idx_iter(counting_iter_t(0), + BlockIdxToKey(blocks_per_slice)); + + // InclusiveSumByKey for withinKCounts + { + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSumByKey(nullptr, + temp_storage_bytes, + slice_idx_iter, + withinKCounts, + withinKCounts, + num_blocks, + cub::Equality(), + stream); + auto temp_buf = + phi::memory_utils::Alloc(place, temp_storage_bytes, phi_stream); + cub::DeviceScan::InclusiveSumByKey(temp_buf->ptr(), + temp_storage_bytes, + slice_idx_iter, + withinKCounts, + withinKCounts, + num_blocks, + cub::Equality(), + stream); + } + // InclusiveSumByKey for kthCounts + { + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSumByKey(nullptr, + temp_storage_bytes, + slice_idx_iter, + kthCounts, + kthCounts, + num_blocks, + cub::Equality(), + stream); + auto temp_buf = + phi::memory_utils::Alloc(place, temp_storage_bytes, phi_stream); + cub::DeviceScan::InclusiveSumByKey(temp_buf->ptr(), + temp_storage_bytes, + slice_idx_iter, + kthCounts, + kthCounts, + num_blocks, + cub::Equality(), + stream); + } + + gatherTopK + <<>>(input, + inputSliceSize, + outputSliceSize, + largest, + numInputSlices, + inputWithinSliceStride, + topK, + topKWithinSliceStride, + indices, + indicesWithinSliceStride, + items_per_thread, + blocks_per_slice, + kthValues, + withinKCounts, + kthCounts, + num_blocks); +#else + // Fallback: use single-block gatherTopK with kthValues + { + dim3 grid2; + bool ok2 = getGridFromTiles(numInputSlices, &grid2); + assert(ok2); + (void)ok2; + int warp_size = TOPK_WARP_SIZE; + dim3 block2( + std::min(topk_ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * + (int64_t)warp_size, + (int64_t)1024)); + sbtopk::gatherTopK + <<>>(input, + inputSliceSize, + outputSliceSize, + largest, + numInputSlices, + inputWithinSliceStride, + topK, + topKWithinSliceStride, + indices, + indicesWithinSliceStride, + kthValues); + } +#endif +} + +} // namespace mbtopk + +bool should_use_multiblock(int64_t num_slices, int64_t slice_size) { + if (num_slices > std::numeric_limits::max() || + slice_size > std::numeric_limits::max()) + return false; +#if TOPK_CUB_SUPPORTS_SCAN_BY_KEY() + return (num_slices <= 20 && slice_size >= 20000) || + (num_slices > 20 && num_slices <= 40 && slice_size >= 10000) || + (num_slices > 40 && num_slices <= 80 && slice_size >= 8000) || + (num_slices > 80 && num_slices < 200 && slice_size >= 5000) || + (num_slices >= 200 && num_slices < 800 && slice_size >= 3000) || + (num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) || + (num_slices > 4000 && slice_size >= 400); +#else + return (num_slices <= 400 && slice_size >= 5000) || + (num_slices > 400 && num_slices < 4000 && slice_size >= 1000) || + (num_slices >= 4000 && slice_size >= 300); +#endif +} + +// canUse32BitIndexMath: check if tensor indexing fits in 32-bit integers +bool canUse32BitIndexMath( + const phi::DenseTensor& t, + int64_t max_elem = std::numeric_limits::max()) { + int64_t elements = t.numel(); + if (elements >= max_elem) { + return false; + } + if (elements == 0) { + return max_elem > 0; + } + + int64_t offset = 0; + int64_t linearId = elements - 1; + + for (int i = t.dims().size() - 1; i >= 0; --i) { + int64_t curDimIndex = linearId % t.dims()[i]; + int64_t curDimOffset = curDimIndex * t.strides()[i]; + offset += curDimOffset; + linearId /= t.dims()[i]; + } + + if (offset >= max_elem) { + return false; + } + + return true; +} + +} // namespace topk_impl +#endif // PADDLE_PHI_KERNELS_FUNCS_TOP_K_CUDA_KERNEL_H_ diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 0f352315bdbe84..48b1b2bdff671f 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -23,6 +23,9 @@ #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/kernels/funcs/top_k_cuda_kernel.h" +#endif namespace phi { #define FIXED_BLOCK_DIM_BASE(dim, ...) \ @@ -368,8 +371,259 @@ void TopkV1Kernel(const Context& dev_ctx, DenseTensor* indices) { TopkKernel(dev_ctx, x, k_scalar, -1, true, true, out, indices); } + +#ifdef PADDLE_WITH_CUDA +template +void TopkKernelCuda(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& k_scalar, + int axis, + bool largest, + bool sorted, + DenseTensor* out, + DenseTensor* indices) { + const auto& in_dims = x.dims(); + + // Handle empty output (e.g. when k comes from tensor, dims may contain -1) + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(indices); + return; + } + + // 0d input tensor + if (in_dims.size() == 0) { + Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + dev_ctx.template Alloc(indices); + funcs::set_constant(dev_ctx, indices, static_cast(0)); + return; + } + + if (axis < 0) axis += in_dims.size(); + int k = k_scalar.to(); + + // For k=1, call TopkKernel + if (k == 1) { + TopkKernel( + dev_ctx, x, k_scalar, axis, largest, sorted, out, indices); + } + + // Handle k from tensor: output dims may contain -1, resize before Alloc + if (k_scalar.FromTensor()) { + DDim out_dims = out->dims(); + out_dims[axis] = k; + out->Resize(out_dims); + indices->Resize(out_dims); + } + + // Handle empty input + if (x.numel() == 0) { + phi::Full( + dev_ctx, phi::vectorize(out->dims()), static_cast(NAN), out); + phi::Full(dev_ctx, + phi::vectorize(indices->dims()), + static_cast(0), + indices); + return; + } + + // Now safe to allocate output memory + T* output_data = dev_ctx.template Alloc(out); + int64_t* indices_data = dev_ctx.template Alloc(indices); + + phi::DenseTensor input_contiguous; + + if (x.meta().is_contiguous()) { + input_contiguous = x; + } else { + input_contiguous.Resize(x.dims()); + dev_ctx.template Alloc(&input_contiguous); + phi::Copy( + dev_ctx, x, dev_ctx.GetPlace(), false, &input_contiguous); + } + + int64_t sliceSize = in_dims.size() == 0 ? 1 : in_dims[axis]; + int dim = axis; + + auto stream = dev_ctx.stream(); + int device_id = dev_ctx.GetPlace().GetDeviceId(); + auto place = dev_ctx.GetPlace(); + + // Macro: inner kernel launch helpers (same as before) +#define TOPK_RUN_K(INDEX_T, DIM, LAUNCH_FUNCTION_NAME) \ + LAUNCH_FUNCTION_NAME( \ + inputInfo, \ + static_cast(sliceSize), \ + static_cast(k), \ + largest, \ + static_cast(numInputSlices), \ + static_cast(inputInfo.strides[collapseInputDim]), \ + topKInfo, \ + static_cast(topKInfo.strides[collapseTopKDim]), \ + indicesInfo, \ + static_cast(indicesInfo.strides[collapseIndicesDim]), \ + stream) + +#define TOPK_RUN_K_MB(INDEX_T, DIM) \ + topk_impl::mbtopk::launch( \ + inputInfo, \ + static_cast(sliceSize), \ + static_cast(k), \ + largest, \ + static_cast(numInputSlices), \ + static_cast(inputInfo.strides[collapseInputDim]), \ + topKInfo, \ + static_cast(topKInfo.strides[collapseTopKDim]), \ + indicesInfo, \ + static_cast(indicesInfo.strides[collapseIndicesDim]), \ + stream, \ + device_id, \ + place) + +#define TOPK_RUN_MB(INDEX_T, DIM) \ + if (topk_impl::should_use_multiblock(numInputSlices, sliceSize)) { \ + TOPK_RUN_K_MB(INDEX_T, DIM); \ + } else { \ + TOPK_RUN_K(INDEX_T, DIM, topk_impl::sbtopk::launch); \ + } + +#define TOPK_RUN_DIM(INDEX_T) \ + if (allDims == 1) { \ + TOPK_RUN_MB(INDEX_T, 1); \ + } else if (allDims == 2) { \ + TOPK_RUN_MB(INDEX_T, 2); \ + } else if (allDims == 3) { \ + TOPK_RUN_MB(INDEX_T, 3); \ + } else { \ + TOPK_RUN_MB(INDEX_T, -1); \ + } + + // RUN_T: Build TensorInfo, collapse dims, and launch — all parameterized + // by INDEX_T. +#define TOPK_RUN_T(INDEX_T) \ + do { \ + auto inputInfo = \ + topk_impl::getTensorInfo(input_contiguous); \ + auto topKInfo = topk_impl::getTensorInfo(*out); \ + auto indicesInfo = topk_impl::getTensorInfo(*indices); \ + \ + /* Handle 0-d tensor: expand to 1-d */ \ + if (!in_dims.size()) { \ + inputInfo.dims = 1; \ + inputInfo.sizes[0] = 1; \ + inputInfo.strides[0] = 1; \ + topKInfo.dims = 1; \ + topKInfo.sizes[0] = 1; \ + topKInfo.strides[0] = 1; \ + indicesInfo.dims = 1; \ + indicesInfo.sizes[0] = 1; \ + indicesInfo.strides[0] = 1; \ + } \ + \ + /* Set sizes[dim] = 1 to calculate slice offsets */ \ + inputInfo.sizes[dim] = 1; \ + topKInfo.sizes[dim] = 1; \ + indicesInfo.sizes[dim] = 1; \ + \ + /* Stash stride of dim because it can be accidentally collapsed */ \ + auto strideTopK = topKInfo.strides[dim]; \ + auto strideIndices = indicesInfo.strides[dim]; \ + \ + /* Collapse dims */ \ + int collapseInputDim = inputInfo.collapseDims(dim); \ + int collapseTopKDim = topKInfo.collapseDims(dim); \ + int collapseIndicesDim = indicesInfo.collapseDims(dim); \ + \ + /* Restore stride in case it was collapsed */ \ + topKInfo.strides[collapseTopKDim] = strideTopK; \ + indicesInfo.strides[collapseIndicesDim] = strideIndices; \ + \ + int64_t numInputSlices = 1; \ + for (int i = 0; i < inputInfo.dims; ++i) { \ + numInputSlices *= inputInfo.sizes[i]; \ + } \ + \ + int allDims = inputInfo.dims; \ + if (topKInfo.dims != allDims || indicesInfo.dims != allDims) { \ + allDims = -1; \ + } \ + \ + TOPK_RUN_DIM(INDEX_T); \ + } while (0) + + // Dispatch: use 32-bit indexing when all tensors qualify, else 64-bit + if (input_contiguous.numel() > 0) { + if (topk_impl::canUse32BitIndexMath(input_contiguous) && + topk_impl::canUse32BitIndexMath(*out) && + topk_impl::canUse32BitIndexMath(*indices)) { + TOPK_RUN_T(uint32_t); + } else { + TOPK_RUN_T(uint64_t); + } + } + +#undef TOPK_RUN_K +#undef TOPK_RUN_K_MB +#undef TOPK_RUN_MB +#undef TOPK_RUN_DIM +#undef TOPK_RUN_T + + // Sort the results if needed + if (sorted && k > 1 && out->numel() > 0) { + // Three-tier sort dispatch: + // k <= 32: Bitonic Sort + // k <= 128: WarpMergeSort (CUB) + // k <= 4096: BlockRadixSort (CUB) + // k > 4096: Fall back to ArgsortKernel + TakeAlongAxisKernel + if (k <= 4096) { + topk_impl::sortKeyValueInplace( + dev_ctx, out, indices, axis, largest); + } else { + phi::DenseTensor sorted_indices; + phi::DenseTensor sorted_values; + sorted_indices.Resize(indices->dims()); + sorted_values.Resize(out->dims()); + dev_ctx.template Alloc(&sorted_indices); + dev_ctx.template Alloc(&sorted_values); + + phi::ArgsortKernel(dev_ctx, + *out, + axis, + largest, + /*stable=*/true, + &sorted_values, + &sorted_indices); + + phi::DenseTensor new_indices; + new_indices.Resize(indices->dims()); + dev_ctx.template Alloc(&new_indices); + phi::TakeAlongAxisKernel( + dev_ctx, *indices, sorted_indices, axis, &new_indices); + + phi::Copy( + dev_ctx, sorted_values, dev_ctx.GetPlace(), false, out); + phi::Copy( + dev_ctx, new_indices, dev_ctx.GetPlace(), false, indices); + } + } +} +#endif } // namespace phi +#ifdef PADDLE_WITH_CUDA +PD_REGISTER_KERNEL(topk, + GPU, + ALL_LAYOUT, + phi::TopkKernelCuda, + float, + double, + int, + int64_t, + phi::float16, + phi::bfloat16) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} +#else PD_REGISTER_KERNEL(topk, GPU, ALL_LAYOUT, @@ -382,6 +636,7 @@ PD_REGISTER_KERNEL(topk, phi::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::INT64); } +#endif PD_REGISTER_KERNEL(topk_v1, GPU,