forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_pytorch_common.py
115 lines (95 loc) · 3.62 KB
/
test_pytorch_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import functools
import os
import unittest
import sys
import torch
import torch.autograd.function as function
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(-1, pytorch_test_dir)
from torch.testing._internal.common_utils import * # noqa: F401,F403
torch.set_default_tensor_type("torch.FloatTensor")
BATCH_SIZE = 2
RNN_BATCH_SIZE = 7
RNN_SEQUENCE_LENGTH = 11
RNN_INPUT_SIZE = 5
RNN_HIDDEN_SIZE = 3
def _skipper(condition, reason):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if condition():
raise unittest.SkipTest(reason)
return f(*args, **kwargs)
return wrapper
return decorator
skipIfNoCuda = _skipper(lambda: not torch.cuda.is_available(),
"CUDA is not available")
skipIfTravis = _skipper(lambda: os.getenv("TRAVIS"),
"Skip In Travis")
# skips tests for all versions below min_opset_version.
# if exporting the op is only supported after a specific version,
# add this wrapper to prevent running the test for opset_versions
# smaller than the currently tested opset_version
def skipIfUnsupportedMinOpsetVersion(min_opset_version):
def skip_dec(func):
def wrapper(self):
if self.opset_version < min_opset_version:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
# skips tests for all versions above min_opset_version.
def skipIfUnsupportedMaxOpsetVersion(min_opset_version):
def skip_dec(func):
def wrapper(self):
if self.opset_version > min_opset_version:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
# skips tests for all opset versions.
def skipForAllOpsetVersions():
def skip_dec(func):
def wrapper(self):
if self.opset_version:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
# Enables tests for scripting, instead of only tracing the model.
def enableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = True
return func(self)
return wrapper
return script_dec
# Disable tests for scripting.
def disableScriptTest():
def script_dec(func):
def wrapper(self):
self.is_script_test_enabled = False
return func(self)
return wrapper
return script_dec
# skips tests for opset_versions listed in unsupported_opset_versions.
# if the caffe2 test cannot be run for a specific version, add this wrapper
# (for example, an op was modified but the change is not supported in caffe2)
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
def skip_dec(func):
def wrapper(self):
if self.opset_version in unsupported_opset_versions:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
def skipIfONNXShapeInference(onnx_shape_inference):
def skip_dec(func):
def wrapper(self):
if self.onnx_shape_inference is onnx_shape_inference:
raise unittest.SkipTest("Skip verify test for unsupported opset_version")
return func(self)
return wrapper
return skip_dec
def flatten(x):
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))