|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | +#include "aiter_hip_common.h" |
| 4 | +#include "py_itfs_common.h" |
| 5 | +#include <ATen/hip/HIPContext.h> |
| 6 | +#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> |
| 7 | +#include <torch/all.h> |
| 8 | + |
| 9 | +struct __attribute__((packed)) TopKDecodeKernelArgs |
| 10 | +{ |
| 11 | + void* ptr_workspace; |
| 12 | + void* ptr_logits; |
| 13 | + void* ptr_rowStarts; |
| 14 | + void* ptr_rowEnds; |
| 15 | + void* ptr_indices; |
| 16 | + void* ptr_values; |
| 17 | + int32_t stride0; |
| 18 | + int32_t stride1; |
| 19 | +}; |
| 20 | + |
| 21 | +template <typename T, typename IdxT, int kNumThreadsPerBlock = 1024> |
| 22 | +int64_t invokePrefillTopKLastDimWorkspaceSize(int32_t numRows, int32_t topkValue) |
| 23 | +{ |
| 24 | + return topkValue * (sizeof(T) + sizeof(IdxT)) * numRows; |
| 25 | +} |
| 26 | + |
| 27 | +void top_k_per_row_prefill_fast(const torch::Tensor& logits, |
| 28 | + const torch::Tensor& rowStarts, |
| 29 | + const torch::Tensor& rowEnds, |
| 30 | + torch::Tensor& indices, |
| 31 | + std::optional<torch::Tensor> values, |
| 32 | + int64_t numRows, |
| 33 | + int64_t stride0, |
| 34 | + int64_t stride1) |
| 35 | +{ |
| 36 | + // Compute workspace size and allocate workspace tensor |
| 37 | + const auto numColumns = logits.size(1); |
| 38 | + int64_t workspace_size = invokePrefillTopKLastDimWorkspaceSize<float, int32_t>(numRows, 2048); |
| 39 | + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(logits.device()); |
| 40 | + torch::Tensor workspace = torch::empty({workspace_size}, options); |
| 41 | + |
| 42 | + TopKDecodeKernelArgs args; |
| 43 | + size_t arg_size = sizeof(args); |
| 44 | + |
| 45 | + args.ptr_workspace = static_cast<void*>(workspace.data_ptr<uint8_t>()); |
| 46 | + args.ptr_logits = logits.data_ptr<float>(); |
| 47 | + args.ptr_rowStarts = rowStarts.data_ptr<int>(); |
| 48 | + args.ptr_rowEnds = rowEnds.data_ptr<int>(); |
| 49 | + args.ptr_indices = indices.data_ptr<int>(); |
| 50 | + args.ptr_values = nullptr; |
| 51 | + args.stride0 = static_cast<int32_t>(stride0); |
| 52 | + args.stride1 = static_cast<int32_t>(stride1); |
| 53 | + |
| 54 | + // Load the compiled assembly kernel |
| 55 | + static AiterAsmKernel impl_topk_decode( |
| 56 | + "_ZN5aiter11PrefillTopKL10topKPerRowILi1024ELi2048ELi2048ELi512EEEvPvPKfPKiS6_PiPfii", |
| 57 | + "/topk_per_row_prefill/asm_top_k_per_row_prefill.co"); |
| 58 | + |
| 59 | + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(logits)); |
| 60 | + const hipStream_t stream = at::hip::getCurrentHIPStream(); |
| 61 | + |
| 62 | + // Launch kernel configuration |
| 63 | + constexpr int kNumThreadsPerBlock = 1024; |
| 64 | + uint64_t gdx = numRows; |
| 65 | + |
| 66 | + TORCH_CHECK(gdx >> 31 == 0, "numRows too large: ", numRows); |
| 67 | + |
| 68 | + impl_topk_decode.launch_kernel({&args, |
| 69 | + &arg_size, |
| 70 | + static_cast<int>(gdx), // gdx: one block per row |
| 71 | + 1, // gdy |
| 72 | + 1, // gdz |
| 73 | + kNumThreadsPerBlock, // bdx: 1024 threads |
| 74 | + 1, // bdy |
| 75 | + 1, // bdz |
| 76 | + stream}); |
| 77 | +} |
| 78 | + |
0 commit comments