File tree Expand file tree Collapse file tree 4 files changed +51
-10
lines changed Expand file tree Collapse file tree 4 files changed +51
-10
lines changed Original file line number Diff line number Diff line change @@ -645,11 +645,12 @@ batch_index_select_dim0_codegen_forward_kernel(
645
645
{%- endif %}
646
646
647
647
// Determine the linearized warp ID, and exit early if needed
648
+ {%- if is_index_select %}
648
649
auto b_t = blockIdx .x * blockDim .y + threadIdx .y ;
649
- {%- if not is_index_select %}
650
- if ( b_t > = offsets.size (0 ) - 1 ) {
651
- return ;
652
- }
650
+ {%- else %}
651
+ const auto total_B = offsets.size (0 ) - 1 ;
652
+ // Since we place a limit on the grid size, we need to perform grid-striding
653
+ for ( auto b_t = blockIdx . x * blockDim . y + threadIdx . y ; b_t < total_B; b_t += blockDim . y * gridDim . x ) {
653
654
{%- endif %}
654
655
655
656
// Determine the Table and Training Example IDs
@@ -832,6 +833,10 @@ batch_index_select_dim0_codegen_forward_kernel(
832
833
833
834
}
834
835
{%- endif %}
836
+
837
+ {%- if not is_index_select %}
838
+ } // for b_t
839
+ {%- endif %}
835
840
}
836
841
837
842
Original file line number Diff line number Diff line change 37
37
// //////////////////////////////////////////////////////////////////////////////
38
38
#include " fbgemm_gpu/utils/ops_utils.h"
39
39
{%- endif %}
40
+ #include " fbgemm_gpu/utils/device_properties.cuh"
40
41
#include " fbgemm_gpu/utils/kernel_launcher.cuh"
41
42
#include " fbgemm_gpu/embedding_forward_template_helpers.cuh"
42
43
#include " fbgemm_gpu/split_embeddings_cache_cuda.cuh"
@@ -708,6 +709,10 @@ batch_index_select_dim0_codegen_forward_cuda(
708
709
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread ;
709
710
{%- endif %}
710
711
712
+ const auto grid = min (
713
+ div_round_up (total_B, kForwardMaxThreads / kThreadGroupSize ),
714
+ utils::cuda::get_max_thread_blocks (at::cuda::getCurrentCUDAStream ()));
715
+
711
716
FBGEMM_LAUNCH_KERNEL (
712
717
({{ mdesc }}_embedding_codegen_forward_{{ desc_suffix }}_kernel
713
718
<emb_t ,
@@ -719,7 +724,7 @@ batch_index_select_dim0_codegen_forward_cuda(
719
724
index_t ,
720
725
kMaxVecsPerThread ,
721
726
kThreadGroupSize >),
722
- div_round_up (total_B, kForwardMaxThreads / kThreadGroupSize ) ,
727
+ grid ,
723
728
dim3 (kThreadGroupSize , kForwardMaxThreads / kThreadGroupSize ),
724
729
0 ,
725
730
at::cuda::getCurrentCUDAStream (),
Original file line number Diff line number Diff line change 8
8
9
9
#pragma once
10
10
11
+ #include < ATen/cuda/CUDAContext.h>
11
12
#include < c10/cuda/CUDAException.h>
13
+ #include < c10/cuda/CUDAStream.h>
12
14
#include < cuda.h>
13
15
14
- namespace fbgemm_gpu ::utils {
16
+ namespace fbgemm_gpu ::utils::cuda {
17
+
18
+ // Based on the empirical study, max grid size that is 64x larger than the
19
+ // number of SMs gives good performance across the board
20
+ constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64 ;
21
+
22
+ inline auto get_max_thread_blocks (const c10::cuda::CUDAStream& stream) {
23
+ const auto device = stream.device_index ();
24
+ return MAX_THREAD_BLOCKS_FACTOR *
25
+ at::cuda::getDeviceProperties (device)->multiProcessorCount ;
26
+ }
15
27
16
28
inline auto get_compute_versions () {
17
29
static const auto versions = [] {
@@ -27,4 +39,4 @@ inline auto get_compute_versions() {
27
39
return versions;
28
40
}
29
41
30
- } // namespace fbgemm_gpu::utils
42
+ } // namespace fbgemm_gpu::utils::cuda
Original file line number Diff line number Diff line change @@ -180,7 +180,13 @@ struct KernelLauncher {
180
180
TORCH_CHECK (
181
181
threads_per_block <= properties.maxThreadsPerBlock ,
182
182
context.description (),
183
- " : Threads per block " ,
183
+ " : [block dim " ,
184
+ block.x ,
185
+ " x " ,
186
+ block.y ,
187
+ " x " ,
188
+ block.z ,
189
+ " ] Threads per block " ,
184
190
threads_per_block,
185
191
" is greater than the limit of " ,
186
192
properties.maxThreadsPerBlock );
@@ -190,15 +196,28 @@ struct KernelLauncher {
190
196
// automatically work around problem like CUDA does (V100 or newer
191
197
// architectures), see:
192
198
// https://github.com/ROCm/hip/issues/2253
199
+ // https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/hip_runtime_api/modules/occupancy.html
193
200
const uint64_t total_threads = U64 (grid.x ) * U64 (grid.y ) * U64 (grid.z ) *
194
201
U64 (block.x ) * U64 (block.y ) * U64 (block.z );
195
202
196
203
TORCH_CHECK (
197
204
total_threads < U64 (std::numeric_limits<uint32_t >::max ()),
198
205
context.description (),
199
- " : Total number of threads " ,
206
+ " [grid dim " ,
207
+ grid.x ,
208
+ " x " ,
209
+ grid.y ,
210
+ " x " ,
211
+ grid.z ,
212
+ " ] [block dim " ,
213
+ block.x ,
214
+ " x " ,
215
+ block.y ,
216
+ " x " ,
217
+ block.z ,
218
+ " ]: Total number of threads " ,
200
219
total_threads,
201
- " is greater than the limit of 2^32" );
220
+ " is greater than the HIP limit of 2^32" );
202
221
#endif
203
222
}
204
223
You can’t perform that action at this time.
0 commit comments