Skip to content

Commit 4fbffeb

Browse files
emlinfacebook-github-bot
authored andcommitted
Fix wrong return type of QuantManagedCollisionEmbeddingCollection (#2830)
Summary: Pull Request resolved: #2830 this forward will be called after diff D68991644, so we need to make sure the return type is consistent with ManagedCollisionEmbeddingCollection Reviewed By: kausv, jma99fb Differential Revision: D71250093 fbshipit-source-id: a5c91a2f703fe1803ebcc54be9cb3634bcec1354
1 parent cc9064b commit 4fbffeb

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

torchrec/distributed/tests/test_infer_shardings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2229,7 +2229,7 @@ def test_sharded_quant_mc_ec_rw(
22292229
sharded_model.load_state_dict(quant_model.state_dict())
22302230
sharded_output = sharded_model(*inputs[0])
22312231

2232-
assert_close(non_sharded_output, sharded_output)
2232+
assert_close(non_sharded_output[0], sharded_output[0])
22332233
gm: torch.fx.GraphModule = symbolic_trace(
22342234
sharded_model,
22352235
leaf_modules=[
@@ -2242,7 +2242,7 @@ def test_sharded_quant_mc_ec_rw(
22422242
gm_script = torch.jit.script(gm)
22432243
print(f"gm_script:\n{gm_script}")
22442244
gm_script_output = gm_script(*inputs[0])
2245-
assert_close(sharded_output, gm_script_output)
2245+
assert_close(sharded_output[0], gm_script_output[0])
22462246

22472247
@unittest.skipIf(
22482248
torch.cuda.device_count() <= 1,

torchrec/quant/embedding_modules.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,9 @@ def forward(
10931093
) -> Dict[str, JaggedTensor]:
10941094
features = self._managed_collision_collection(features)
10951095

1096-
return super().forward(features)
1096+
# mcec expects Tuple return type
1097+
# pyre-ignore
1098+
return (super().forward(features),)
10971099

10981100
def _get_name(self) -> str:
10991101
return "QuantManagedCollisionEmbeddingCollection"

0 commit comments

Comments
 (0)