Skip to content

Commit f22736d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add NJT/TD support for EC (#2596)
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
1 parent 917690c commit f22736d

File tree

4 files changed

+105
-20
lines changed

4 files changed

+105
-20
lines changed

torchrec/distributed/embedding.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
import torch
29+
from tensordict import TensorDict
2930
from torch import distributed as dist, nn
3031
from torch.autograd.profiler import record_function
3132
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
@@ -90,20 +91,14 @@
9091
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9192
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9293
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
94+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9395

9496
try:
9597
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
9698
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
9799
except OSError:
98100
pass
99101

100-
try:
101-
from tensordict import TensorDict
102-
except ImportError:
103-
104-
class TensorDict:
105-
pass
106-
107102

108103
logger: logging.Logger = logging.getLogger(__name__)
109104

@@ -1205,25 +1200,50 @@ def _compute_sequence_vbe_context(
12051200
def input_dist(
12061201
self,
12071202
ctx: EmbeddingCollectionContext,
1208-
features: KeyedJaggedTensor,
1203+
features: TypeUnion[KeyedJaggedTensor, TensorDict],
12091204
) -> Awaitable[Awaitable[KJTList]]:
1205+
# torch.distributed.breakpoint()
1206+
feature_keys = list(features.keys()) # pyre-ignore[6]
12101207
if self._has_uninitialized_input_dist:
1211-
self._create_input_dist(input_feature_names=features.keys())
1208+
self._create_input_dist(input_feature_names=feature_keys)
12121209
self._has_uninitialized_input_dist = False
12131210
with torch.no_grad():
12141211
unpadded_features = None
1215-
if features.variable_stride_per_key():
1212+
if (
1213+
isinstance(features, KeyedJaggedTensor)
1214+
and features.variable_stride_per_key()
1215+
):
12161216
unpadded_features = features
12171217
features = pad_vbe_kjt_lengths(unpadded_features)
12181218

1219-
if self._features_order:
1219+
if isinstance(features, KeyedJaggedTensor) and self._features_order:
12201220
features = features.permute(
12211221
self._features_order,
12221222
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
12231223
# but got `TypeUnion[Module, Tensor]`.
12241224
self._features_order_tensor,
12251225
)
1226-
features_by_shards = features.split(self._feature_splits)
1226+
1227+
if isinstance(features, KeyedJaggedTensor):
1228+
features_by_shards = features.split(self._feature_splits)
1229+
else: # TensorDict
1230+
feature_names = (
1231+
[feature_keys[i] for i in self._features_order]
1232+
if self._features_order # empty features_order means no reordering
1233+
else feature_keys
1234+
)
1235+
feature_names = [name.split("@")[0] for name in feature_names]
1236+
feature_name_by_sharding_types: List[List[str]] = []
1237+
start = 0
1238+
for length in self._feature_splits:
1239+
feature_name_by_sharding_types.append(
1240+
feature_names[start : start + length]
1241+
)
1242+
start += length
1243+
features_by_shards = [
1244+
maybe_td_to_kjt(features, names)
1245+
for names in feature_name_by_sharding_types
1246+
]
12271247
if self._use_index_dedup:
12281248
features_by_shards = self._dedup_indices(ctx, features_by_shards)
12291249

torchrec/distributed/test_utils/test_sharding.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def gen_model_and_input(
147147
long_indices: bool = True,
148148
global_constant_batch: bool = False,
149149
num_inputs: int = 1,
150+
input_type: str = "kjt", # "kjt" or "td"
150151
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
151152
torch.manual_seed(0)
152153
if dedup_feature_names:
@@ -177,9 +178,9 @@ def gen_model_and_input(
177178
feature_processor_modules=feature_processor_modules,
178179
)
179180
inputs = []
180-
for _ in range(num_inputs):
181-
inputs.append(
182-
(
181+
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
182+
for _ in range(num_inputs):
183+
inputs.append(
183184
cast(VariableBatchModelInputCallable, generate)(
184185
average_batch_size=batch_size,
185186
world_size=world_size,
@@ -188,8 +189,26 @@ def gen_model_and_input(
188189
weighted_tables=weighted_tables or [],
189190
global_constant_batch=global_constant_batch,
190191
)
191-
if generate == ModelInput.generate_variable_batch_input
192-
else cast(ModelInputCallable, generate)(
192+
)
193+
elif generate == ModelInput.generate:
194+
for _ in range(num_inputs):
195+
inputs.append(
196+
ModelInput.generate(
197+
world_size=world_size,
198+
tables=tables,
199+
dedup_tables=dedup_tables,
200+
weighted_tables=weighted_tables or [],
201+
num_float_features=num_float_features,
202+
variable_batch_size=variable_batch_size,
203+
batch_size=batch_size,
204+
long_indices=long_indices,
205+
input_type=input_type,
206+
)
207+
)
208+
else:
209+
for _ in range(num_inputs):
210+
inputs.append(
211+
cast(ModelInputCallable, generate)(
193212
world_size=world_size,
194213
tables=tables,
195214
dedup_tables=dedup_tables,
@@ -200,7 +219,6 @@ def gen_model_and_input(
200219
long_indices=long_indices,
201220
)
202221
)
203-
)
204222
return (model, inputs)
205223

206224

@@ -286,6 +304,7 @@ def sharding_single_rank_test(
286304
global_constant_batch: bool = False,
287305
world_size_2D: Optional[int] = None,
288306
node_group_size: Optional[int] = None,
307+
input_type: str = "kjt", # "kjt" or "td"
289308
) -> None:
290309
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
291310
# Generate model & inputs.
@@ -308,6 +327,7 @@ def sharding_single_rank_test(
308327
batch_size=batch_size,
309328
feature_processor_modules=feature_processor_modules,
310329
global_constant_batch=global_constant_batch,
330+
input_type=input_type,
311331
)
312332
global_model = global_model.to(ctx.device)
313333
global_input = inputs[0][0].to(ctx.device)

torchrec/distributed/tests/test_sequence_model_parallel.py

+41
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,44 @@ def _test_sharding(
376376
variable_batch_per_feature=variable_batch_per_feature,
377377
global_constant_batch=True,
378378
)
379+
380+
381+
@skip_if_asan_class
382+
class TDSequenceModelParallelTest(SequenceModelParallelTest):
383+
384+
def test_sharding_variable_batch(self) -> None:
385+
pass
386+
387+
def _test_sharding(
388+
self,
389+
sharders: List[TestEmbeddingCollectionSharder],
390+
backend: str = "gloo",
391+
world_size: int = 2,
392+
local_size: Optional[int] = None,
393+
constraints: Optional[Dict[str, ParameterConstraints]] = None,
394+
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
395+
qcomms_config: Optional[QCommsConfig] = None,
396+
apply_optimizer_in_backward_config: Optional[
397+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
398+
] = None,
399+
variable_batch_size: bool = False,
400+
variable_batch_per_feature: bool = False,
401+
) -> None:
402+
self._run_multi_process_test(
403+
callable=sharding_single_rank_test,
404+
world_size=world_size,
405+
local_size=local_size,
406+
model_class=model_class,
407+
tables=self.tables,
408+
embedding_groups=self.embedding_groups,
409+
sharders=sharders,
410+
optim=EmbOptimType.EXACT_SGD,
411+
backend=backend,
412+
constraints=constraints,
413+
qcomms_config=qcomms_config,
414+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
415+
variable_batch_size=variable_batch_size,
416+
variable_batch_per_feature=variable_batch_per_feature,
417+
global_constant_batch=True,
418+
input_type="td",
419+
)

torchrec/modules/embedding_modules.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def __init__(
219219
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
220220
self.reset_parameters()
221221

222-
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
222+
def forward(
223+
self,
224+
features: KeyedJaggedTensor, # can also take TensorDict as input
225+
) -> KeyedTensor:
223226
"""
224227
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
225228
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
@@ -450,7 +453,7 @@ def __init__( # noqa C901
450453

451454
def forward(
452455
self,
453-
features: KeyedJaggedTensor,
456+
features: KeyedJaggedTensor, # can also take TensorDict as input
454457
) -> Dict[str, JaggedTensor]:
455458
"""
456459
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
@@ -463,6 +466,7 @@ def forward(
463466
Dict[str, JaggedTensor]
464467
"""
465468

469+
features = maybe_td_to_kjt(features, None)
466470
feature_embeddings: Dict[str, JaggedTensor] = {}
467471
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
468472
for i, emb_module in enumerate(self.embeddings.values()):

0 commit comments

Comments
 (0)