Skip to content

Commit d63b7b9

Browse files
steamedMantouyongshuai
andauthored
topk prefill uplift v1.0 (ROCm#1755)
Co-authored-by: yongshuai <yongshuai@amd.com>
1 parent 7c7edd5 commit d63b7b9

7 files changed

Lines changed: 784 additions & 0 deletions

File tree

aiter/jit/optCompilerConfig.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@
10651065
"srcs": [
10661066
"f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'",
10671067
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_topk_per_row_decode.cu'",
1068+
"f'{AITER_CSRC_DIR}/py_itfs_cu/asm_topk_per_row_prefill.cu'",
10681069
"f'{AITER_CSRC_DIR}/pybind/topk_per_row_pybind.cu'"
10691070
],
10701071
"flags_extra_cc": [],

aiter/ops/topk.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,19 @@ def top_k_per_row_prefill(
209209
) -> None: ...
210210

211211

212+
@compile_ops("module_top_k_per_row")
213+
def top_k_per_row_prefill_fast(
214+
logits: torch.Tensor,
215+
rowStarts: torch.Tensor,
216+
rowEnds: torch.Tensor,
217+
indices: torch.Tensor,
218+
values: Optional[torch.Tensor],
219+
numRows: int,
220+
stride0: int,
221+
stride1: int,
222+
) -> None: ...
223+
224+
212225
@compile_ops("module_top_k_per_row")
213226
def top_k_per_row_decode(
214227
logits: torch.Tensor,

csrc/include/rocm_ops.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,16 @@ namespace py = pybind11;
15651565
py::arg("numRows"), \
15661566
py::arg("stride0"), \
15671567
py::arg("stride1")); \
1568+
m.def("top_k_per_row_prefill_fast", \
1569+
&top_k_per_row_prefill_fast, \
1570+
py::arg("logits"), \
1571+
py::arg("rowStarts"), \
1572+
py::arg("rowEnds"), \
1573+
py::arg("indices"), \
1574+
py::arg("values"), \
1575+
py::arg("numRows"), \
1576+
py::arg("stride0"), \
1577+
py::arg("stride1")); \
15681578
m.def("top_k_per_row_decode", \
15691579
&top_k_per_row_decode, \
15701580
py::arg("logits"), \

csrc/include/topk_per_row.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ void top_k_per_row_decode(const torch::Tensor& logits,
1919
int64_t stride0,
2020
int64_t stride1);
2121

22+
void top_k_per_row_prefill_fast(const torch::Tensor& logits,
23+
const torch::Tensor& rowStarts,
24+
const torch::Tensor& rowEnds,
25+
torch::Tensor& indices,
26+
std::optional<torch::Tensor> values,
27+
int64_t numRows,
28+
int64_t stride0,
29+
int64_t stride1);
30+
2231
void top_k_per_row_decode_fast(const torch::Tensor& logits,
2332
int64_t next_n,
2433
const torch::Tensor& seqLens,
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#include "aiter_hip_common.h"
4+
#include "py_itfs_common.h"
5+
#include <ATen/hip/HIPContext.h>
6+
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
7+
#include <torch/all.h>
8+
9+
struct __attribute__((packed)) TopKDecodeKernelArgs
10+
{
11+
void* ptr_workspace;
12+
void* ptr_logits;
13+
void* ptr_rowStarts;
14+
void* ptr_rowEnds;
15+
void* ptr_indices;
16+
void* ptr_values;
17+
int32_t stride0;
18+
int32_t stride1;
19+
};
20+
21+
template <typename T, typename IdxT, int kNumThreadsPerBlock = 1024>
22+
int64_t invokePrefillTopKLastDimWorkspaceSize(int32_t numRows, int32_t topkValue)
23+
{
24+
return topkValue * (sizeof(T) + sizeof(IdxT)) * numRows;
25+
}
26+
27+
void top_k_per_row_prefill_fast(const torch::Tensor& logits,
28+
const torch::Tensor& rowStarts,
29+
const torch::Tensor& rowEnds,
30+
torch::Tensor& indices,
31+
std::optional<torch::Tensor> values,
32+
int64_t numRows,
33+
int64_t stride0,
34+
int64_t stride1)
35+
{
36+
// Compute workspace size and allocate workspace tensor
37+
const auto numColumns = logits.size(1);
38+
int64_t workspace_size = invokePrefillTopKLastDimWorkspaceSize<float, int32_t>(numRows, 2048);
39+
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device());
40+
torch::Tensor workspace = torch::empty({workspace_size}, options);
41+
42+
TopKDecodeKernelArgs args;
43+
size_t arg_size = sizeof(args);
44+
45+
args.ptr_workspace = static_cast<void*>(workspace.data_ptr<uint8_t>());
46+
args.ptr_logits = logits.data_ptr<float>();
47+
args.ptr_rowStarts = rowStarts.data_ptr<int>();
48+
args.ptr_rowEnds = rowEnds.data_ptr<int>();
49+
args.ptr_indices = indices.data_ptr<int>();
50+
args.ptr_values = nullptr;
51+
args.stride0 = static_cast<int32_t>(stride0);
52+
args.stride1 = static_cast<int32_t>(stride1);
53+
54+
// Load the compiled assembly kernel
55+
static AiterAsmKernel impl_topk_decode(
56+
"_ZN5aiter11PrefillTopKL10topKPerRowILi1024ELi2048ELi2048ELi512EEEvPvPKfPKiS6_PiPfii",
57+
"/topk_per_row_prefill/asm_top_k_per_row_prefill.co");
58+
59+
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(logits));
60+
const hipStream_t stream = at::hip::getCurrentHIPStream();
61+
62+
// Launch kernel configuration
63+
constexpr int kNumThreadsPerBlock = 1024;
64+
uint64_t gdx = numRows;
65+
66+
TORCH_CHECK(gdx >> 31 == 0, "numRows too large: ", numRows);
67+
68+
impl_topk_decode.launch_kernel({&args,
69+
&arg_size,
70+
static_cast<int>(gdx), // gdx: one block per row
71+
1, // gdy
72+
1, // gdz
73+
kNumThreadsPerBlock, // bdx: 1024 threads
74+
1, // bdy
75+
1, // bdz
76+
stream});
77+
}
78+
30.3 KB
Binary file not shown.

0 commit comments

Comments
 (0)