Skip to content

Commit 7c27553

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix failed test_kjt_bucketize_before_all2all_cpu (#2689)
Summary: # context * found a test failure from OSS [test run](https://github.com/pytorch/torchrec/actions/runs/12816026713/job/35736016089): P1714445461 * the issue is a recent change (D65912888) incorrectly calling the `_fx_wrap_tensor_to_device_dtype` function ``` block_bucketize_pos=( _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) if block_bucketize_row_pos is not None else None ), ``` where `block_bucketize_row_pos: List[torch.tensor]`, but the function only accepts torch.Tensor ``` torch.fx.wrap def _fx_wrap_tensor_to_device_dtype( t: torch.Tensor, tensor_device_dtype: torch.Tensor ) -> torch.Tensor: return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) ``` * the fix is supposed to be straightforward to apply a list-comprehension over the function ``` block_bucketize_pos=( [ _fx_wrap_tensor_to_device_dtype(pos, kjt.lengths()) # <---- pay attention here, kjt.lengths() for pos in block_bucketize_row_pos ] ``` * according to the previous comments, the `block_bucketize_pos`'s `dtype` should be the same as `kjt._length`, however, it triggers the following error {F1974430883} * according to the operator implementation ([codepointer](https://fburl.com/code/9gyyl8h4)), the `block_bucketize_pos` should have the same dtype as `kjt._values`. length has a type name of `offset_t`, values has a type name of `index_t`, the same as `block_bucketize_pos`. Reviewed By: dstaay-fb Differential Revision: D68358894
1 parent 9dfdfb8 commit 7c27553

File tree

2 files changed

+18
-100
lines changed

2 files changed

+18
-100
lines changed

torchrec/distributed/embedding_sharding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all(
274274
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
275275
max_B=_fx_wrap_max_B(kjt),
276276
block_bucketize_pos=(
277-
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
277+
[
278+
_fx_wrap_tensor_to_device_dtype(pos, kjt.values())
279+
for pos in block_bucketize_row_pos
280+
]
278281
if block_bucketize_row_pos is not None
279282
else None
280283
),

torchrec/distributed/tests/test_utils.py

+14-99
Original file line numberDiff line numberDiff line change
@@ -263,98 +263,6 @@ def block_bucketize_ref(
263263

264264

265265
class KJTBucketizeTest(unittest.TestCase):
266-
@unittest.skipIf(
267-
torch.cuda.device_count() <= 0,
268-
"CUDA is not available",
269-
)
270-
# pyre-ignore[56]
271-
@given(
272-
index_type=st.sampled_from([torch.int, torch.long]),
273-
offset_type=st.sampled_from([torch.int, torch.long]),
274-
world_size=st.integers(1, 129),
275-
num_features=st.integers(1, 15),
276-
batch_size=st.integers(1, 15),
277-
)
278-
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
279-
def test_kjt_bucketize_before_all2all(
280-
self,
281-
index_type: torch.dtype,
282-
offset_type: torch.dtype,
283-
world_size: int,
284-
num_features: int,
285-
batch_size: int,
286-
) -> None:
287-
MAX_BATCH_SIZE = 15
288-
MAX_LENGTH = 10
289-
# max number of rows needed for a given feature to have unique row index
290-
MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE
291-
292-
lengths_list = [
293-
random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size)
294-
]
295-
keys_list = [f"feature_{i}" for i in range(num_features)]
296-
# for each feature, generate unrepeated row indices
297-
indices_lists = [
298-
random.sample(
299-
range(MAX_ROW_COUNT),
300-
# number of indices needed is the length sum of all batches for a feature
301-
sum(
302-
lengths_list[
303-
feature_offset * batch_size : (feature_offset + 1) * batch_size
304-
]
305-
),
306-
)
307-
for feature_offset in range(num_features)
308-
]
309-
indices_list = list(itertools.chain(*indices_lists))
310-
311-
weights_list = [random.randint(1, 100) for _ in range(len(indices_list))]
312-
313-
# for each feature, calculate the minimum block size needed to
314-
# distribute all rows to the available trainers
315-
block_sizes_list = [
316-
(
317-
math.ceil((max(feature_indices_list) + 1) / world_size)
318-
if feature_indices_list
319-
else 1
320-
)
321-
for feature_indices_list in indices_lists
322-
]
323-
324-
kjt = KeyedJaggedTensor(
325-
keys=keys_list,
326-
lengths=torch.tensor(lengths_list, dtype=offset_type)
327-
.view(num_features * batch_size)
328-
.cuda(),
329-
values=torch.tensor(indices_list, dtype=index_type).cuda(),
330-
weights=torch.tensor(weights_list, dtype=torch.float).cuda(),
331-
)
332-
"""
333-
each entry in block_sizes identifies how many hashes for each feature goes
334-
to every rank; we have three featues in `self.features`
335-
"""
336-
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()
337-
338-
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
339-
kjt=kjt,
340-
num_buckets=world_size,
341-
block_sizes=block_sizes,
342-
)
343-
344-
expected_block_bucketized_kjt = block_bucketize_ref(
345-
kjt,
346-
world_size,
347-
block_sizes,
348-
)
349-
350-
self.assertTrue(
351-
keyed_jagged_tensor_equals(
352-
block_bucketized_kjt,
353-
expected_block_bucketized_kjt,
354-
is_pooled_features=True,
355-
)
356-
)
357-
358266
# pyre-ignore[56]
359267
@given(
360268
index_type=st.sampled_from([torch.int, torch.long]),
@@ -363,16 +271,20 @@ def test_kjt_bucketize_before_all2all(
363271
num_features=st.integers(1, 15),
364272
batch_size=st.integers(1, 15),
365273
variable_bucket_pos=st.booleans(),
274+
device=st.sampled_from(
275+
["cpu"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
276+
),
366277
)
367-
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
368-
def test_kjt_bucketize_before_all2all_cpu(
278+
@settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None)
279+
def test_kjt_bucketize_before_all2all(
369280
self,
370281
index_type: torch.dtype,
371282
offset_type: torch.dtype,
372283
world_size: int,
373284
num_features: int,
374285
batch_size: int,
375286
variable_bucket_pos: bool,
287+
device: str,
376288
) -> None:
377289
MAX_BATCH_SIZE = 15
378290
MAX_LENGTH = 10
@@ -423,17 +335,17 @@ def test_kjt_bucketize_before_all2all_cpu(
423335

424336
kjt = KeyedJaggedTensor(
425337
keys=keys_list,
426-
lengths=torch.tensor(lengths_list, dtype=offset_type).view(
338+
lengths=torch.tensor(lengths_list, dtype=offset_type, device=device).view(
427339
num_features * batch_size
428340
),
429-
values=torch.tensor(indices_list, dtype=index_type),
430-
weights=torch.tensor(weights_list, dtype=torch.float),
341+
values=torch.tensor(indices_list, dtype=index_type, device=device),
342+
weights=torch.tensor(weights_list, dtype=torch.float, device=device),
431343
)
432344
"""
433345
each entry in block_sizes identifies how many hashes for each feature goes
434346
to every rank; we have three featues in `self.features`
435347
"""
436-
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
348+
block_sizes = torch.tensor(block_sizes_list, dtype=index_type, device=device)
437349
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
438350
kjt=kjt,
439351
num_buckets=world_size,
@@ -442,7 +354,10 @@ def test_kjt_bucketize_before_all2all_cpu(
442354
)
443355

444356
expected_block_bucketized_kjt = block_bucketize_ref(
445-
kjt, world_size, block_sizes, "cpu"
357+
kjt,
358+
world_size,
359+
block_sizes,
360+
device,
446361
)
447362

448363
self.assertTrue(

0 commit comments

Comments
 (0)