You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
# context
* found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461
* the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function
```
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
if block_bucketize_row_pos is not None
else None
),
```
where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor
```
torch.fx.wrap
def _fx_wrap_tensor_to_device_dtype(
t: torch.Tensor, tensor_device_dtype: torch.Tensor
) -> torch.Tensor:
return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype)
```
* the fix is supposed to be straightforward to apply a list-comprehension over the function
```
block_bucketize_pos=(
[
_fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths()
for pos in block_bucketize_row_pos
]
```
* according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error
{F1974430883}
* according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`.
length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`.
Reviewed By: dstaay-fb
Differential Revision: D68358894
0 commit comments