Skip to content

Commit 79f5c5e

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add NJT/TD support for EC (#2596)
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
1 parent a502bbb commit 79f5c5e

File tree

4 files changed

+110
-21
lines changed

4 files changed

+110
-21
lines changed

torchrec/distributed/embedding.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
import torch
29+
from tensordict import TensorDict
2930
from torch import distributed as dist, nn
3031
from torch.autograd.profiler import record_function
3132
from torch.distributed._tensor import DTensor
@@ -88,21 +89,19 @@
8889
from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext
8990
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9091
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
91-
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
92+
from torchrec.sparse.jagged_tensor import (
93+
_to_offsets,
94+
JaggedTensor,
95+
KeyedJaggedTensor,
96+
td_to_kjt,
97+
)
9298

9399
try:
94100
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
95101
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
96102
except OSError:
97103
pass
98104

99-
try:
100-
from tensordict import TensorDict
101-
except ImportError:
102-
103-
class TensorDict:
104-
pass
105-
106105

107106
logger: logging.Logger = logging.getLogger(__name__)
108107

@@ -1146,25 +1145,50 @@ def _compute_sequence_vbe_context(
11461145
def input_dist(
11471146
self,
11481147
ctx: EmbeddingCollectionContext,
1149-
features: KeyedJaggedTensor,
1148+
features: TypeUnion[KeyedJaggedTensor, TensorDict],
11501149
) -> Awaitable[Awaitable[KJTList]]:
1150+
# torch.distributed.breakpoint()
1151+
feature_keys = list(features.keys()) # pyre-ignore[6]
11511152
if self._has_uninitialized_input_dist:
1152-
self._create_input_dist(input_feature_names=features.keys())
1153+
self._create_input_dist(input_feature_names=feature_keys)
11531154
self._has_uninitialized_input_dist = False
11541155
with torch.no_grad():
11551156
unpadded_features = None
1156-
if features.variable_stride_per_key():
1157+
if (
1158+
isinstance(features, KeyedJaggedTensor)
1159+
and features.variable_stride_per_key()
1160+
):
11571161
unpadded_features = features
11581162
features = pad_vbe_kjt_lengths(unpadded_features)
11591163

1160-
if self._features_order:
1164+
if isinstance(features, KeyedJaggedTensor) and self._features_order:
11611165
features = features.permute(
11621166
self._features_order,
11631167
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
11641168
# but got `TypeUnion[Module, Tensor]`.
11651169
self._features_order_tensor,
11661170
)
1167-
features_by_shards = features.split(self._feature_splits)
1171+
1172+
if isinstance(features, KeyedJaggedTensor):
1173+
features_by_shards = features.split(self._feature_splits)
1174+
else: # TensorDict
1175+
feature_names = (
1176+
[feature_keys[i] for i in self._features_order]
1177+
if self._features_order # empty features_order means no reordering
1178+
else feature_keys
1179+
)
1180+
feature_names = [name.split("@")[0] for name in feature_names]
1181+
feature_name_by_sharding_types: List[List[str]] = []
1182+
start = 0
1183+
for length in self._feature_splits:
1184+
feature_name_by_sharding_types.append(
1185+
feature_names[start : start + length]
1186+
)
1187+
start += length
1188+
features_by_shards = [
1189+
td_to_kjt(features, names)
1190+
for names in feature_name_by_sharding_types
1191+
]
11681192
if self._use_index_dedup:
11691193
features_by_shards = self._dedup_indices(ctx, features_by_shards)
11701194

torchrec/distributed/test_utils/test_sharding.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def gen_model_and_input(
148148
long_indices: bool = True,
149149
global_constant_batch: bool = False,
150150
num_inputs: int = 1,
151+
input_type: str = "kjt", # "kjt" or "td"
151152
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
152153
torch.manual_seed(0)
153154
if dedup_feature_names:
@@ -178,9 +179,9 @@ def gen_model_and_input(
178179
feature_processor_modules=feature_processor_modules,
179180
)
180181
inputs = []
181-
for _ in range(num_inputs):
182-
inputs.append(
183-
(
182+
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
183+
for _ in range(num_inputs):
184+
inputs.append(
184185
cast(VariableBatchModelInputCallable, generate)(
185186
average_batch_size=batch_size,
186187
world_size=world_size,
@@ -189,8 +190,26 @@ def gen_model_and_input(
189190
weighted_tables=weighted_tables or [],
190191
global_constant_batch=global_constant_batch,
191192
)
192-
if generate == ModelInput.generate_variable_batch_input
193-
else cast(ModelInputCallable, generate)(
193+
)
194+
elif generate == ModelInput.generate:
195+
for _ in range(num_inputs):
196+
inputs.append(
197+
ModelInput.generate(
198+
world_size=world_size,
199+
tables=tables,
200+
dedup_tables=dedup_tables,
201+
weighted_tables=weighted_tables or [],
202+
num_float_features=num_float_features,
203+
variable_batch_size=variable_batch_size,
204+
batch_size=batch_size,
205+
long_indices=long_indices,
206+
input_type=input_type,
207+
)
208+
)
209+
else:
210+
for _ in range(num_inputs):
211+
inputs.append(
212+
cast(ModelInputCallable, generate)(
194213
world_size=world_size,
195214
tables=tables,
196215
dedup_tables=dedup_tables,
@@ -201,7 +220,6 @@ def gen_model_and_input(
201220
long_indices=long_indices,
202221
)
203222
)
204-
)
205223
return (model, inputs)
206224

207225

@@ -287,6 +305,7 @@ def sharding_single_rank_test(
287305
global_constant_batch: bool = False,
288306
world_size_2D: Optional[int] = None,
289307
node_group_size: Optional[int] = None,
308+
input_type: str = "kjt", # "kjt" or "td"
290309
) -> None:
291310

292311
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
@@ -310,6 +329,7 @@ def sharding_single_rank_test(
310329
batch_size=batch_size,
311330
feature_processor_modules=feature_processor_modules,
312331
global_constant_batch=global_constant_batch,
332+
input_type=input_type,
313333
)
314334
global_model = global_model.to(ctx.device)
315335
global_input = inputs[0][0].to(ctx.device)

torchrec/distributed/tests/test_sequence_model_parallel.py

+41
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,44 @@ def _test_sharding(
376376
variable_batch_per_feature=variable_batch_per_feature,
377377
global_constant_batch=True,
378378
)
379+
380+
381+
@skip_if_asan_class
382+
class TDSequenceModelParallelTest(SequenceModelParallelTest):
383+
384+
def test_sharding_variable_batch(self) -> None:
385+
pass
386+
387+
def _test_sharding(
388+
self,
389+
sharders: List[TestEmbeddingCollectionSharder],
390+
backend: str = "gloo",
391+
world_size: int = 2,
392+
local_size: Optional[int] = None,
393+
constraints: Optional[Dict[str, ParameterConstraints]] = None,
394+
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
395+
qcomms_config: Optional[QCommsConfig] = None,
396+
apply_optimizer_in_backward_config: Optional[
397+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
398+
] = None,
399+
variable_batch_size: bool = False,
400+
variable_batch_per_feature: bool = False,
401+
) -> None:
402+
self._run_multi_process_test(
403+
callable=sharding_single_rank_test,
404+
world_size=world_size,
405+
local_size=local_size,
406+
model_class=model_class,
407+
tables=self.tables,
408+
embedding_groups=self.embedding_groups,
409+
sharders=sharders,
410+
optim=EmbOptimType.EXACT_SGD,
411+
backend=backend,
412+
constraints=constraints,
413+
qcomms_config=qcomms_config,
414+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
415+
variable_batch_size=variable_batch_size,
416+
variable_batch_per_feature=variable_batch_per_feature,
417+
global_constant_batch=True,
418+
input_type="td",
419+
)

torchrec/modules/embedding_modules.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def __init__( # noqa C901
456456

457457
def forward(
458458
self,
459-
features: KeyedJaggedTensor,
459+
features: Union[KeyedJaggedTensor, TensorDict],
460460
) -> Dict[str, JaggedTensor]:
461461
"""
462462
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
@@ -470,7 +470,10 @@ def forward(
470470
"""
471471

472472
feature_embeddings: Dict[str, JaggedTensor] = {}
473-
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
473+
if isinstance(features, KeyedJaggedTensor):
474+
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
475+
else:
476+
jt_dict = features
474477
for i, emb_module in enumerate(self.embeddings.values()):
475478
feature_names = self._feature_names[i]
476479
embedding_names = self._embedding_names_by_table[i]
@@ -483,6 +486,7 @@ def forward(
483486
feature_embeddings[embedding_name] = JaggedTensor(
484487
values=lookup,
485488
lengths=f.lengths(),
489+
offsets=f.offsets() if isinstance(features, TensorDict) else None,
486490
weights=f.values() if self._need_indices else None,
487491
)
488492
return feature_embeddings

0 commit comments

Comments
 (0)