@@ -60,7 +60,7 @@ using namespace fbgemm_gpu;
60
60
61
61
In-code variables that are defined outside:
62
62
emb_t, cache_t, cache_t
63
- idx_j
63
+ offset_idx_j
64
64
D_emb
65
65
lxu_cache_weights
66
66
{{ locs_or_addrs_idx }}_j
@@ -96,7 +96,7 @@ using namespace fbgemm_gpu;
96
96
{%- if from_cache %}
97
97
cache_weights, // Load from the cache
98
98
{%- else %}
99
- &weights[idx_j * D_emb ], // Load from the embedding table
99
+ &weights[offset_idx_j ], // Load from the embedding table
100
100
{%- endif %}
101
101
D);
102
102
@@ -158,7 +158,7 @@ using namespace fbgemm_gpu;
158
158
159
159
In-code variables that are defined outside:
160
160
emb_t, cache_t, cache_t
161
- idx_j
161
+ offset_idx_j
162
162
inner_j
163
163
D_emb
164
164
lxu_cache_weights
@@ -194,7 +194,7 @@ using namespace fbgemm_gpu;
194
194
{%- if from_cache %}
195
195
cache_weights, // Load from the cache
196
196
{%- else %}
197
- &weights[idx_j * D_emb ], // Load from the embedding table
197
+ &weights[offset_idx_j ], // Load from the embedding table
198
198
{%- endif %}
199
199
D);
200
200
@@ -243,7 +243,7 @@ using namespace fbgemm_gpu;
243
243
244
244
In-code variables that are defined outside:
245
245
emb_t, cache_t, cache_t
246
- idx_j
246
+ offset_idx_j
247
247
inner_j
248
248
D_emb
249
249
lxu_cache_weights
@@ -289,20 +289,23 @@ using namespace fbgemm_gpu;
289
289
// Determine the L index that this thread will load data from in cooperative load
290
290
auto l = l_start + threadIdx .x ;
291
291
292
- {%- if dense or lxu_miss_rate != " cache_conflict_miss_rate::zero" %}
292
+ {%- if (
293
+ dense
294
+ or (lxu_miss_rate != " cache_conflict_miss_rate::zero" )
295
+ or (lxu_miss_rate == " cache_conflict_miss_rate::zero" and is_gwd_kernel)
296
+ )
297
+ %}
293
298
// Cooperatively load the indices
294
- [[maybe_unused]] int64_t idx = l < L ? indices[indices_start + l] : 0 ;
299
+ const overflow_safe_int_t idx = l < L ? indices[indices_start + l] : 0 ;
300
+ // If idx is loaded
301
+ const auto offset_idx = idx * D_emb;
295
302
{%- endif %}
296
303
297
304
{%- if not dense and lxu_miss_rate != " cache_conflict_miss_rate::all" %}
298
305
// Cooperatively load the cache's indices
299
306
[[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }} = (use_lxu_cache && placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0 ;
300
307
{%- endif %}
301
308
302
- {%- if lxu_miss_rate == " cache_conflict_miss_rate::zero" and is_gwd_kernel %}
303
- int64_t idx = l < L ? indices[indices_start + l] : 0 ; // only used for accessing prev_iter
304
- {%- endif %}
305
-
306
309
{%- if is_gwd_kernel %}
307
310
// if l > L or prev_iter == 0, global_weight_decay = 1
308
311
const auto prev_it = prev_iter[idx];
@@ -323,10 +326,10 @@ using namespace fbgemm_gpu;
323
326
{
324
327
{%- if dense or lxu_miss_rate != " cache_conflict_miss_rate::zero" %}
325
328
// Load index from thread j in the group
326
- [[maybe_unused]] int64_t idx_j_ [kManualUnrollLength ];
329
+ overflow_safe_int_t offset_idx_j_ [kManualUnrollLength ];
327
330
for (auto inner_j = 0 ; inner_j < kManualUnrollLength ; ++inner_j)
328
331
{
329
- idx_j_ [inner_j] = SHFL_SYNC (idx , outer_j + inner_j);
332
+ offset_idx_j_ [inner_j] = SHFL_SYNC (offset_idx , outer_j + inner_j);
330
333
}
331
334
{%- endif %}
332
335
@@ -353,13 +356,13 @@ using namespace fbgemm_gpu;
353
356
{
354
357
auto j = outer_j + inner_j;
355
358
{%- if is_index_select %}
356
- int64_t output_j = L_start + l_start + j;
359
+ overflow_safe_int_t output_j = L_start + l_start + j;
357
360
{%- elif nobag %}
358
- int64_t output_j = indices_start + l_start + j;
361
+ overflow_safe_int_t output_j = indices_start + l_start + j;
359
362
{%- endif %}
360
363
361
364
{%- if dense or lxu_miss_rate != " cache_conflict_miss_rate::zero" %}
362
- [[maybe_unused]] int64_t idx_j = idx_j_ [inner_j];
365
+ [[maybe_unused]] auto offset_idx_j = offset_idx_j_ [inner_j];
363
366
{%- endif %}
364
367
{%- if not dense and lxu_miss_rate != " cache_conflict_miss_rate::all" %}
365
368
[[maybe_unused]] {{ locs_or_addrs_type }} {{ locs_or_addrs_idx }}_j
@@ -411,13 +414,13 @@ using namespace fbgemm_gpu;
411
414
auto j = outer_j + inner_j;
412
415
413
416
{%- if is_index_select %}
414
- int64_t output_j = L_start + l_start + j;
417
+ overflow_safe_int_t output_j = L_start + l_start + j;
415
418
{%- elif nobag %}
416
- int64_t output_j = indices_start + l_start + j;
419
+ overflow_safe_int_t output_j = indices_start + l_start + j;
417
420
{%- endif %}
418
421
419
422
{%- if dense or lxu_miss_rate != " cache_conflict_miss_rate::zero" %}
420
- [[maybe_unused]] int64_t idx_j = idx_j_ [inner_j];
423
+ [[maybe_unused]] auto offset_idx_j = offset_idx_j_ [inner_j];
421
424
{%- endif %}
422
425
{%- if not dense and lxu_miss_rate != " cache_conflict_miss_rate::all" %}
423
426
[[maybe_unused]] int32_t {{ locs_or_addrs_idx }}_j = {{ locs_or_addrs_idx }}_j_[inner_j];
@@ -473,13 +476,13 @@ using namespace fbgemm_gpu;
473
476
{%- endif %}
474
477
{%- if dense or lxu_miss_rate != " cache_conflict_miss_rate::zero" %}
475
478
// Load index from thread j in the group
476
- [[maybe_unused]] int64_t idx_j = SHFL_SYNC (idx , j);
479
+ [[maybe_unused]] auto offset_idx_j = SHFL_SYNC (offset_idx , j);
477
480
{%- endif %}
478
481
479
482
{%- if is_index_select %}
480
- int64_t output_j = L_start + l_start + j;
483
+ overflow_safe_int_t output_j = L_start + l_start + j;
481
484
{%- elif nobag %}
482
- int64_t output_j = indices_start + l_start + j;
485
+ overflow_safe_int_t output_j = indices_start + l_start + j;
483
486
{%- endif %}
484
487
485
488
{%- if not dense and lxu_miss_rate != " cache_conflict_miss_rate::all" %}
@@ -664,23 +667,23 @@ batch_index_select_dim0_codegen_forward_kernel(
664
667
int32_t T = weights_offsets.size (0 );
665
668
666
669
{%- if is_index_select %}
667
- index_t indices_start;
670
+ overflow_safe_int_t indices_start;
668
671
int32_t L;
669
- int32_t L_start;
672
+ overflow_safe_int_t L_start;
670
673
if (t >= T) {
671
674
return ;
672
675
}
673
676
const auto total_L_start = total_L_offsets[t];
674
677
const auto total_L = total_L_offsets[t + 1 ] - total_L_start;
675
- L_start = b * fixed_L_per_warp;
678
+ L_start = static_cast < overflow_safe_int_t >(b) * fixed_L_per_warp;
676
679
if (L_start >= total_L) {
677
680
return ;
678
681
}
679
682
indices_start = total_L_start + L_start;
680
683
L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start);
681
684
{%- else %}
682
685
// Determine the number of indices Vec4(pooling factor) to look up within the bag
683
- index_t indices_start = offsets[b_t ];
686
+ overflow_safe_int_t indices_start = offsets[b_t ];
684
687
int32_t L = offsets[b_t + 1 ] - indices_start;
685
688
{%- endif %}
686
689
@@ -694,8 +697,8 @@ batch_index_select_dim0_codegen_forward_kernel(
694
697
{%- if is_index_select %}
695
698
// Check D in the kernel to avoid iterating through the list on host
696
699
CUDA_KERNEL_ASSERT (D % 4 == 0 && " The column size must be multiple of 4" );
697
- const auto output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t];
698
- const auto output_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
700
+ const overflow_safe_int_t output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t];
701
+ const overflow_safe_int_t output_stride = permute_output_dim_0_1 ? D_offsets[T] : D;
699
702
{%- endif %}
700
703
701
704
{%- if is_gwd_kernel %}
@@ -707,7 +710,7 @@ batch_index_select_dim0_codegen_forward_kernel(
707
710
708
711
// From the Table ID, fetch its weight tensor offset, locate that position
709
712
// in the input weights tensor, and set the weights table pointer
710
- int64_t weights_offset = weights_offsets[t];
713
+ const auto weights_offset = weights_offsets[t];
711
714
const emb_t * __restrict__ weights;
712
715
{%- if not dense %}
713
716
const auto placement = static_cast <PlacementType>(weights_placements[t]);
0 commit comments