File tree 2 files changed +30
-4
lines changed
2 files changed +30
-4
lines changed Original file line number Diff line number Diff line change 11
11
import contextlib
12
12
import logging
13
13
import random
14
+ import sys
14
15
import time
15
16
from typing import Any , Generator
16
17
17
18
import torch
18
19
19
20
logger : logging .Logger = logging .getLogger (__name__ )
20
21
21
- torch .ops .load_library ("//caffe2/torch/fb/retrieval:faster_hash_cpu" )
22
- torch .ops .load_library ("//caffe2/torch/fb/retrieval:faster_hash_cuda" )
22
+
23
+ def load_required_libraries () -> bool :
24
+ try :
25
+ torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
26
+ torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
27
+ return True
28
+ except Exception as e :
29
+ logger .error (f"Failed to load faster_hash libraries, skipping test: { e } " )
30
+ return False
23
31
24
32
25
33
@contextlib .contextmanager
@@ -347,6 +355,9 @@ def _run_benchmark_with_eviction(
347
355
348
356
349
357
if __name__ == "__main__" :
358
+ if not load_required_libraries ():
359
+ print ("Skipping test because libraries were not loaded" )
360
+ sys .exit (0 )
350
361
logger .setLevel (logging .INFO )
351
362
handler = logging .StreamHandler ()
352
363
handler .setLevel (logging .INFO )
Original file line number Diff line number Diff line change 13
13
import torch
14
14
from hypothesis import settings
15
15
16
- torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
17
- torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
16
+
17
+ def load_required_libraries () -> bool :
18
+ try :
19
+ torch .ops .load_library ("//torchrec/ops:faster_hash_cpu" )
20
+ torch .ops .load_library ("//torchrec/ops:faster_hash_cuda" )
21
+ return True
22
+ except Exception as e :
23
+ print (f"Skipping tests because libraries were not loaded: { e } " )
24
+ return False
18
25
19
26
20
27
class HashZchKernelEvictionPolicy (IntEnum ):
@@ -23,6 +30,14 @@ class HashZchKernelEvictionPolicy(IntEnum):
23
30
24
31
25
32
class FasterHashTest (unittest .TestCase ):
33
+
34
+ @classmethod
35
+ def setUpClass (cls ):
36
+ if not load_required_libraries ():
37
+ raise unittest .SkipTest (
38
+ "Libraries not loaded, skipping all tests in MyTestCase"
39
+ )
40
+
26
41
@unittest .skipIf (not torch .cuda .is_available (), "Skip when CUDA is not available" )
27
42
@settings (deadline = None )
28
43
def test_simple_zch_no_evict (self ) -> None :
You can’t perform that action at this time.
0 commit comments