Skip to content

Limit the grid size for the TBE forward kernel #4208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,12 @@ batch_index_select_dim0_codegen_forward_kernel(
{%- endif %}

// Determine the linearized warp ID, and exit early if needed
{%- if is_index_select %}
auto b_t = blockIdx.x * blockDim.y + threadIdx.y;
{%- if not is_index_select %}
if (b_t >= offsets.size(0) - 1) {
return;
}
{%- else %}
const auto total_B = offsets.size(0) - 1;
// Since we place a limit on the grid size, we need to perform grid-striding
for (auto b_t = blockIdx.x * blockDim.y + threadIdx.y; b_t < total_B; b_t += blockDim.y * gridDim.x) {
{%- endif %}

// Determine the Table and Training Example IDs
Expand Down Expand Up @@ -832,6 +833,10 @@ batch_index_select_dim0_codegen_forward_kernel(

}
{%- endif %}

{%- if not is_index_select %}
} // for b_t
{%- endif %}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
////////////////////////////////////////////////////////////////////////////////
#include "fbgemm_gpu/utils/ops_utils.h"
{%- endif %}
#include "fbgemm_gpu/utils/device_properties.cuh"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
Expand Down Expand Up @@ -708,6 +709,10 @@ batch_index_select_dim0_codegen_forward_cuda(
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread;
{%- endif %}

const auto grid = min(
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
utils::cuda::get_max_thread_blocks(at::cuda::getCurrentCUDAStream()));

FBGEMM_LAUNCH_KERNEL(
({{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel
<emb_t,
Expand All @@ -719,7 +724,7 @@ batch_index_select_dim0_codegen_forward_cuda(
index_t,
kMaxVecsPerThread,
kThreadGroupSize>),
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
grid,
dim3(kThreadGroupSize, kForwardMaxThreads / kThreadGroupSize),
0,
at::cuda::getCurrentCUDAStream(),
Expand Down
16 changes: 14 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/device_properties.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,22 @@

#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>

namespace fbgemm_gpu::utils {
namespace fbgemm_gpu::utils::cuda {

// Based on the empirical study, max grid size that is 64x larger than the
// number of SMs gives good performance across the board
constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64;

inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) {
const auto device = stream.device_index();
return MAX_THREAD_BLOCKS_FACTOR *
at::cuda::getDeviceProperties(device)->multiProcessorCount;
}

inline auto get_compute_versions() {
static const auto versions = [] {
Expand All @@ -27,4 +39,4 @@ inline auto get_compute_versions() {
return versions;
}

} // namespace fbgemm_gpu::utils
} // namespace fbgemm_gpu::utils::cuda
25 changes: 22 additions & 3 deletions fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,13 @@ struct KernelLauncher {
TORCH_CHECK(
threads_per_block <= properties.maxThreadsPerBlock,
context.description(),
": Threads per block ",
": [block dim ",
block.x,
" x ",
block.y,
" x ",
block.z,
"] Threads per block ",
threads_per_block,
" is greater than the limit of ",
properties.maxThreadsPerBlock);
Expand All @@ -190,15 +196,28 @@ struct KernelLauncher {
// automatically work around problem like CUDA does (V100 or newer
// architectures), see:
// https://github.com/ROCm/hip/issues/2253
// https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/hip_runtime_api/modules/occupancy.html
const uint64_t total_threads = U64(grid.x) * U64(grid.y) * U64(grid.z) *
U64(block.x) * U64(block.y) * U64(block.z);

TORCH_CHECK(
total_threads < U64(std::numeric_limits<uint32_t>::max()),
context.description(),
": Total number of threads ",
" [grid dim ",
grid.x,
" x ",
grid.y,
" x ",
grid.z,
"] [block dim ",
block.x,
" x ",
block.y,
" x ",
block.z,
"]: Total number of threads ",
total_threads,
" is greater than the limit of 2^32");
" is greater than the HIP limit of 2^32");
#endif
}

Expand Down
13 changes: 13 additions & 0 deletions fbgemm_gpu/test/tbe/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
# For long running tests reduce the number of iterations to reduce timeout errors.
MAX_EXAMPLES_LONG_RUNNING = 15

FORWARD_MAX_THREADS = 512

VERBOSITY: Verbosity = Verbosity.verbose


Expand Down Expand Up @@ -83,3 +85,14 @@ def format_ref_tensors_in_mixed_B_layout(

def assert_torch_equal(tensor_a: torch.Tensor, tensor_b: torch.Tensor) -> None:
assert torch.equal(tensor_a, tensor_b)


def get_max_thread_blocks(stream: torch.cuda.streams.Stream) -> int:
# Based on the empirical studies, having a max grid size that is 64x larger than
# the number of SMs gives good performance across the board
MAX_THREAD_BLOCKS_FACTOR = 64
device = stream.device_index
return (
MAX_THREAD_BLOCKS_FACTOR
* torch.cuda.get_device_properties(device).multi_processor_count
)
81 changes: 67 additions & 14 deletions fbgemm_gpu/test/tbe/training/forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# pyre-ignore-all-errors[56]

import math
import random
import unittest

Expand Down Expand Up @@ -37,7 +38,9 @@
from .. import common # noqa E402
from ..common import (
format_ref_tensors_in_mixed_B_layout,
FORWARD_MAX_THREADS,
gen_mixed_B_batch_sizes,
get_max_thread_blocks,
MAX_EXAMPLES_LONG_RUNNING,
open_source,
)
Expand Down Expand Up @@ -472,26 +475,16 @@ def test_forward_gpu_no_cache_int8(
False, # use_experimental_tbe
)

@unittest.skipIf(*gpu_unavailable)
@given(
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
)
def test_forward_gpu_no_cache_fp16(
def _test_forward_gpu_no_cache_fp16_impl(
self,
T: int,
B: int,
L: int,
use_experimental_tbe: bool,
) -> None:
weights_precision = SparseType.FP16
use_cpu = False
T = random.randint(1, 10)
D = random.randint(2, 256)
B = random.randint(1, 128)
L = random.randint(0, 20)
log_E = random.randint(3, 5)

use_cache = False
Expand Down Expand Up @@ -535,6 +528,66 @@ def test_forward_gpu_no_cache_fp16(
use_experimental_tbe,
)

@unittest.skipIf(*gpu_unavailable)
@given(
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
)
def test_forward_gpu_no_cache_fp16(
self,
use_experimental_tbe: bool,
) -> None:
return self._test_forward_gpu_no_cache_fp16_impl(
random.randint(1, 10),
random.randint(1, 128),
random.randint(0, 20),
use_experimental_tbe,
)

@unittest.skipIf(*gpu_unavailable)
@given(
use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
)
def test_forward_gpu_no_cache_fp16_large(
self,
use_experimental_tbe: bool,
) -> None:
torch.cuda.empty_cache()

max_num_threads = FORWARD_MAX_THREADS * get_max_thread_blocks(
torch.cuda.current_stream()
)
# NOTE: L is arbitrarily chosen here
L = 10
# NOTE: Fix to the smallest value B such that (B x L) = (number of
# indices) > (allowed grid size x block size)
B = 2 ** (math.ceil(math.log2(max_num_threads / L)))
# NOTE: T is chosen to be small enough to avoid OOM errors given that
# B x L must be large enough
T = 3

assert (
B * L > max_num_threads
), "Should be testing the case where B * L is larger than max_num_threads"

return self._test_forward_gpu_no_cache_fp16_impl(
T,
B,
L,
use_experimental_tbe,
)

@unittest.skipIf(*gpu_unavailable)
@given(
use_experimental_tbe=st.booleans(),
Expand Down
Loading