Skip to content

Commit cec5500

Browse files
committed
add unit tests
1 parent d9361c8 commit cec5500

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
3+
from benchmarks.torchbench_model import TorchBenchModel
4+
5+
6+
class MockExperiment:
7+
8+
def __init__(self, accelerator, test):
9+
self.accelerator = accelerator
10+
self.test = "train"
11+
12+
13+
class TorchBenchModelTest(unittest.TestCase):
14+
15+
def test_do_not_use_amp_on_cpu_and_warns(self):
16+
experiment = MockExperiment("cpu", "train")
17+
model = TorchBenchModel("torchbench or other", "super_deep_model",
18+
experiment)
19+
with self.assertLogs('benchmarks.torchbench_model', level='WARNING') as cm:
20+
use_amp = model.use_amp()
21+
self.assertEqual(len(cm.output), 1)
22+
self.assertIn("AMP is not used", cm.output[0])
23+
self.assertFalse(use_amp)
24+
25+
def test_use_amp_on_cuda(self):
26+
experiment = MockExperiment("cuda", "train")
27+
model = TorchBenchModel("torchbench or other", "super_deep_model",
28+
experiment)
29+
self.assertTrue(model.use_amp())
30+
31+
32+
if __name__ == '__main__':
33+
unittest.main()

0 commit comments

Comments
 (0)