Skip to content

Commit a1d1c20

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix failed test_kjt_bucketize_before_all2all_cpu
Summary: # context * found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461 * the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function ``` block_bucketize_pos=( _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) if block_bucketize_row_pos is not None else None ), ``` where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor ``` torch.fx.wrap def _fx_wrap_tensor_to_device_dtype( t: torch.Tensor, tensor_device_dtype: torch.Tensor ) -> torch.Tensor: return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) ``` * the fix is supposed to be straightforward to apply a list-comprehension over the function ``` block_bucketize_pos=( [ _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths() for pos in block_bucketize_row_pos ] ``` * according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error {F1974430883} * according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`. length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`. Reviewed By: dstaay-fb Differential Revision: D68358894
1 parent 9dfdfb8 commit a1d1c20

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

torchrec/distributed/embedding_sharding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all(
274274
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
275275
max_B=_fx_wrap_max_B(kjt),
276276
block_bucketize_pos=(
277-
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
277+
[
278+
_fx_wrap_tensor_to_device_dtype(pos, kjt.values())
279+
for pos in block_bucketize_row_pos
280+
]
278281
if block_bucketize_row_pos is not None
279282
else None
280283
),

torchrec/distributed/tests/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_kjt_bucketize_before_all2all(
364364
batch_size=st.integers(1, 15),
365365
variable_bucket_pos=st.booleans(),
366366
)
367-
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
367+
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
368368
def test_kjt_bucketize_before_all2all_cpu(
369369
self,
370370
index_type: torch.dtype,

0 commit comments

Comments
 (0)