Skip to content

Commit e8c3451

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add NJT/TD support for EBC and pipeline benchmark (#2581)
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. * ref: D63436011 # Details * `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results) * Currently only support EBC use case WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither. NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario. * Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces `TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822} * For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm Runtime-wise they are similar ~10% {F1949374305} * In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271} ``` TrainPipelineSparseDist | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD) TrainPipelineSparseDist | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT) ``` * increased data size, GPU runtime is 4x {F1949386106} # Conclusion 1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch. 2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models. 3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful. Differential Revision: D65103519
1 parent 7be5368 commit e8c3451

File tree

5 files changed

+65
-27
lines changed

5 files changed

+65
-27
lines changed

torchrec/distributed/embeddingbag.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from tensordict import TensorDict
3031
from torch import distributed as dist, nn, Tensor
3132
from torch.autograd.profiler import record_function
3233
from torch.distributed._tensor import DTensor
@@ -92,6 +93,7 @@
9293
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9394
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9495
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
96+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9597

9698
try:
9799
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -100,13 +102,6 @@
100102
except OSError:
101103
pass
102104

103-
try:
104-
from tensordict import TensorDict
105-
except ImportError:
106-
107-
class TensorDict:
108-
pass
109-
110105

111106
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
112107
return (
@@ -661,9 +656,7 @@ def __init__(
661656
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
662657
# to support mean pooling callback hook
663658
self._has_mean_pooling_callback: bool = (
664-
True
665-
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
666-
else False
659+
PoolingType.MEAN.value in self._pooling_type_to_rs_features
667660
)
668661
self._dim_per_key: Optional[torch.Tensor] = None
669662
self._kjt_key_indices: Dict[str, int] = {}
@@ -1163,17 +1156,27 @@ def _create_inverse_indices_permute_indices(
11631156

11641157
# pyre-ignore [14]
11651158
def input_dist(
1166-
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
1159+
self,
1160+
ctx: EmbeddingBagCollectionContext,
1161+
features: Union[KeyedJaggedTensor, TensorDict],
11671162
) -> Awaitable[Awaitable[KJTList]]:
1163+
if isinstance(features, TensorDict):
1164+
feature_keys = list(features.keys()) # pyre-ignore[6]
1165+
if len(self._features_order) > 0:
1166+
feature_keys = [feature_keys[i] for i in self._features_order]
1167+
self._has_features_permute = False # feature_keys are in order
1168+
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
1169+
else:
1170+
feature_keys = features.keys()
11681171
ctx.variable_batch_per_feature = features.variable_stride_per_key()
11691172
ctx.inverse_indices = features.inverse_indices_or_none()
11701173
if self._has_uninitialized_input_dist:
1171-
self._create_input_dist(features.keys())
1174+
self._create_input_dist(feature_keys)
11721175
self._has_uninitialized_input_dist = False
11731176
if ctx.variable_batch_per_feature:
11741177
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
11751178
if self._has_mean_pooling_callback:
1176-
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
1179+
self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices)
11771180
with torch.no_grad():
11781181
if self._has_features_permute:
11791182
features = features.permute(

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(
160160

161161
tables = [
162162
EmbeddingBagConfig(
163-
num_embeddings=(i + 1) * 1000,
163+
num_embeddings=max(i + 1, 100) * 1000,
164164
embedding_dim=dim_emb,
165165
name="table_" + str(i),
166166
feature_names=["feature_" + str(i)],
@@ -169,7 +169,7 @@ def main(
169169
]
170170
weighted_tables = [
171171
EmbeddingBagConfig(
172-
num_embeddings=(i + 1) * 1000,
172+
num_embeddings=max(i + 1, 100) * 1000,
173173
embedding_dim=dim_emb,
174174
name="weighted_table_" + str(i),
175175
feature_names=["weighted_feature_" + str(i)],

torchrec/modules/embedding_modules.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,7 @@
1919
pooling_type_to_str,
2020
)
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
22-
23-
24-
try:
25-
from tensordict import TensorDict
26-
except ImportError:
27-
28-
class TensorDict:
29-
pass
22+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
3023

3124

3225
@torch.fx.wrap
@@ -237,6 +230,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
237230
KeyedTensor
238231
"""
239232
flat_feature_names: List[str] = []
233+
features = maybe_td_to_kjt(features, None)
240234
for names in self._feature_names:
241235
flat_feature_names.extend(names)
242236
inverse_indices = reorder_inverse_indices(

torchrec/sparse/jagged_tensor.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@
4949

5050
# OSS
5151
try:
52-
from tensordict import TensorDict
52+
pass
5353
except ImportError:
54-
55-
class TensorDict:
56-
pass
54+
pass
5755

5856

5957
logger: logging.Logger = logging.getLogger()

torchrec/sparse/tensor_dict.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import List, Optional
9+
10+
import torch
11+
from tensordict import TensorDict
12+
13+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
14+
15+
16+
def maybe_td_to_kjt(
17+
features: KeyedJaggedTensor, keys: Optional[List[str]] = None
18+
) -> KeyedJaggedTensor:
19+
if torch.jit.is_scripting():
20+
assert isinstance(features, KeyedJaggedTensor)
21+
return features
22+
if isinstance(features, TensorDict):
23+
if keys is None:
24+
keys = list(features.keys())
25+
values = torch.cat([features[key]._values for key in keys], dim=0)
26+
lengths = torch.cat(
27+
[
28+
(
29+
(features[key]._lengths)
30+
if features[key]._lengths is not None
31+
else torch.diff(features[key]._offsets)
32+
)
33+
for key in keys
34+
],
35+
dim=0,
36+
)
37+
return KeyedJaggedTensor(
38+
keys=keys,
39+
values=values,
40+
lengths=lengths,
41+
)
42+
else:
43+
return features

0 commit comments

Comments
 (0)