File tree Expand file tree Collapse file tree 1 file changed +33
-0
lines changed Expand file tree Collapse file tree 1 file changed +33
-0
lines changed Original file line number Diff line number Diff line change
1
+ import unittest
2
+
3
+ from benchmarks .torchbench_model import TorchBenchModel
4
+
5
+
6
+ class MockExperiment :
7
+
8
+ def __init__ (self , accelerator , test ):
9
+ self .accelerator = accelerator
10
+ self .test = "train"
11
+
12
+
13
+ class TorchBenchModelTest (unittest .TestCase ):
14
+
15
+ def test_do_not_use_amp_on_cpu_and_warns (self ):
16
+ experiment = MockExperiment ("cpu" , "train" )
17
+ model = TorchBenchModel ("torchbench or other" , "super_deep_model" ,
18
+ experiment )
19
+ with self .assertLogs ('benchmarks.torchbench_model' , level = 'WARNING' ) as cm :
20
+ use_amp = model .use_amp ()
21
+ self .assertEqual (len (cm .output ), 1 )
22
+ self .assertIn ("AMP is not used" , cm .output [0 ])
23
+ self .assertFalse (use_amp )
24
+
25
+ def test_use_amp_on_cuda (self ):
26
+ experiment = MockExperiment ("cuda" , "train" )
27
+ model = TorchBenchModel ("torchbench or other" , "super_deep_model" ,
28
+ experiment )
29
+ self .assertTrue (model .use_amp ())
30
+
31
+
32
+ if __name__ == '__main__' :
33
+ unittest .main ()
You can’t perform that action at this time.
0 commit comments