17
17
from torchrec .sparse .tensor_dict import maybe_td_to_kjt
18
18
19
19
20
- class TestTensorDIct (unittest .TestCase ):
21
- @given (device_str = st .sampled_from (["cpu" , "cuda" , "meta" ]))
22
- @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
20
+ class TestTensorDict (unittest .TestCase ):
23
21
# pyre-ignore[56]
24
- @unittest .skipIf (
25
- torch .cuda .device_count () <= 0 ,
26
- "CUDA is not available" ,
22
+ @given (
23
+ device_str = st .sampled_from (
24
+ ["cpu" , "meta" ] + (["cuda" ] if torch .cuda .is_available () else [])
25
+ )
27
26
)
27
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
28
28
def test_kjt_input (self , device_str : str ) -> None :
29
29
device = torch .device (device_str )
30
30
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
@@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
36
36
features = maybe_td_to_kjt (kjt )
37
37
self .assertEqual (features , kjt )
38
38
39
- @given (device_str = st .sampled_from (["cpu" , "cuda" , "meta" ]))
40
- @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
41
39
# pyre-ignore[56]
42
- @unittest .skipIf (
43
- torch .cuda .device_count () <= 0 ,
44
- "CUDA is not available" ,
40
+ @given (
41
+ device_str = st .sampled_from (
42
+ ["cpu" , "meta" ] + (["cuda" ] if torch .cuda .is_available () else [])
43
+ )
45
44
)
45
+ @settings (verbosity = Verbosity .verbose , max_examples = 5 , deadline = None )
46
46
def test_td_kjt (self , device_str : str ) -> None :
47
47
device = torch .device (device_str )
48
48
values = torch .tensor ([0 , 1 , 2 , 3 , 2 , 3 , 4 ], device = device )
0 commit comments