File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -300,6 +300,23 @@ def _fx_trec_unwrap_kjt(
300
300
return indices .int (), offsets .int ()
301
301
302
302
303
+ @torch .fx .wrap
304
+ def _fx_trec_unwrap_jt (
305
+ jt : JaggedTensor ,
306
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
307
+ """
308
+ Forced conversions to support TBE
309
+ CPU - int32 or int64, offsets dtype must match
310
+ GPU - int32 only, offsets dtype must match
311
+ """
312
+ indices = jt .values ()
313
+ offsets = jt .offsets ()
314
+ if jt .device ().type == "cpu" :
315
+ return indices , offsets .type (dtype = indices .dtype )
316
+ else :
317
+ return indices .int (), offsets .int ()
318
+
319
+
303
320
class EmbeddingBagCollection (EmbeddingBagCollectionInterface , ModuleNoCopyMixin ):
304
321
"""
305
322
This class represents a reimplemented version of the EmbeddingBagCollection
You can’t perform that action at this time.
0 commit comments