Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 76446e7

Browse files
che-shfacebook-github-bot
authored andcommittedMar 19, 2025·
JaggedTensor permute - less CPU ops (#2786)
Summary: Pull Request resolved: #2786 `JaggedTensor.permute` could be called with very large `indices` list (a few hundred items) - so calling python properties `self.keys()`, `self.variable_stride_per_key()` and `self.stride_per_key_per_rank()` in the loop over indices start to compound and take noticeable time **on CPU**. Reviewed By: sarckk Differential Revision: D70609204 fbshipit-source-id: 257a9a45b204514eef932afcf7df958b194912b6
1 parent 055119e commit 76446e7

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed
 

‎torchrec/sparse/jagged_tensor.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -2511,22 +2511,23 @@ def permute(
25112511
permuted_stride_per_key_per_rank: List[List[int]] = []
25122512
permuted_length_per_key: List[int] = []
25132513
permuted_length_per_key_sum = 0
2514+
keys = self._keys
2515+
variable_stride_per_key = self.variable_stride_per_key()
2516+
stride_per_key_per_rank = self.stride_per_key_per_rank()
25142517
for index in indices:
2515-
key = self.keys()[index]
2518+
key = keys[index]
25162519
permuted_keys.append(key)
25172520
permuted_length_per_key.append(length_per_key[index])
2518-
if self.variable_stride_per_key():
2519-
permuted_stride_per_key_per_rank.append(
2520-
self.stride_per_key_per_rank()[index]
2521-
)
2521+
if variable_stride_per_key:
2522+
permuted_stride_per_key_per_rank.append(stride_per_key_per_rank[index])
25222523

25232524
permuted_length_per_key_sum = sum(permuted_length_per_key)
25242525
if not torch.jit.is_scripting() and is_non_strict_exporting():
25252526
torch._check_is_size(permuted_length_per_key_sum)
25262527
torch._check(permuted_length_per_key_sum != -1)
25272528
torch._check(permuted_length_per_key_sum != 0)
25282529

2529-
if self.variable_stride_per_key():
2530+
if variable_stride_per_key:
25302531
length_per_key_tensor = _pin_and_move(
25312532
torch.tensor(self.length_per_key()), self.device()
25322533
)
@@ -2571,7 +2572,7 @@ def permute(
25712572
permuted_length_per_key_sum,
25722573
)
25732574
stride_per_key_per_rank = (
2574-
permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None
2575+
permuted_stride_per_key_per_rank if variable_stride_per_key else None
25752576
)
25762577
kjt = KeyedJaggedTensor(
25772578
keys=permuted_keys,

0 commit comments

Comments
 (0)
Please sign in to comment.