Skip to content

Commit fbe2939

Browse files
jyingl3facebook-github-bot
authored andcommitted
tag KJT related part to sc.INPUT_DIST (#2678)
Summary: tag KJT output to sc.INPUT_DIST to enforce placement on CPU Differential Revision: D67313214
1 parent 411876a commit fbe2939

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchrec/quant/embedding_modules.py

+17
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,23 @@ def _fx_trec_unwrap_kjt(
300300
return indices.int(), offsets.int()
301301

302302

303+
@torch.fx.wrap
304+
def _fx_trec_unwrap_jt(
305+
jt: JaggedTensor,
306+
) -> Tuple[torch.Tensor, torch.Tensor]:
307+
"""
308+
Forced conversions to support TBE
309+
CPU - int32 or int64, offsets dtype must match
310+
GPU - int32 only, offsets dtype must match
311+
"""
312+
indices = jt.values()
313+
offsets = jt.offsets()
314+
if jt.device().type == "cpu":
315+
return indices, offsets.type(dtype=indices.dtype)
316+
else:
317+
return indices.int(), offsets.int()
318+
319+
303320
class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin):
304321
"""
305322
This class represents a reimplemented version of the EmbeddingBagCollection

0 commit comments

Comments
 (0)