Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 859c0be

Browse files
TroyGardenfacebook-github-bot
authored andcommittedFeb 26, 2025·
skip tests with internal/external discrepancy (pytorch#2759)
Summary: # context * in torchrec github (OSS env) a few tests are [failing](https://github.com/pytorch/torchrec/actions/runs/13449271251/job/37580767712) * however, these tests pass internally due to different set up * torch.export uses training ir externally but inference ir internally * dlrm transformer tests use random.seed(0) to generate initial weights and the numeric values might be different internally and externally Reviewed By: dstaay-fb, iamzainhuda Differential Revision: D69996988
1 parent 856ff3c commit 859c0be

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed
 

‎torchrec/inference/inference_legacy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
- `examples/dlrm/inference/dlrm_predict.py`: this shows how to use `PredictModule` and `PredictFactory` based on an existing model.
2525
"""
2626

27-
from . import model_packager, modules # noqa # noqa
27+
from . import model_packager # noqa

‎torchrec/inference/tests/test_inference.py

+4
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,7 @@ def test_fused_params_overwrite(self) -> None:
410410

411411
# Make sure that overwrite of ebc_fused_params is not reflected in ec_fused_params
412412
self.assertEqual(ec_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL], orig_value)
413+
414+
# change it back to the original value because it modifies the global variable
415+
# otherwise it will affect other tests
416+
ebc_fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = orig_value

‎torchrec/ir/tests/test_serializer.py

+5
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,12 @@ def test_serialize_deserialize_ebc(self) -> None:
253253
self.assertEqual(deserialized.shape, orginal.shape)
254254
self.assertTrue(torch.allclose(deserialized, orginal))
255255

256+
@unittest.skipIf(
257+
torch.cuda.device_count() == 0,
258+
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
259+
)
256260
def test_dynamic_shape_ebc(self) -> None:
261+
# TODO: https://fb.workplace.com/groups/1028545332188949/permalink/1138699244506890/
257262
model = self.generate_model()
258263
feature1 = KeyedJaggedTensor.from_offsets_sync(
259264
keys=["f1", "f2", "f3"],

‎torchrec/models/experimental/test_transformerdlrm.py

+12
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def test_larger(self) -> None:
6161
concat_dense = inter_arch(dense_features, sparse_features)
6262
self.assertEqual(concat_dense.size(), (B, D * (F + 1)))
6363

64+
@unittest.skipIf(
65+
torch.cuda.device_count() == 0,
66+
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
67+
)
6468
def test_correctness(self) -> None:
6569
D = 4
6670
B = 3
@@ -165,6 +169,10 @@ def test_correctness(self) -> None:
165169
)
166170
)
167171

172+
@unittest.skipIf(
173+
torch.cuda.device_count() == 0,
174+
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
175+
)
168176
def test_numerical_stability(self) -> None:
169177
D = 4
170178
B = 3
@@ -194,6 +202,10 @@ def test_numerical_stability(self) -> None:
194202

195203

196204
class DLRMTransformerTest(unittest.TestCase):
205+
@unittest.skipIf(
206+
torch.cuda.device_count() == 0,
207+
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
208+
)
197209
def test_basic(self) -> None:
198210
torch.manual_seed(0)
199211
B = 2

0 commit comments

Comments
 (0)
Please sign in to comment.