Skip to content

Commit 53752ea

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix hypothesis strategy that skips entire test without CUDA (#2690)
Summary: Pull Request resolved: #2690 # context * original implementation will skip the entire test set if there is no cuda available * the actual intention is to loop all devices (cpu, meta, cuda) and only skip cuda if not available. Reviewed By: dstaay-fb Differential Revision: D68373224 fbshipit-source-id: 28c8b12a61213ebfc794b07f51e5eff77f13938a
1 parent d0bf444 commit 53752ea

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

torchrec/sparse/tests/test_tensor_dict.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
1818

1919

20-
class TestTensorDIct(unittest.TestCase):
21-
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
22-
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
20+
class TestTensorDict(unittest.TestCase):
2321
# pyre-ignore[56]
24-
@unittest.skipIf(
25-
torch.cuda.device_count() <= 0,
26-
"CUDA is not available",
22+
@given(
23+
device_str=st.sampled_from(
24+
["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
25+
)
2726
)
27+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
2828
def test_kjt_input(self, device_str: str) -> None:
2929
device = torch.device(device_str)
3030
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
@@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
3636
features = maybe_td_to_kjt(kjt)
3737
self.assertEqual(features, kjt)
3838

39-
@given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
40-
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
4139
# pyre-ignore[56]
42-
@unittest.skipIf(
43-
torch.cuda.device_count() <= 0,
44-
"CUDA is not available",
40+
@given(
41+
device_str=st.sampled_from(
42+
["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
43+
)
4544
)
45+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
4646
def test_td_kjt(self, device_str: str) -> None:
4747
device = torch.device(device_str)
4848
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)

0 commit comments

Comments
 (0)