We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1d9541b commit 30de365Copy full SHA for 30de365
torchrec/quant/embedding_modules.py
@@ -300,6 +300,23 @@ def _fx_trec_unwrap_kjt(
300
return indices.int(), offsets.int()
301
302
303
+@torch.fx.wrap
304
+def _fx_trec_unwrap_jt(
305
+ jt: JaggedTensor,
306
+) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ """
308
+ Forced conversions to support TBE
309
+ CPU - int32 or int64, offsets dtype must match
310
+ GPU - int32 only, offsets dtype must match
311
312
+ indices = jt.values()
313
+ offsets = jt.offsets()
314
+ if jt.device().type == "cpu":
315
+ return indices, offsets.type(dtype=indices.dtype)
316
+ else:
317
+ return indices.int(), offsets.int()
318
+
319
320
class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin):
321
"""
322
This class represents a reimplemented version of the EmbeddingBagCollection
0 commit comments