forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_models_onnxruntime.py
32 lines (24 loc) · 1.07 KB
/
test_models_onnxruntime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import onnxruntime # noqa
from test_models import TestModels
from test_pytorch_onnx_onnxruntime import run_model_test
import torch
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
opset_versions = opset_versions if opset_versions else [7, 8, 9, 10, 11, 12]
for opset_version in opset_versions:
self.opset_version = opset_version
run_model_test(self, model, False,
input=inputs, rtol=rtol, atol=atol)
if self.is_script_test_enabled and opset_version > 11:
outputs = model(inputs)
script_model = torch.jit.script(model)
run_model_test(self, script_model, False, example_outputs=outputs,
input=inputs, rtol=rtol, atol=atol, use_new_jit_passes=True)
if __name__ == '__main__':
TestModels.is_script_test_enabled = True
TestModels.exportTest = exportTest
unittest.main()