Skip to content

Commit 534df45

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add stride into KJT pytree (#2587)
Summary: # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Differential Revision: D66400821
1 parent 8afe20e commit 534df45

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

torchrec/sparse/jagged_tensor.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -3031,13 +3031,17 @@ def dist_init(
30313031

30323032
def _kjt_flatten(
30333033
t: KeyedJaggedTensor,
3034-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3035-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3034+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], int]]:
3035+
# for variable batch scenario, the stride cannot be computed from lengths/len(keys),
3036+
# instead, it should be computed from stride_per_key_per_rank, which is not included
3037+
# in the flatten spec. The stride is needed for the EBC output shape, so we need to
3038+
# store it in the context.
3039+
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (t._keys, t.stride())
30363040

30373041

30383042
def _kjt_flatten_with_keys(
30393043
t: KeyedJaggedTensor,
3040-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3044+
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], int]]:
30413045
values, context = _kjt_flatten(t)
30423046
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30433047
return [ # pyre-ignore[7]
@@ -3046,9 +3050,11 @@ def _kjt_flatten_with_keys(
30463050

30473051

30483052
def _kjt_unflatten(
3049-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3053+
values: List[Optional[torch.Tensor]],
3054+
context: List[str], # context is (_keys, _stride)
30503055
) -> KeyedJaggedTensor:
3051-
return KeyedJaggedTensor(context, *values)
3056+
keys, stride = context
3057+
return KeyedJaggedTensor(keys, *values, stride=stride)
30523058

30533059

30543060
def _kjt_flatten_spec(

0 commit comments

Comments
 (0)