Skip to content

Commit 9af600e

Browse files
sryapfacebook-github-bot
authored andcommitted
Fix weighted TBE inference NaN (un-init) row_weights (pytorch#4006)
Summary: Pull Request resolved: pytorch#4006 X-link: facebookresearch/FBGEMM#1093 This diff fixes the problem introduced by D70855331 that the `per_sample_weights` in the TBE inference kernels were not properly initialized and became NaNs, causing the embedding lookup output to contain NaNs. Reviewed By: jwfromm Differential Revision: D73387999 fbshipit-source-id: 005e41f829cd6e255d3b78a0bb06dfd491ae0ee4
1 parent 0040646 commit 9af600e

File tree

3 files changed

+100
-13
lines changed

3 files changed

+100
-13
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu

+6-12
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,12 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
212212
buffers[warp_idx][i][input_row_idx][row_load_idx] = data;
213213
}
214214
{% if weighted %}
215-
{%- if is_rocm %}
216-
if (valid && row_load_idx == 0) {
215+
if (row_load_idx == 0) {
217216
// Use only one thread to load the index weight to prevent a race
218217
// condition when writing to the shared memory
219-
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = indice_weights[indices_starts[i] + L_start + input_row_idx];
218+
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] =
219+
valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
220220
}
221-
{%- else %}
222-
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
223-
{%- endif %}
224221
{% endif %}
225222
}
226223
}
@@ -255,15 +252,12 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
255252
cp_async_zfill_cg<sizeof(uint4)>(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid);
256253
}
257254
{% if weighted %}
258-
{%- if is_rocm %}
259-
if (valid && row_load_idx == 0) {
255+
if (row_load_idx == 0) {
260256
// Use only one thread to load the index weight to prevent a race
261257
// condition when writing to the shared memory
262-
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = indice_weights[indices_starts[i] + L_start + input_row_idx];
258+
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] =
259+
valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
263260
}
264-
{%- else %}
265-
buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0;
266-
{%- endif %}
267261
{% endif %}
268262
}
269263
{%- if is_rocm %}

fbgemm_gpu/test/tbe/inference/failures_dict_fast.json

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"fbgemm::HFP8QuantizedToFloat": {},
1313
"fbgemm::asynchronous_complete_cumsum": {},
1414
"fbgemm::bounds_check_indices": {},
15+
"fbgemm::check_feature_gate_key": {},
1516
"fbgemm::dense_embedding_codegen_lookup_function": {
1617
"BackwardDenseTest.test_autograd_registration__test_backward_dense": {
1718
"comment": "",
@@ -30,12 +31,14 @@
3031
}
3132
},
3233
"fbgemm::emb_inplace_update": {},
34+
"fbgemm::get_infos_metadata": {},
3335
"fbgemm::get_unique_indices": {
3436
"LXUCacheTest.test_faketensor__test_unique_lxu_cache_lookup": {
3537
"comment": "",
3638
"status": "xfail"
3739
}
3840
},
41+
"fbgemm::initialize_nan_shared_mem": {},
3942
"fbgemm::int_nbit_split_embedding_codegen_lookup_function": {
4043
"NBitForwardTest.test_faketensor__test_nbit_forward_fused_pooled_emb_quant": {
4144
"comment": "",
@@ -306,6 +309,7 @@
306309
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {},
307310
"fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {},
308311
"fbgemm::split_embedding_codegen_lookup_sgd_function": {},
309-
"fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {}
312+
"fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {},
313+
"fbgemm::split_embedding_codegen_lookup_sgd_function_pt2": {}
310314
}
311315
}

fbgemm_gpu/test/tbe/inference/nbit_forward_test.py

+89
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@
115115
"Operator outputs int4 tensors which do not support opcheck tests"
116116
),
117117
],
118+
"test_faketensor__test_nbit_forward_fused_pooled_emb_quant_nan_weighted": [
119+
unittest.skip("Operator not implemented for fake tensors"),
120+
],
118121
}
119122

120123

@@ -354,6 +357,92 @@ def test_nbit_forward_fused_pooled_emb_quant_against_ref(
354357
**kwargs,
355358
)
356359

360+
@unittest.skipIf(*gpu_unavailable)
361+
def test_nbit_forward_fused_pooled_emb_quant_nan_weighted(self) -> None:
362+
# Hash size
363+
E = 10
364+
# Embedding dimensoin
365+
D = 160
366+
# Pooling factor
367+
L = 64
368+
369+
# Use TBE training op as a reference
370+
op_ref = SplitTableBatchedEmbeddingBagsCodegen(
371+
[
372+
(E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA),
373+
],
374+
weights_precision=SparseType.FP32,
375+
output_dtype=SparseType.FP32,
376+
device=torch.cuda.current_device(),
377+
)
378+
379+
# Instantiate TBE inference
380+
op = IntNBitTableBatchedEmbeddingBagsCodegen(
381+
embedding_specs=[
382+
(
383+
"",
384+
E,
385+
D,
386+
SparseType.INT4,
387+
EmbeddingLocation.DEVICE,
388+
),
389+
],
390+
output_dtype=SparseType.FP16,
391+
)
392+
393+
# Initialize weights_ref with 1.0
394+
weights_ref = op_ref.split_embedding_weights()
395+
weights_ref[0].fill_(1.0)
396+
397+
# Copy weights_ref to weights
398+
op.initialize_weights()
399+
weights = op.split_embedding_weights()
400+
quant_weights, quant_scale_shift = quantize_embs(
401+
weights_ref[0], SparseType.INT4
402+
)
403+
weights[0][0].copy_(quant_weights)
404+
weights[0][1].copy_(quant_scale_shift)
405+
406+
# Generate inputs
407+
indices = torch.as_tensor(
408+
[0] * L, device=torch.cuda.current_device(), dtype=torch.int
409+
)
410+
offsets = torch.as_tensor(
411+
[0, L], device=torch.cuda.current_device(), dtype=torch.int
412+
)
413+
per_sample_weights = torch.arange(
414+
L, device=torch.cuda.current_device(), dtype=torch.float
415+
)
416+
417+
# Set a bunch of indices to -1 to simulate pruning.
418+
pruned_indices = indices.clone().detach()
419+
prune_select = torch.arange(pruned_indices.numel()) % 8 == 0
420+
pruned_indices[prune_select] = -1
421+
422+
# Pre-prune per_sample_weights for reference
423+
pruned_per_sample_weights = per_sample_weights.clone().detach()
424+
pruned_per_sample_weights[prune_select] = 0.0
425+
426+
# Run reference
427+
output_ref = op_ref(
428+
indices=indices,
429+
offsets=offsets,
430+
per_sample_weights=pruned_per_sample_weights,
431+
)
432+
433+
# Initialize shared memory to NaNs.
434+
torch.ops.fbgemm.initialize_nan_shared_mem(torch.cuda.current_device())
435+
436+
# Run test
437+
output = op(
438+
indices=pruned_indices,
439+
offsets=offsets,
440+
per_sample_weights=per_sample_weights,
441+
)
442+
443+
# Expect the outputs to be bit-wise equivalent
444+
assert torch.equal(output_ref, output)
445+
357446
@unittest.skipIf(*gpu_unavailable)
358447
@given(
359448
T=st.integers(min_value=1, max_value=10),

0 commit comments

Comments
 (0)