Skip to content

Commit 8b5124f

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add stride into KJT pytree
Summary: # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Differential Revision: D66400821
1 parent 7f3b7dc commit 8b5124f

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

torchrec/sparse/jagged_tensor.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -3026,13 +3026,17 @@ def dist_init(
30263026

30273027
def _kjt_flatten(
30283028
t: KeyedJaggedTensor,
3029-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3030-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3029+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], int]]:
3030+
# for variable batch scenario, the stride cannot be computed from lengths/len(keys),
3031+
# instead, it should be computed from stride_per_key_per_rank, which is not included
3032+
# in the flatten spec. The stride is needed for the EBC output shape, so we need to
3033+
# store it in the context.
3034+
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (t._keys, t.stride())
30313035

30323036

30333037
def _kjt_flatten_with_keys(
30343038
t: KeyedJaggedTensor,
3035-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3039+
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], int]]:
30363040
values, context = _kjt_flatten(t)
30373041
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30383042
return [ # pyre-ignore[7]
@@ -3041,9 +3045,11 @@ def _kjt_flatten_with_keys(
30413045

30423046

30433047
def _kjt_unflatten(
3044-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3048+
values: List[Optional[torch.Tensor]],
3049+
context: List[str], # context is (_keys, _stride)
30453050
) -> KeyedJaggedTensor:
3046-
return KeyedJaggedTensor(context, *values)
3051+
keys, stride = context
3052+
return KeyedJaggedTensor(keys, *values, stride=stride)
30473053

30483054

30493055
def _kjt_flatten_spec(

0 commit comments

Comments
 (0)