Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c1323b7

Browse files
q10facebook-github-bot
authored andcommittedMay 29, 2025·
Limit the grid size for the TBE forward kernel
Summary: - Limit the grid size for the TBE forward kernel, as we are observing thread counts over 2^32 for AMD runs Reviewed By: yoyoyocmu, r-barnes Differential Revision: D75543767
1 parent 3ba2c3a commit c1323b7

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed
 

‎fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -645,11 +645,12 @@ batch_index_select_dim0_codegen_forward_kernel(
645645
{%- endif %}
646646

647647
// Determine the linearized warp ID, and exit early if needed
648+
{%- if is_index_select %}
648649
auto b_t = blockIdx.x * blockDim.y + threadIdx.y;
649-
{%- if not is_index_select %}
650-
if (b_t >= offsets.size(0) - 1) {
651-
return;
652-
}
650+
{%- else %}
651+
const auto total_B = offsets.size(0) - 1;
652+
// Since we place a limit on the grid size, we need to perform grid-striding
653+
for (auto b_t = blockIdx.x * blockDim.y + threadIdx.y; b_t < total_B; b_t += blockDim.y * gridDim.x) {
653654
{%- endif %}
654655

655656
// Determine the Table and Training Example IDs
@@ -832,6 +833,10 @@ batch_index_select_dim0_codegen_forward_kernel(
832833

833834
}
834835
{%- endif %}
836+
837+
{%- if not is_index_select %}
838+
} // for b_t
839+
{%- endif %}
835840
}
836841

837842

‎fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
////////////////////////////////////////////////////////////////////////////////
3838
#include "fbgemm_gpu/utils/ops_utils.h"
3939
{%- endif %}
40+
#include "fbgemm_gpu/utils/device_properties.cuh"
4041
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
4142
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
4243
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
@@ -708,6 +709,10 @@ batch_index_select_dim0_codegen_forward_cuda(
708709
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread;
709710
{%- endif %}
710711

712+
const auto grid = min(
713+
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
714+
utils::cuda::get_max_thread_blocks(at::cuda::getCurrentCUDAStream()));
715+
711716
FBGEMM_LAUNCH_KERNEL(
712717
({{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel
713718
<emb_t,
@@ -719,7 +724,7 @@ batch_index_select_dim0_codegen_forward_cuda(
719724
index_t,
720725
kMaxVecsPerThread,
721726
kThreadGroupSize>),
722-
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
727+
grid,
723728
dim3(kThreadGroupSize, kForwardMaxThreads / kThreadGroupSize),
724729
0,
725730
at::cuda::getCurrentCUDAStream(),

‎fbgemm_gpu/include/fbgemm_gpu/utils/device_properties.cuh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,22 @@
88

99
#pragma once
1010

11+
#include <ATen/cuda/CUDAContext.h>
1112
#include <c10/cuda/CUDAException.h>
13+
#include <c10/cuda/CUDAStream.h>
1214
#include <cuda.h>
1315

14-
namespace fbgemm_gpu::utils {
16+
namespace fbgemm_gpu::utils::cuda {
17+
18+
// Based on the empirical study, max grid size that is 64x larger than the
19+
// number of SMs gives good performance across the board
20+
constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64;
21+
22+
inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) {
23+
const auto device = stream.device_index();
24+
return MAX_THREAD_BLOCKS_FACTOR *
25+
at::cuda::getDeviceProperties(device)->multiProcessorCount;
26+
}
1527

1628
inline auto get_compute_versions() {
1729
static const auto versions = [] {
@@ -27,4 +39,4 @@ inline auto get_compute_versions() {
2739
return versions;
2840
}
2941

30-
} // namespace fbgemm_gpu::utils
42+
} // namespace fbgemm_gpu::utils::cuda

‎fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,13 @@ struct KernelLauncher {
180180
TORCH_CHECK(
181181
threads_per_block <= properties.maxThreadsPerBlock,
182182
context.description(),
183-
": Threads per block ",
183+
": [block dim ",
184+
block.x,
185+
" x ",
186+
block.y,
187+
" x ",
188+
block.z,
189+
"] Threads per block ",
184190
threads_per_block,
185191
" is greater than the limit of ",
186192
properties.maxThreadsPerBlock);
@@ -190,15 +196,28 @@ struct KernelLauncher {
190196
// automatically work around problem like CUDA does (V100 or newer
191197
// architectures), see:
192198
// https://github.com/ROCm/hip/issues/2253
199+
// https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/hip_runtime_api/modules/occupancy.html
193200
const uint64_t total_threads = U64(grid.x) * U64(grid.y) * U64(grid.z) *
194201
U64(block.x) * U64(block.y) * U64(block.z);
195202

196203
TORCH_CHECK(
197204
total_threads < U64(std::numeric_limits<uint32_t>::max()),
198205
context.description(),
199-
": Total number of threads ",
206+
" [grid dim ",
207+
grid.x,
208+
" x ",
209+
grid.y,
210+
" x ",
211+
grid.z,
212+
"] [block dim ",
213+
block.x,
214+
" x ",
215+
block.y,
216+
" x ",
217+
block.z,
218+
"]: Total number of threads ",
200219
total_threads,
201-
" is greater than the limit of 2^32");
220+
" is greater than the HIP limit of 2^32");
202221
#endif
203222
}
204223

0 commit comments

Comments
 (0)
Please sign in to comment.