Skip to content

[Performance Optimization] Rewrite GPU TopK kernel with radix-select …#78703

Merged
sneaxiy merged 1 commit intoPaddlePaddle:release/3.4from
zhengshengning:cp_acc_opt_topk_3_4
Apr 21, 2026
Merged

[Performance Optimization] Rewrite GPU TopK kernel with radix-select …#78703
sneaxiy merged 1 commit intoPaddlePaddle:release/3.4from
zhengshengning:cp_acc_opt_topk_3_4

Conversation

@zhengshengning
Copy link
Copy Markdown
Contributor

PR Category

Operator Mechanism

PR Types

Improvements

Description

devPR:#78409

是否引起精度变化

…and multi-tier sorting (PaddlePaddle#78409)

* [TopK] Rewrite GPU TopK kernel with radix-select and multi-tier sorting

Replace the existing GPU TopK implementation with a new radix-select
based algorithm and multi-tier sorting strategy for improved performance:

- Radix-select for efficient top-k selection
- Multi-block top-k (mbtopk) for large slices
- Single-block top-k (sbtopk) for smaller slices
- Three-tier sort dispatch: Bitonic Sort (k<=32), WarpMergeSort (k<=128),
  BlockRadixSort (k<=4096), ArgsortKernel fallback (k>4096)
- Rename old TopkKernel to TopkKernelOld for reference

* Fix doLdg duplicate definition: restore long long types with NOLINT

On LP64 Linux, int64_t is typedef of long, not long long. Using int64_t
caused duplicate specialization. Restore original long long / unsigned
long long types with NOLINT to suppress cpplint, and remove the
duplicate int64_t specialization.

* Fix TopkKernel crash: defer Alloc until after FromTensor resize

When k comes from a tensor, InferMeta may set output dims with -1,
making metadata invalid. Calling Alloc before resolving the actual k
value triggers PreconditionNotMetError.

Fix: move Alloc after FromTensor() resize, add empty-output guard and
empty-input handling to match the old kernel behavior.

* Fix TopkKernel crash: defer Alloc until after FromTensor resize

When k comes from a tensor, InferMeta may set output dims with -1,
making metadata invalid. Calling Alloc before resolving the actual k
value triggers PreconditionNotMetError.

Fix: move Alloc after FromTensor() resize, add empty-output guard and
empty-input handling to match the old kernel behavior.

* Fix HIP/ROCm compilation errors in top_k_cuda_kernel.cu

- Bitfield: add HIP fallback using bit shifts instead of PTX asm
  (bfe.u32/u64, bfi.b32/b64 are NVIDIA PTX only)
- getLaneId/getLaneMaskLe/getLaneMaskLt: use HIP intrinsics on __HIPCC__
- CubKeyType<bfloat16>: use hip_bfloat16 instead of __nv_bfloat16
- Replace cudaStream_t with gpuStream_t (Paddle's unified type alias)

* Fix Windows build: bring gpuStream_t into anonymous namespace

gpuStream_t is defined in phi:: namespace (via gpu_decls.h). The helper
functions in the anonymous namespace cannot access it without
qualification. Add 'using phi::gpuStream_t;' at the top of the
anonymous namespace.

* Fix DCU/HIP compilation errors in top_k_cuda_kernel.cu

- Guard __syncwarp() with #if !defined(__HIPCC__) since HIP/DCU
  does not provide this intrinsic (AMD wavefronts are lockstep)
- Replace cudaMemsetAsync with hipMemsetAsync under PADDLE_WITH_HIP
- Use conservative defaults for regsPerMultiprocessor (65536) and
  maxBlocksPerMultiProcessor on HIP since hipDeviceProp_t lacks
  these members

* rename tok_cuda_kernel

* fix

* fix

* fix2

* fix

* fix2

* fix
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented Apr 17, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Copy Markdown
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sneaxiy sneaxiy merged commit 5966f85 into PaddlePaddle:release/3.4 Apr 21, 2026
275 of 285 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants