Skip to content

Commit a8c47ee

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
use inverse indices in KJT permute
Summary: calling a kjt.permute() on a VBE KJT makes the output KJT no longer VBE. this diff fixes it such that the output KJT is VBE. Reviewed By: joshuadeng Differential Revision: D65621958
1 parent 42c512c commit a8c47ee

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

torchrec/sparse/jagged_tensor.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,10 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
24892489
return split_list
24902490

24912491
def permute(
2492-
self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None
2492+
self,
2493+
indices: List[int],
2494+
indices_tensor: Optional[torch.Tensor] = None,
2495+
include_inverse_indices: bool = False,
24932496
) -> "KeyedJaggedTensor":
24942497
"""
24952498
Permutes the KeyedJaggedTensor.
@@ -2587,7 +2590,9 @@ def permute(
25872590
offset_per_key=None,
25882591
index_per_key=None,
25892592
jt_dict=None,
2590-
inverse_indices=None,
2593+
inverse_indices=(
2594+
self.inverse_indices_or_none() if include_inverse_indices else None
2595+
),
25912596
)
25922597
return kjt
25932598

torchrec/sparse/tests/test_jagged_tensor.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1372,16 +1372,27 @@ def test_permute_vb(self) -> None:
13721372
lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0])
13731373
keys = ["index_0", "index_1", "index_2"]
13741374
stride_per_key_per_rank = [[2], [4], [3]]
1375+
inverse_indices = (
1376+
["index_0", "index_1", "index_2"],
1377+
torch.Tensor(
1378+
[
1379+
[0, 0, 0, 0, 0, 1, 1, 1, 1],
1380+
[0, 0, 1, 1, 3, 3, 2, 2, 1],
1381+
[2, 2, 1, 0, 0, 2, 1, 2, 0],
1382+
]
1383+
),
1384+
)
13751385

13761386
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
13771387
values=values,
13781388
keys=keys,
13791389
lengths=lengths,
13801390
stride_per_key_per_rank=stride_per_key_per_rank,
1391+
inverse_indices=inverse_indices,
13811392
)
13821393

13831394
indices = [1, 0, 2]
1384-
permuted_jag_tensor = jag_tensor.permute(indices)
1395+
permuted_jag_tensor = jag_tensor.permute(indices, include_inverse_indices=True)
13851396

13861397
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
13871398
self.assertEqual(
@@ -1401,6 +1412,15 @@ def test_permute_vb(self) -> None:
14011412
)
14021413
)
14031414
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1415+
self.assertEqual(
1416+
jag_tensor.inverse_indices()[0], permuted_jag_tensor.inverse_indices()[0]
1417+
)
1418+
self.assertTrue(
1419+
torch.equal(
1420+
jag_tensor.inverse_indices()[1],
1421+
permuted_jag_tensor.inverse_indices()[1],
1422+
)
1423+
)
14041424

14051425
def test_permute_vb_duplicate(self) -> None:
14061426
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])

0 commit comments

Comments
 (0)