Skip to content

Commit 3be1d8d

Browse files
sryapfacebook-github-bot
authored andcommitted
Use overflow_safe_int_t in TBE forward training (pytorch#3953)
Summary: Pull Request resolved: pytorch#3953 X-link: facebookresearch/FBGEMM#1037 This diff updates the TBE forward training kernel to use `overflow_safe_int_t` (`int64_t`) for indices of the tensors that the number of elements can be larger than `int32_t` (i.e., using `PackedTensorAccessor64`). Reviewed By: jwfromm Differential Revision: D72502238 fbshipit-source-id: d8756f8f1dcce4f82767f29f2947a5f371e58222
1 parent 9af600e commit 3be1d8d

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

+32-29
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using namespace fbgemm_gpu;
6060
6161
In-code variables that are defined outside:
6262
emb_t, cache_t, cache_t
63-
idx_j
63+
offset_idx_j
6464
D_emb
6565
lxu_cache_weights
6666
{{ locs_or_addrs_idx }}_j
@@ -96,7 +96,7 @@ using namespace fbgemm_gpu;
9696
{%- if from_cache %}
9797
cache_weights, // Load from the cache
9898
{%- else %}
99-
&weights[idx_j * D_emb], // Load from the embedding table
99+
&weights[offset_idx_j], // Load from the embedding table
100100
{%- endif %}
101101
D);
102102

@@ -158,7 +158,7 @@ using namespace fbgemm_gpu;
158158
159159
In-code variables that are defined outside:
160160
emb_t, cache_t, cache_t
161-
idx_j
161+
offset_idx_j
162162
inner_j
163163
D_emb
164164
lxu_cache_weights
@@ -194,7 +194,7 @@ using namespace fbgemm_gpu;
194194
{%- if from_cache %}
195195
cache_weights, // Load from the cache
196196
{%- else %}
197-
&weights[idx_j * D_emb], // Load from the embedding table
197+
&weights[offset_idx_j], // Load from the embedding table
198198
{%- endif %}
199199
D);
200200

@@ -243,7 +243,7 @@ using namespace fbgemm_gpu;
243243
244244
In-code variables that are defined outside:
245245
emb_t, cache_t, cache_t
246-
idx_j
246+
offset_idx_j
247247
inner_j
248248
D_emb
249249
lxu_cache_weights
@@ -289,20 +289,23 @@ using namespace fbgemm_gpu;
289289
// Determine the L index that this thread will load data from in cooperative load
290290
auto l = l_start + threadIdx.x;
291291

292-
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
292+
{%- if (
293+
dense
294+
or (lxu_miss_rate != "cache_conflict_miss_rate::zero")
295+
or (lxu_miss_rate == "cache_conflict_miss_rate::zero" and is_gwd_kernel)
296+
)
297+
%}
293298
// Cooperatively load the indices
294-
[[maybe_unused]] int64_t idx = l < L ? indices[indices_start + l] : 0;
299+
const overflow_safe_int_t idx = l < L ? indices[indices_start + l] : 0;
300+
// If idx is loaded
301+
const auto offset_idx = idx * D_emb;
295302
{%- endif %}
296303

297304
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
298305
// Cooperatively load the cache's indices
299306
[[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
300307
{%- endif %}
301308

302-
{%- if lxu_miss_rate == "cache_conflict_miss_rate::zero" and is_gwd_kernel %}
303-
int64_t idx = l < L ? indices[indices_start + l] : 0; // only used for accessing prev_iter
304-
{%- endif %}
305-
306309
{%- if is_gwd_kernel %}
307310
// if l > L or prev_iter == 0, global_weight_decay = 1
308311
const auto prev_it = prev_iter[idx];
@@ -323,10 +326,10 @@ using namespace fbgemm_gpu;
323326
{
324327
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
325328
// Load index from thread j in the group
326-
[[maybe_unused]] int64_t idx_j_[kManualUnrollLength];
329+
overflow_safe_int_t offset_idx_j_[kManualUnrollLength];
327330
for (auto inner_j = 0; inner_j < kManualUnrollLength; ++inner_j)
328331
{
329-
idx_j_[inner_j] = SHFL_SYNC(idx, outer_j + inner_j);
332+
offset_idx_j_[inner_j] = SHFL_SYNC(offset_idx, outer_j + inner_j);
330333
}
331334
{%- endif %}
332335

@@ -353,13 +356,13 @@ using namespace fbgemm_gpu;
353356
{
354357
auto j = outer_j + inner_j;
355358
{%- if is_index_select %}
356-
int64_t output_j = L_start + l_start + j;
359+
overflow_safe_int_t output_j = L_start + l_start + j;
357360
{%- elif nobag %}
358-
int64_t output_j = indices_start + l_start + j;
361+
overflow_safe_int_t output_j = indices_start + l_start + j;
359362
{%- endif %}
360363

361364
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
362-
[[maybe_unused]] int64_t idx_j = idx_j_[inner_j];
365+
[[maybe_unused]] auto offset_idx_j = offset_idx_j_[inner_j];
363366
{%- endif %}
364367
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
365368
[[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j
@@ -411,13 +414,13 @@ using namespace fbgemm_gpu;
411414
auto j = outer_j + inner_j;
412415

413416
{%- if is_index_select %}
414-
int64_t output_j = L_start + l_start + j;
417+
overflow_safe_int_t output_j = L_start + l_start + j;
415418
{%- elif nobag %}
416-
int64_t output_j = indices_start + l_start + j;
419+
overflow_safe_int_t output_j = indices_start + l_start + j;
417420
{%- endif %}
418421

419422
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
420-
[[maybe_unused]] int64_t idx_j = idx_j_[inner_j];
423+
[[maybe_unused]] auto offset_idx_j = offset_idx_j_[inner_j];
421424
{%- endif %}
422425
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
423426
[[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j = {{ locs_or_addrs_idx }}_j_[inner_j];
@@ -473,13 +476,13 @@ using namespace fbgemm_gpu;
473476
{%- endif %}
474477
{%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %}
475478
// Load index from thread j in the group
476-
[[maybe_unused]] int64_t idx_j = SHFL_SYNC(idx, j);
479+
[[maybe_unused]] auto offset_idx_j = SHFL_SYNC(offset_idx, j);
477480
{%- endif %}
478481

479482
{%- if is_index_select %}
480-
int64_t output_j = L_start + l_start + j;
483+
overflow_safe_int_t output_j = L_start + l_start + j;
481484
{%- elif nobag %}
482-
int64_t output_j = indices_start + l_start + j;
485+
overflow_safe_int_t output_j = indices_start + l_start + j;
483486
{%- endif %}
484487

485488
{%- if not dense and lxu_miss_rate != "cache_conflict_miss_rate::all" %}
@@ -664,23 +667,23 @@ batch_index_select_dim0_codegen_forward_kernel(
664667
int32_t T = weights_offsets.size(0);
665668

666669
{%- if is_index_select %}
667-
index_t indices_start;
670+
overflow_safe_int_t indices_start;
668671
int32_t L;
669-
int32_t L_start;
672+
overflow_safe_int_t L_start;
670673
if (t >= T) {
671674
return;
672675
}
673676
const auto total_L_start = total_L_offsets[t];
674677
const auto total_L = total_L_offsets[t + 1] - total_L_start;
675-
L_start = b * fixed_L_per_warp;
678+
L_start = static_cast<overflow_safe_int_t>(b) * fixed_L_per_warp;
676679
if (L_start >= total_L) {
677680
return;
678681
}
679682
indices_start = total_L_start + L_start;
680683
L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start);
681684
{%- else %}
682685
// Determine the number of indices Vec4(pooling factor) to look up within the bag
683-
index_t indices_start = offsets[b_t];
686+
overflow_safe_int_t indices_start = offsets[b_t];
684687
int32_t L = offsets[b_t + 1] - indices_start;
685688
{%- endif %}
686689

@@ -694,8 +697,8 @@ batch_index_select_dim0_codegen_forward_kernel(
694697
{%- if is_index_select %}
695698
// Check D in the kernel to avoid iterating through the list on host
696699
CUDA_KERNEL_ASSERT(D % 4 == 0 && "The column size must be multiple of 4");
697-
const auto output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t];
698-
const auto output_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
700+
const overflow_safe_int_t output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t];
701+
const overflow_safe_int_t output_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
699702
{%- endif %}
700703

701704
{%- if is_gwd_kernel %}
@@ -707,7 +710,7 @@ batch_index_select_dim0_codegen_forward_kernel(
707710

708711
// From the Table ID, fetch its weight tensor offset, locate that position
709712
// in the input weights tensor, and set the weights table pointer
710-
int64_t weights_offset = weights_offsets[t];
713+
const auto weights_offset = weights_offsets[t];
711714
const emb_t* __restrict__ weights;
712715
{%- if not dense %}
713716
const auto placement = static_cast<PlacementType>(weights_placements[t]);

0 commit comments

Comments
 (0)