forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ops.py
128 lines (102 loc) · 4.86 KB
/
test_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from functools import partial, wraps
import torch
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_methods_invocations import \
(op_db)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm)
from torch.autograd.gradcheck import gradcheck, gradgradcheck
# Tests that apply to all operators
class TestOpInfo(TestCase):
exact_dtype = True
# Verifies that ops have their unsupported dtypes
# registered correctly by testing that each claimed unsupported dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db, unsupported_dtypes_only=True)
def test_unsupported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
with self.assertRaises(RuntimeError):
op(sample.input, *sample.args, **sample.kwargs)
# Verifies that ops have their supported dtypes
# registered correctly by testing that each claimed supported dtype
# does NOT throw a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db)
def test_supported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(sample.input, *sample.args, **sample.kwargs)
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest("Skipped! ", op.name, ' does not support dtype ', str(dtype))
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
partial_fn = partial(variant, **sample.kwargs)
if check == 'gradcheck':
self.assertTrue(gradcheck(partial_fn, (sample.input,) + sample.args,
check_grad_dtypes=True))
elif check == 'gradgradcheck':
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=False,
check_grad_dtypes=True))
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=True,
check_grad_dtypes=True))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradcheck')
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
# Tests that gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# Test that gradients of gradients are computed correctly
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_op())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_method_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_method())
@dtypes(torch.double, torch.cdouble)
@ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())
if __name__ == '__main__':
run_tests()