Skip to content

Commit 5b75aec

Browse files
q10facebook-github-bot
authored andcommitted
Limit the grid size for the TBE forward kernel (#4208)
Summary: Pull Request resolved: #4208 X-link: facebookresearch/FBGEMM#1283 - Limit the grid size for the TBE forward kernel, as we are observing thread counts over 2^32 for AMD runs Reviewed By: yoyoyocmu, r-barnes, sryap Differential Revision: D75543767 fbshipit-source-id: de60e075095a94922ca935e681b805d31d2e6486
1 parent 45b2a8f commit 5b75aec

File tree

6 files changed

+131
-24
lines changed

6 files changed

+131
-24
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -645,11 +645,12 @@ batch_index_select_dim0_codegen_forward_kernel(
645645
{%- endif %}
646646

647647
// Determine the linearized warp ID, and exit early if needed
648+
{%- if is_index_select %}
648649
auto b_t = blockIdx.x * blockDim.y + threadIdx.y;
649-
{%- if not is_index_select %}
650-
if (b_t >= offsets.size(0) - 1) {
651-
return;
652-
}
650+
{%- else %}
651+
const auto total_B = offsets.size(0) - 1;
652+
// Since we place a limit on the grid size, we need to perform grid-striding
653+
for (auto b_t = blockIdx.x * blockDim.y + threadIdx.y; b_t < total_B; b_t += blockDim.y * gridDim.x) {
653654
{%- endif %}
654655

655656
// Determine the Table and Training Example IDs
@@ -832,6 +833,10 @@ batch_index_select_dim0_codegen_forward_kernel(
832833

833834
}
834835
{%- endif %}
836+
837+
{%- if not is_index_select %}
838+
} // for b_t
839+
{%- endif %}
835840
}
836841

837842

fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
////////////////////////////////////////////////////////////////////////////////
3838
#include "fbgemm_gpu/utils/ops_utils.h"
3939
{%- endif %}
40+
#include "fbgemm_gpu/utils/device_properties.cuh"
4041
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
4142
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
4243
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
@@ -708,6 +709,10 @@ batch_index_select_dim0_codegen_forward_cuda(
708709
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread;
709710
{%- endif %}
710711

712+
const auto grid = min(
713+
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
714+
utils::cuda::get_max_thread_blocks(at::cuda::getCurrentCUDAStream()));
715+
711716
FBGEMM_LAUNCH_KERNEL(
712717
({{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel
713718
<emb_t,
@@ -719,7 +724,7 @@ batch_index_select_dim0_codegen_forward_cuda(
719724
index_t,
720725
kMaxVecsPerThread,
721726
kThreadGroupSize>),
722-
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
727+
grid,
723728
dim3(kThreadGroupSize, kForwardMaxThreads / kThreadGroupSize),
724729
0,
725730
at::cuda::getCurrentCUDAStream(),

fbgemm_gpu/include/fbgemm_gpu/utils/device_properties.cuh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,22 @@
88

99
#pragma once
1010

11+
#include <ATen/cuda/CUDAContext.h>
1112
#include <c10/cuda/CUDAException.h>
13+
#include <c10/cuda/CUDAStream.h>
1214
#include <cuda.h>
1315

14-
namespace fbgemm_gpu::utils {
16+
namespace fbgemm_gpu::utils::cuda {
17+
18+
// Based on the empirical study, max grid size that is 64x larger than the
19+
// number of SMs gives good performance across the board
20+
constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64;
21+
22+
inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) {
23+
const auto device = stream.device_index();
24+
return MAX_THREAD_BLOCKS_FACTOR *
25+
at::cuda::getDeviceProperties(device)->multiProcessorCount;
26+
}
1527

1628
inline auto get_compute_versions() {
1729
static const auto versions = [] {
@@ -27,4 +39,4 @@ inline auto get_compute_versions() {
2739
return versions;
2840
}
2941

30-
} // namespace fbgemm_gpu::utils
42+
} // namespace fbgemm_gpu::utils::cuda

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,13 @@ struct KernelLauncher {
180180
TORCH_CHECK(
181181
threads_per_block <= properties.maxThreadsPerBlock,
182182
context.description(),
183-
": Threads per block ",
183+
": [block dim ",
184+
block.x,
185+
" x ",
186+
block.y,
187+
" x ",
188+
block.z,
189+
"] Threads per block ",
184190
threads_per_block,
185191
" is greater than the limit of ",
186192
properties.maxThreadsPerBlock);
@@ -190,15 +196,28 @@ struct KernelLauncher {
190196
// automatically work around problem like CUDA does (V100 or newer
191197
// architectures), see:
192198
// https://github.com/ROCm/hip/issues/2253
199+
// https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/hip_runtime_api/modules/occupancy.html
193200
const uint64_t total_threads = U64(grid.x) * U64(grid.y) * U64(grid.z) *
194201
U64(block.x) * U64(block.y) * U64(block.z);
195202

196203
TORCH_CHECK(
197204
total_threads < U64(std::numeric_limits<uint32_t>::max()),
198205
context.description(),
199-
": Total number of threads ",
206+
" [grid dim ",
207+
grid.x,
208+
" x ",
209+
grid.y,
210+
" x ",
211+
grid.z,
212+
"] [block dim ",
213+
block.x,
214+
" x ",
215+
block.y,
216+
" x ",
217+
block.z,
218+
"]: Total number of threads ",
200219
total_threads,
201-
" is greater than the limit of 2^32");
220+
" is greater than the HIP limit of 2^32");
202221
#endif
203222
}
204223

fbgemm_gpu/test/tbe/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
# For long running tests reduce the number of iterations to reduce timeout errors.
4646
MAX_EXAMPLES_LONG_RUNNING = 15
4747

48+
FORWARD_MAX_THREADS = 512
49+
4850
VERBOSITY: Verbosity = Verbosity.verbose
4951

5052

@@ -83,3 +85,14 @@ def format_ref_tensors_in_mixed_B_layout(
8385

8486
def assert_torch_equal(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> None:
8587
assert torch.equal(tensor_a, tensor_b)
88+
89+
90+
def get_max_thread_blocks(stream: torch.cuda.streams.Stream) -> int:
91+
# Based on the empirical studies, having a max grid size that is 64x larger than
92+
# the number of SMs gives good performance across the board
93+
MAX_THREAD_BLOCKS_FACTOR = 64
94+
device = stream.device_index
95+
return (
96+
MAX_THREAD_BLOCKS_FACTOR
97+
* torch.cuda.get_device_properties(device).multi_processor_count
98+
)

fbgemm_gpu/test/tbe/training/forward_test.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# pyre-ignore-all-errors[56]
1111

12+
import math
1213
import random
1314
import unittest
1415

@@ -37,7 +38,9 @@
3738
from .. import common # noqa E402
3839
from ..common import (
3940
format_ref_tensors_in_mixed_B_layout,
41+
FORWARD_MAX_THREADS,
4042
gen_mixed_B_batch_sizes,
43+
get_max_thread_blocks,
4144
MAX_EXAMPLES_LONG_RUNNING,
4245
open_source,
4346
)
@@ -472,26 +475,16 @@ def test_forward_gpu_no_cache_int8(
472475
False, # use_experimental_tbe
473476
)
474477

475-
@unittest.skipIf(*gpu_unavailable)
476-
@given(
477-
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
478-
)
479-
@settings(
480-
verbosity=VERBOSITY,
481-
max_examples=MAX_EXAMPLES_LONG_RUNNING,
482-
deadline=None,
483-
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
484-
)
485-
def test_forward_gpu_no_cache_fp16(
478+
def _test_forward_gpu_no_cache_fp16_impl(
486479
self,
480+
T: int,
481+
B: int,
482+
L: int,
487483
use_experimental_tbe: bool,
488484
) -> None:
489485
weights_precision = SparseType.FP16
490486
use_cpu = False
491-
T = random.randint(1, 10)
492487
D = random.randint(2, 256)
493-
B = random.randint(1, 128)
494-
L = random.randint(0, 20)
495488
log_E = random.randint(3, 5)
496489

497490
use_cache = False
@@ -535,6 +528,66 @@ def test_forward_gpu_no_cache_fp16(
535528
use_experimental_tbe,
536529
)
537530

531+
@unittest.skipIf(*gpu_unavailable)
532+
@given(
533+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
534+
)
535+
@settings(
536+
verbosity=VERBOSITY,
537+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
538+
deadline=None,
539+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
540+
)
541+
def test_forward_gpu_no_cache_fp16(
542+
self,
543+
use_experimental_tbe: bool,
544+
) -> None:
545+
return self._test_forward_gpu_no_cache_fp16_impl(
546+
random.randint(1, 10),
547+
random.randint(1, 128),
548+
random.randint(0, 20),
549+
use_experimental_tbe,
550+
)
551+
552+
@unittest.skipIf(*gpu_unavailable)
553+
@given(
554+
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
555+
)
556+
@settings(
557+
verbosity=VERBOSITY,
558+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
559+
deadline=None,
560+
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
561+
)
562+
def test_forward_gpu_no_cache_fp16_large(
563+
self,
564+
use_experimental_tbe: bool,
565+
) -> None:
566+
torch.cuda.empty_cache()
567+
568+
max_num_threads = FORWARD_MAX_THREADS * get_max_thread_blocks(
569+
torch.cuda.current_stream()
570+
)
571+
# NOTE: L is arbitrarily chosen here
572+
L = 10
573+
# NOTE: Fix to the smallest value B such that (B x L) = (number of
574+
# indices) > (allowed grid size x block size)
575+
B = 2 ** (math.ceil(math.log2(max_num_threads / L)))
576+
# NOTE: T is chosen to be small enough to avoid OOM errors given that
577+
# B x L must be large enough
578+
T = 3
579+
580+
assert (
581+
B * L > max_num_threads
582+
), "Should be testing the case where B * L is larger than max_num_threads"
583+
584+
return self._test_forward_gpu_no_cache_fp16_impl(
585+
T,
586+
B,
587+
L,
588+
use_experimental_tbe,
589+
)
590+
538591
@unittest.skipIf(*gpu_unavailable)
539592
@given(
540593
use_experimental_tbe=st.booleans(),

0 commit comments

Comments
 (0)