File tree Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Original file line number Diff line number Diff line change
1
+ import unittest
2
+
3
+ from benchmarks .torchbench_model import TorchBenchModel
4
+ from benchmarks .benchmark_experiment import BenchmarkExperiment
5
+
6
+
7
+ class MockExperiment :
8
+
9
+ def __init__ (self , accelerator , test ):
10
+ self .accelerator = accelerator
11
+ self .test = "train"
12
+
13
+
14
+ class TorchBenchModelTest (unittest .TestCase ):
15
+
16
+ def test_do_not_use_amp_on_cpu_and_warns (self ):
17
+ experiment = MockExperiment ("cpu" , "train" )
18
+ model = TorchBenchModel ("torchbench or other" , "super_deep_model" ,
19
+ experiment )
20
+ with self .assertLogs ('benchmarks.torchbench_model' , level = 'WARNING' ) as cm :
21
+ use_amp = model .use_amp ()
22
+ self .assertEqual (len (cm .output ), 1 )
23
+ self .assertIn ("AMP is not used" , cm .output [0 ])
24
+ self .assertFalse (use_amp )
25
+
26
+ def test_use_amp_on_cuda (self ):
27
+ experiment = MockExperiment ("cuda" , "train" )
28
+ model = TorchBenchModel ("torchbench or other" , "super_deep_model" ,
29
+ experiment )
30
+ self .assertTrue (model .use_amp ())
31
+
32
+
33
+ if __name__ == '__main__' :
34
+ unittest .main ()
You can’t perform that action at this time.
0 commit comments