You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
change dtype of block_bucketize_row_pos and fix flaky test_kjt_bucketize_before_all2all_cpu (#2689)
Summary:
Pull Request resolved: #2689
# context
* found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461
* the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function
```
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
if block_bucketize_row_pos is not None
else None
),
```
where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor
```
torch.fx.wrap
def _fx_wrap_tensor_to_device_dtype(
t: torch.Tensor, tensor_device_dtype: torch.Tensor
) -> torch.Tensor:
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
```
* the fix is supposed to be straightforward to apply a list-comprehension over the function
```
block_bucketize_pos=(
[
_fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths()
for pos in block_bucketize_row_pos
]
```
* according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error
{F1974430883}
* according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`.
length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`.
Reviewed By: dstaay-fb
Differential Revision: D68358894
fbshipit-source-id: 13303c54288c99c6cf58d550365f8d3c698c34b1
0 commit comments