Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5014024

Browse files
TroyGardenfacebook-github-bot
authored andcommittedMar 7, 2025·
reland D70126859
Summary: # context * previous diff triggered S495021 * the error message is like ``` ModelGenerationPlatformError("AttributeError: '_EmbeddingBagProxy' object has no attribute 'weight'") ``` * this diff works around the `embedding_bag.weight` to access the weight.dtype, instead, using the dtype from the table config. Differential Revision: D70712348
1 parent 592ed93 commit 5014024

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed
 

‎torchrec/distributed/test_utils/test_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ def _validate_pooling_factor(
243243
global_idlist_indices.append(indices)
244244
global_idlist_offsets.append(offsets)
245245

246-
for idx in range(len(idscore_ind_ranges)):
247-
ind_range = idscore_ind_ranges[idx]
246+
for idx, ind_range in enumerate(idscore_ind_ranges):
248247
lengths_ = torch.abs(
249248
torch.randn(batch_size * world_size, device=device)
250249
+ (

‎torchrec/distributed/test_utils/test_sharding.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@
5959
ShardingPlan,
6060
ShardingType,
6161
)
62-
from torchrec.modules.embedding_configs import BaseEmbeddingConfig, EmbeddingBagConfig
62+
from torchrec.modules.embedding_configs import (
63+
BaseEmbeddingConfig,
64+
DataType,
65+
EmbeddingBagConfig,
66+
)
6367
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
6468
from torchrec.optim.optimizers import in_backward_optimizer_filter
6569

@@ -554,9 +558,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
554558
)
555559

556560
# Compare predictions of sharded vs unsharded models.
557-
if qcomms_config is None:
558-
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
559-
else:
561+
if qcomms_config is not None:
560562
# With quantized comms, we can relax constraints a bit
561563
rtol = 0.003
562564
if CommType.FP8 in [
@@ -568,6 +570,18 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
568570
torch.testing.assert_close(
569571
global_pred, torch.cat(all_local_pred), rtol=rtol, atol=atol
570572
)
573+
elif (
574+
weighted_tables is not None
575+
and weighted_tables[0].data_type == DataType.FP16
576+
): # https://www.internalfb.com/intern/diffing/?paste_number=1740410921
577+
torch.testing.assert_close(
578+
global_pred,
579+
torch.cat(all_local_pred),
580+
atol=1e-4, # relaxed atol due to FP16 in weights
581+
rtol=1e-4, # relaxed rtol due to FP16 in weights
582+
)
583+
else:
584+
torch.testing.assert_close(global_pred, torch.cat(all_local_pred))
571585

572586

573587
def create_device_mesh_for_2D(

‎torchrec/modules/embedding_modules.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,20 @@ def forward(
243243
pooled_embeddings: List[torch.Tensor] = []
244244
feature_dict = features.to_dict()
245245
for i, embedding_bag in enumerate(self.embedding_bags.values()):
246+
embedding_config = self._embedding_bag_configs[i]
247+
dtype = (
248+
torch.float32
249+
if embedding_config.data_type == DataType.FP32
250+
else torch.float16
251+
)
246252
for feature_name in self._feature_names[i]:
247253
f = feature_dict[feature_name]
248254
res = embedding_bag(
249255
input=f.values(),
250256
offsets=f.offsets(),
251-
per_sample_weights=f.weights() if self._is_weighted else None,
257+
per_sample_weights=(
258+
f.weights().to(dtype) if self._is_weighted else None
259+
),
252260
).float()
253261
pooled_embeddings.append(res)
254262
return KeyedTensor(

0 commit comments

Comments
 (0)
Please sign in to comment.