Skip to content

Commit fdc3144

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add NJT/TD support for EBC and pipeline benchmark
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. * ref: D63436011 # Details * `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results) * Currently only support EBC use case WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither. NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario. * Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces `TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822} * For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm Runtime-wise they are similar ~10% {F1949374305} * In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271} ``` TrainPipelineSparseDist | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD) TrainPipelineSparseDist | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT) ``` * increased data size, GPU runtime is 4x {F1949386106} # Conclusion 1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch. 2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models. 3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful. Differential Revision: D65103519
1 parent 6b37bc0 commit fdc3144

File tree

4 files changed

+101
-44
lines changed

4 files changed

+101
-44
lines changed

torchrec/distributed/embeddingbag.py

+66-40
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from tensordict import TensorDict
3031
from torch import distributed as dist, nn, Tensor
3132
from torch.autograd.profiler import record_function
3233
from torch.distributed._tensor import DTensor
@@ -90,7 +91,12 @@
9091
)
9192
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9293
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
93-
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
94+
from torchrec.sparse.jagged_tensor import (
95+
_to_offsets,
96+
KeyedJaggedTensor,
97+
KeyedTensor,
98+
td_to_kjt,
99+
)
94100

95101
try:
96102
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -655,9 +661,7 @@ def __init__(
655661
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
656662
# to support mean pooling callback hook
657663
self._has_mean_pooling_callback: bool = (
658-
True
659-
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
660-
else False
664+
PoolingType.MEAN.value in self._pooling_type_to_rs_features
661665
)
662666
self._dim_per_key: Optional[torch.Tensor] = None
663667
self._kjt_key_indices: Dict[str, int] = {}
@@ -1161,62 +1165,84 @@ def _create_inverse_indices_permute_indices(
11611165

11621166
# pyre-ignore [14]
11631167
def input_dist(
1164-
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
1168+
self,
1169+
ctx: EmbeddingBagCollectionContext,
1170+
features: Union[KeyedJaggedTensor, TensorDict],
11651171
) -> Awaitable[Awaitable[KJTList]]:
1166-
ctx.variable_batch_per_feature = features.variable_stride_per_key()
1167-
ctx.inverse_indices = features.inverse_indices_or_none()
1172+
if isinstance(features, KeyedJaggedTensor):
1173+
ctx.variable_batch_per_feature = features.variable_stride_per_key()
1174+
ctx.inverse_indices = features.inverse_indices_or_none()
1175+
feature_keys = features.keys()
1176+
else: # features is TensorDict
1177+
ctx.variable_batch_per_feature = False # TD does not support variable batch
1178+
ctx.inverse_indices = None
1179+
feature_keys = list(features.keys()) # pyre-ignore[6]
11681180
if self._has_uninitialized_input_dist:
1169-
self._create_input_dist(features.keys())
1181+
self._create_input_dist(feature_keys)
11701182
self._has_uninitialized_input_dist = False
11711183
if ctx.variable_batch_per_feature:
11721184
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
11731185
if self._has_mean_pooling_callback:
1174-
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
1186+
self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices)
11751187
with torch.no_grad():
1176-
if self._has_features_permute:
1177-
features = features.permute(
1178-
self._features_order,
1179-
self._features_order_tensor,
1180-
)
1181-
if self._has_mean_pooling_callback:
1182-
ctx.divisor = _create_mean_pooling_divisor(
1183-
lengths=features.lengths(),
1184-
stride=features.stride(),
1185-
keys=features.keys(),
1186-
offsets=features.offsets(),
1187-
pooling_type_to_rs_features=self._pooling_type_to_rs_features,
1188-
stride_per_key=features.stride_per_key(),
1189-
dim_per_key=self._dim_per_key, # pyre-ignore[6]
1190-
embedding_names=self._embedding_names,
1191-
embedding_dims=self._embedding_dims,
1192-
variable_batch_per_feature=ctx.variable_batch_per_feature,
1193-
kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6]
1194-
kjt_key_indices=self._kjt_key_indices,
1195-
kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6]
1196-
inverse_indices=ctx.inverse_indices,
1197-
weights=features.weights_or_none(),
1198-
)
1188+
if isinstance(features, KeyedJaggedTensor):
1189+
if self._has_features_permute:
1190+
features = features.permute(
1191+
self._features_order,
1192+
self._features_order_tensor,
1193+
)
1194+
if self._has_mean_pooling_callback:
1195+
ctx.divisor = _create_mean_pooling_divisor(
1196+
lengths=features.lengths(),
1197+
stride=features.stride(),
1198+
keys=features.keys(),
1199+
offsets=features.offsets(),
1200+
pooling_type_to_rs_features=self._pooling_type_to_rs_features,
1201+
stride_per_key=features.stride_per_key(),
1202+
dim_per_key=self._dim_per_key, # pyre-ignore[6]
1203+
embedding_names=self._embedding_names,
1204+
embedding_dims=self._embedding_dims,
1205+
variable_batch_per_feature=ctx.variable_batch_per_feature,
1206+
kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6]
1207+
kjt_key_indices=self._kjt_key_indices,
1208+
kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6]
1209+
inverse_indices=ctx.inverse_indices,
1210+
weights=features.weights_or_none(),
1211+
)
11991212

1200-
features_by_shards = features.split(
1201-
self._feature_splits,
1202-
)
1213+
features_by_sharding_types = features.split(
1214+
self._feature_splits,
1215+
)
1216+
else: # features is TensorDict
1217+
feature_names = [feature_keys[i] for i in self._features_order]
1218+
feature_name_by_sharding_types: List[List[str]] = []
1219+
start = 0
1220+
for length in self._feature_splits:
1221+
feature_name_by_sharding_types.append(
1222+
feature_names[start : start + length]
1223+
)
1224+
start += length
1225+
features_by_sharding_types = [
1226+
td_to_kjt(features, names)
1227+
for names in feature_name_by_sharding_types
1228+
]
12031229
awaitables = []
1204-
for input_dist, features_by_shard, sharding_type in zip(
1230+
for input_dist, features_by_sharding_type, sharding_type in zip(
12051231
self._input_dists,
1206-
features_by_shards,
1232+
features_by_sharding_types,
12071233
self._sharding_types,
12081234
):
12091235
with maybe_annotate_embedding_event(
12101236
EmbeddingEvent.KJT_SPLITS_DIST,
12111237
self._module_fqn,
12121238
sharding_type,
12131239
):
1214-
awaitables.append(input_dist(features_by_shard))
1240+
awaitables.append(input_dist(features_by_sharding_type))
12151241

12161242
ctx.sharding_contexts.append(
12171243
EmbeddingShardingContext(
1218-
batch_size_per_feature_pre_a2a=features_by_shard.stride_per_key(),
1219-
variable_batch_per_feature=features_by_shard.variable_stride_per_key(),
1244+
batch_size_per_feature_pre_a2a=features_by_sharding_type.stride_per_key(),
1245+
variable_batch_per_feature=features_by_sharding_type.variable_stride_per_key(),
12201246
)
12211247
)
12221248
return KJTListSplitsAwaitable(

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(
160160

161161
tables = [
162162
EmbeddingBagConfig(
163-
num_embeddings=(i + 1) * 1000,
163+
num_embeddings=max(i + 1, 100) * 1000,
164164
embedding_dim=dim_emb,
165165
name="table_" + str(i),
166166
feature_names=["feature_" + str(i)],
@@ -169,7 +169,7 @@ def main(
169169
]
170170
weighted_tables = [
171171
EmbeddingBagConfig(
172-
num_embeddings=(i + 1) * 1000,
172+
num_embeddings=max(i + 1, 100) * 1000,
173173
embedding_dim=dim_emb,
174174
name="weighted_table_" + str(i),
175175
feature_names=["weighted_feature_" + str(i)],

torchrec/modules/embedding_modules.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
from tensordict import TensorDict
1516
from torchrec.modules.embedding_configs import (
1617
DataType,
1718
EmbeddingBagConfig,
1819
EmbeddingConfig,
1920
pooling_type_to_str,
2021
)
21-
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
22+
from torchrec.sparse.jagged_tensor import (
23+
JaggedTensor,
24+
KeyedJaggedTensor,
25+
KeyedTensor,
26+
td_to_kjt,
27+
)
2228

2329

2430
@torch.fx.wrap
@@ -217,7 +223,7 @@ def __init__(
217223
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
218224
self.reset_parameters()
219225

220-
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
226+
def forward(self, features: Union[KeyedJaggedTensor, TensorDict]) -> KeyedTensor:
221227
"""
222228
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
223229
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
@@ -228,6 +234,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
228234
KeyedTensor
229235
"""
230236
flat_feature_names: List[str] = []
237+
if isinstance(features, TensorDict):
238+
features = td_to_kjt(features)
231239
for names in self._feature_names:
232240
flat_feature_names.extend(names)
233241
inverse_indices = reorder_inverse_indices(

torchrec/sparse/jagged_tensor.py

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1616

1717
import torch
18+
from tensordict import TensorDict
1819
from torch.autograd.profiler import record_function
1920
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
2021
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
@@ -3024,6 +3025,28 @@ def dist_init(
30243025
return kjt.sync()
30253026

30263027

3028+
def td_to_kjt(td: TensorDict, keys: Optional[List[str]] = None) -> KeyedJaggedTensor:
3029+
if keys is None:
3030+
keys = list(td.keys()) # pyre-ignore[6]
3031+
values = torch.cat([td[key]._values for key in keys], dim=0)
3032+
lengths = torch.cat(
3033+
[
3034+
(
3035+
(td[key]._lengths)
3036+
if td[key]._lengths is not None
3037+
else torch.diff(td[key]._offsets)
3038+
)
3039+
for key in keys
3040+
],
3041+
dim=0,
3042+
)
3043+
return KeyedJaggedTensor(
3044+
keys=keys,
3045+
values=values,
3046+
lengths=lengths,
3047+
)
3048+
3049+
30273050
def _kjt_flatten(
30283051
t: KeyedJaggedTensor,
30293052
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:

0 commit comments

Comments
 (0)