Skip to content

Commit 3441ac3

Browse files
committed
Revert "Revert "add NJT/TD support in test data generator (#2528)""
This reverts commit e5e2565.
1 parent e5e2565 commit 3441ac3

10 files changed

+121
-44
lines changed

install-requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fbgemm-gpu
2+
tensordict
23
torchmetrics==1.0.3
34
tqdm
45
pyre-extensions

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ numpy
77
pandas
88
pyre-extensions
99
scikit-build
10+
tensordict
1011
torchmetrics==1.0.3
1112
torchx
1213
tqdm

torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#!/usr/bin/env python3
1111

12+
from typing import Dict, List
13+
1214
import click
1315

1416
import torch
@@ -82,9 +84,10 @@ def op_bench(
8284
)
8385

8486
def _func_to_benchmark(
85-
kjt: KeyedJaggedTensor,
87+
kjts: List[Dict[str, KeyedJaggedTensor]],
8688
model: torch.nn.Module,
8789
) -> torch.Tensor:
90+
kjt = kjts[0]["feature"]
8891
return model.forward(kjt.values(), kjt.offsets())
8992

9093
# breakpoint() # import fbvscode; fbvscode.set_trace()
@@ -108,8 +111,8 @@ def _func_to_benchmark(
108111

109112
result = benchmark_func(
110113
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
111-
bench_inputs=inputs, # pyre-ignore
112-
prof_inputs=inputs, # pyre-ignore
114+
bench_inputs=[{"feature": inputs}],
115+
prof_inputs=[{"feature": inputs}],
113116
num_benchmarks=10,
114117
num_profiles=10,
115118
profile_dir=".",

torchrec/distributed/benchmark/benchmark_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,14 @@ def get_inputs(
374374

375375
if train:
376376
sparse_features_by_rank = [
377-
model_input.idlist_features for model_input in model_input_by_rank
377+
model_input.idlist_features
378+
for model_input in model_input_by_rank
379+
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
378380
]
379381
inputs_batch.append(sparse_features_by_rank)
380382
else:
381383
sparse_features = model_input_by_rank[0].idlist_features
384+
assert isinstance(sparse_features, KeyedJaggedTensor)
382385
inputs_batch.append([sparse_features])
383386

384387
# Transpose if train, as inputs_by_rank is currently in [B X R] format

torchrec/distributed/test_utils/infer_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def model_input_to_forward_args_kjt(
264264
Optional[torch.Tensor],
265265
]:
266266
kjt = mi.idlist_features
267+
assert isinstance(kjt, KeyedJaggedTensor)
267268
return (
268269
kjt._keys,
269270
kjt._values,
@@ -291,7 +292,8 @@ def model_input_to_forward_args(
291292
]:
292293
idlist_kjt = mi.idlist_features
293294
idscore_kjt = mi.idscore_features
294-
assert idscore_kjt is not None
295+
assert isinstance(idlist_kjt, KeyedJaggedTensor)
296+
assert isinstance(idscore_kjt, KeyedJaggedTensor)
295297
return (
296298
mi.float_features,
297299
idlist_kjt._keys,

torchrec/distributed/test_utils/test_model.py

+87-36
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.nn as nn
17+
from tensordict import TensorDict
1718
from torchrec.distributed.embedding_tower_sharding import (
1819
EmbeddingTowerCollectionSharder,
1920
EmbeddingTowerSharder,
@@ -46,8 +47,8 @@
4647
@dataclass
4748
class ModelInput(Pipelineable):
4849
float_features: torch.Tensor
49-
idlist_features: KeyedJaggedTensor
50-
idscore_features: Optional[KeyedJaggedTensor]
50+
idlist_features: Union[KeyedJaggedTensor, TensorDict]
51+
idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]]
5152
label: torch.Tensor
5253

5354
@staticmethod
@@ -76,11 +77,13 @@ def generate(
7677
randomize_indices: bool = True,
7778
device: Optional[torch.device] = None,
7879
max_feature_lengths: Optional[List[int]] = None,
80+
input_type: str = "kjt",
7981
) -> Tuple["ModelInput", List["ModelInput"]]:
8082
"""
8183
Returns a global (single-rank training) batch
8284
and a list of local (multi-rank training) batches of world_size.
8385
"""
86+
8487
batch_size_by_rank = [batch_size] * world_size
8588
if variable_batch_size:
8689
batch_size_by_rank = [
@@ -199,11 +202,26 @@ def _validate_pooling_factor(
199202
)
200203
global_idlist_lengths.append(lengths)
201204
global_idlist_indices.append(indices)
202-
global_idlist_kjt = KeyedJaggedTensor(
203-
keys=idlist_features,
204-
values=torch.cat(global_idlist_indices),
205-
lengths=torch.cat(global_idlist_lengths),
206-
)
205+
206+
if input_type == "kjt":
207+
global_idlist_input = KeyedJaggedTensor(
208+
keys=idlist_features,
209+
values=torch.cat(global_idlist_indices),
210+
lengths=torch.cat(global_idlist_lengths),
211+
)
212+
elif input_type == "td":
213+
dict_of_nt = {
214+
k: torch.nested.nested_tensor_from_jagged(
215+
values=values,
216+
lengths=lengths,
217+
)
218+
for k, values, lengths in zip(
219+
idlist_features, global_idlist_indices, global_idlist_lengths
220+
)
221+
}
222+
global_idlist_input = TensorDict(source=dict_of_nt)
223+
else:
224+
raise ValueError(f"For IdList features, unknown input type {input_type}")
207225

208226
for idx in range(len(idscore_ind_ranges)):
209227
ind_range = idscore_ind_ranges[idx]
@@ -245,16 +263,25 @@ def _validate_pooling_factor(
245263
global_idscore_lengths.append(lengths)
246264
global_idscore_indices.append(indices)
247265
global_idscore_weights.append(weights)
248-
global_idscore_kjt = (
249-
KeyedJaggedTensor(
250-
keys=idscore_features,
251-
values=torch.cat(global_idscore_indices),
252-
lengths=torch.cat(global_idscore_lengths),
253-
weights=torch.cat(global_idscore_weights),
266+
267+
if input_type == "kjt":
268+
global_idscore_input = (
269+
KeyedJaggedTensor(
270+
keys=idscore_features,
271+
values=torch.cat(global_idscore_indices),
272+
lengths=torch.cat(global_idscore_lengths),
273+
weights=torch.cat(global_idscore_weights),
274+
)
275+
if global_idscore_indices
276+
else None
254277
)
255-
if global_idscore_indices
256-
else None
257-
)
278+
elif input_type == "td":
279+
assert (
280+
len(idscore_features) == 0
281+
), "TensorDict does not support weighted features"
282+
global_idscore_input = None
283+
else:
284+
raise ValueError(f"For weighted features, unknown input type {input_type}")
258285

259286
if randomize_indices:
260287
global_float = torch.rand(
@@ -303,36 +330,57 @@ def _validate_pooling_factor(
303330
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
304331
)
305332

306-
local_idlist_kjt = KeyedJaggedTensor(
307-
keys=idlist_features,
308-
values=torch.cat(local_idlist_indices),
309-
lengths=torch.cat(local_idlist_lengths),
310-
)
333+
if input_type == "kjt":
334+
local_idlist_input = KeyedJaggedTensor(
335+
keys=idlist_features,
336+
values=torch.cat(local_idlist_indices),
337+
lengths=torch.cat(local_idlist_lengths),
338+
)
311339

312-
local_idscore_kjt = (
313-
KeyedJaggedTensor(
314-
keys=idscore_features,
315-
values=torch.cat(local_idscore_indices),
316-
lengths=torch.cat(local_idscore_lengths),
317-
weights=torch.cat(local_idscore_weights),
340+
local_idscore_input = (
341+
KeyedJaggedTensor(
342+
keys=idscore_features,
343+
values=torch.cat(local_idscore_indices),
344+
lengths=torch.cat(local_idscore_lengths),
345+
weights=torch.cat(local_idscore_weights),
346+
)
347+
if local_idscore_indices
348+
else None
349+
)
350+
elif input_type == "td":
351+
dict_of_nt = {
352+
k: torch.nested.nested_tensor_from_jagged(
353+
values=values,
354+
lengths=lengths,
355+
)
356+
for k, values, lengths in zip(
357+
idlist_features, local_idlist_indices, local_idlist_lengths
358+
)
359+
}
360+
local_idlist_input = TensorDict(source=dict_of_nt)
361+
assert (
362+
len(idscore_features) == 0
363+
), "TensorDict does not support weighted features"
364+
local_idscore_input = None
365+
366+
else:
367+
raise ValueError(
368+
f"For weighted features, unknown input type {input_type}"
318369
)
319-
if local_idscore_indices
320-
else None
321-
)
322370

323371
local_input = ModelInput(
324372
float_features=global_float[r * batch_size : (r + 1) * batch_size],
325-
idlist_features=local_idlist_kjt,
326-
idscore_features=local_idscore_kjt,
373+
idlist_features=local_idlist_input,
374+
idscore_features=local_idscore_input,
327375
label=global_label[r * batch_size : (r + 1) * batch_size],
328376
)
329377
local_inputs.append(local_input)
330378

331379
return (
332380
ModelInput(
333381
float_features=global_float,
334-
idlist_features=global_idlist_kjt,
335-
idscore_features=global_idscore_kjt,
382+
idlist_features=global_idlist_input,
383+
idscore_features=global_idscore_input,
336384
label=global_label,
337385
),
338386
local_inputs,
@@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
623671

624672
def record_stream(self, stream: torch.Stream) -> None:
625673
self.float_features.record_stream(stream)
626-
self.idlist_features.record_stream(stream)
627-
if self.idscore_features is not None:
674+
if isinstance(self.idlist_features, KeyedJaggedTensor):
675+
self.idlist_features.record_stream(stream)
676+
if isinstance(self.idscore_features, KeyedJaggedTensor):
628677
self.idscore_features.record_stream(stream)
629678
self.label.record_stream(stream)
630679

@@ -1831,6 +1880,8 @@ def forward(self, input: ModelInput) -> ModelInput:
18311880
)
18321881

18331882
# stride will be same but features will be joined
1883+
assert isinstance(modified_input.idlist_features, KeyedJaggedTensor)
1884+
assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor)
18341885
modified_input.idlist_features = KeyedJaggedTensor.concat(
18351886
[modified_input.idlist_features, self._extra_input.idlist_features]
18361887
)

torchrec/distributed/tests/test_infer_shardings.py

+3
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,7 @@ def test_sharded_quant_fp_ebc_tw(
19871987
inputs = []
19881988
for model_input in model_inputs:
19891989
kjt = model_input.idlist_features
1990+
assert isinstance(kjt, KeyedJaggedTensor)
19901991
kjt = kjt.to(local_device)
19911992
weights = torch.rand(
19921993
kjt._values.size(0), dtype=torch.float, device=local_device
@@ -2166,6 +2167,7 @@ def test_sharded_quant_mc_ec_rw(
21662167
inputs = []
21672168
for model_input in model_inputs:
21682169
kjt = model_input.idlist_features
2170+
assert isinstance(kjt, KeyedJaggedTensor)
21692171
kjt = kjt.to(local_device)
21702172
weights = None
21712173
inputs.append(
@@ -2301,6 +2303,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
23012303
)
23022304
inputs = []
23032305
kjt = model_inputs[0].idlist_features
2306+
assert isinstance(kjt, KeyedJaggedTensor)
23042307
kjt = kjt.to(local_device)
23052308
weights = torch.rand(
23062309
kjt._values.size(0), dtype=torch.float, device=local_device

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def _gen_pipelines(
7575
default=100,
7676
help="Total number of sparse embeddings to be used.",
7777
)
78+
@click.option(
79+
"--ratio_features_weighted",
80+
default=0.4,
81+
help="percentage of features weighted vs unweighted",
82+
)
7883
@click.option(
7984
"--dim_emb",
8085
type=int,
@@ -132,6 +137,7 @@ def _gen_pipelines(
132137
def main(
133138
world_size: int,
134139
n_features: int,
140+
ratio_features_weighted: float,
135141
dim_emb: int,
136142
n_batches: int,
137143
batch_size: int,
@@ -149,8 +155,9 @@ def main(
149155
os.environ["MASTER_ADDR"] = str("localhost")
150156
os.environ["MASTER_PORT"] = str(get_free_port())
151157

152-
num_features = n_features // 2
153-
num_weighted_features = n_features // 2
158+
num_weighted_features = int(n_features * ratio_features_weighted)
159+
num_features = n_features - num_weighted_features
160+
154161
tables = [
155162
EmbeddingBagConfig(
156163
num_embeddings=(i + 1) * 1000,
@@ -257,6 +264,7 @@ def _generate_data(
257264
world_size=world_size,
258265
num_float_features=num_float_features,
259266
pooling_avg=pooling_factor,
267+
input_type=input_type,
260268
)[1]
261269
for i in range(num_batches)
262270
]

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,11 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
306306
# `parameters`.
307307
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)
308308

309-
data = [i.idlist_features for i in local_model_inputs]
309+
data = [
310+
i.idlist_features
311+
for i in local_model_inputs
312+
if isinstance(i.idlist_features, KeyedJaggedTensor)
313+
]
310314
dataloader = iter(data)
311315
pipeline = TrainPipelinePT2(
312316
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing

torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def generate_kjt(
169169
randomize_indices=True,
170170
device=device,
171171
)[0]
172+
assert isinstance(global_input.idlist_features, KeyedJaggedTensor)
172173
return global_input.idlist_features
173174

174175

0 commit comments

Comments
 (0)