Skip to content

Commit faef9f0

Browse files
committed
add unit tests
1 parent 79f66aa commit faef9f0

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
3+
from benchmarks.torchbench_model import TorchBenchModel
4+
from benchmarks.benchmark_experiment import BenchmarkExperiment
5+
6+
7+
class MockExperiment:
8+
9+
def __init__(self, accelerator, test):
10+
self.accelerator = accelerator
11+
self.test = "train"
12+
13+
14+
class TorchBenchModelTest(unittest.TestCase):
15+
16+
def test_do_not_use_amp_on_cpu_and_warns(self):
17+
experiment = MockExperiment("cpu", "train")
18+
model = TorchBenchModel("torchbench or other", "super_deep_model",
19+
experiment)
20+
with self.assertLogs('benchmarks.torchbench_model', level='WARNING') as cm:
21+
use_amp = model.use_amp()
22+
self.assertEqual(len(cm.output), 1)
23+
self.assertIn("AMP is not used", cm.output[0])
24+
self.assertFalse(use_amp)
25+
26+
def test_use_amp_on_cuda(self):
27+
experiment = MockExperiment("cuda", "train")
28+
model = TorchBenchModel("torchbench or other", "super_deep_model",
29+
experiment)
30+
self.assertTrue(model.use_amp())
31+
32+
33+
if __name__ == '__main__':
34+
unittest.main()

0 commit comments

Comments
 (0)