-
Notifications
You must be signed in to change notification settings - Fork 699
/
Copy pathshape_inference_test.py
149 lines (110 loc) · 4.34 KB
/
shape_inference_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# isort:skip_file
# pyre-ignore-all-errors
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch_glow
from glow.glow.torch_glow.tests.tests import utils
class TestGlowShapeInference(utils.TorchGlowTestCase):
def test_shape_inference_basics(self):
"""Test Glow shape inference basic usage."""
def f(a):
return a * a
a = torch.randn(1)
jit_f = torch.jit.trace(f, (a))
jit_f_graph = jit_f.graph_for(a)
args = (a,)
actual = torch_glow.glow_shape_inference(
jit_f_graph,
args,
)
assert actual
def test_shape_inference_input_mismatch(self):
"""Test Glow shape inference basic error handling."""
def f(a):
return a * a
a = torch.randn(1)
jit_f = torch.jit.trace(f, (a))
jit_f_graph = jit_f.graph_for(a)
# Input/args is empty, but the funciton expects one input.
# Shape Inference should raise an exception in this case.
args = ()
self.assertRaises(
Exception,
lambda: torch_glow.glow_shape_inference(
jit_f_graph,
args,
),
)
def test_shape_inference_supported_symbols(self):
"""Test Glow shape inference unsupported symbols."""
def f(a):
return a * a
a = torch.randn(1)
jit_f = torch.jit.trace(f, (a))
jit_f_graph = jit_f.graph_for(a)
args = (a,)
actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
jit_f_graph, args
)
expected = []
self.assertEqual(set(expected), set(actual))
def test_shape_inference_unsupported_symbols(self):
"""Test Glow shape inference unsupported symbols."""
def f(a):
# linalg.multi_dot is currently not supported by shape inference engine
return torch.matrix_power(torch.linalg.multi_dot([a * 3, a + 4]), 3)
a = torch.randn(3, 3)
jit_f = torch.jit.trace(f, (a))
jit_f_graph = jit_f.graph_for(a)
args = (a,)
actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
jit_f_graph, args
)
expected = ["aten::linalg_multi_dot", "aten::linalg_matrix_power"]
self.assertEqual(set(expected), set(actual))
blocklist = ["aten::linalg_multi_dot"]
actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
jit_f_graph, args, blocklist
)
expected = ["aten::linalg_matrix_power"]
self.assertEqual(set(expected), set(actual))
def test_shape_inference_unsupported_symbols_skip_fusion_group(self):
"""Test Glow shape inference unsupported symbols including skipping of
symbols after a secondary fusion group."""
def f(a, b):
x1 = a * b
x2 = x1 * b
x3 = x2 * a
x4 = x3 / b
x5 = x4 / a
x6 = x5 / b
x7 = x6 * a
x8 = x7 * b
return x8 * torch.linalg.multi_dot([x8, x8])
torch_glow.enableFusionPass_DO_NOT_USE_THIS()
torch_glow.setFusionStartIndex(3)
torch_glow.setFusionEndIndex(6)
a = torch.randn(5, 5)
b = torch.randn(5, 5)
jit_f = torch.jit.trace(f, (a, b))
jit_f_graph = jit_f.graph_for(a, b)
torch_glow.clearFusionIndices()
args = (a, b)
# Don't skip nodes after the last fusion node.
# in this case, one of the nodes (linalg.multi_dot) following the last fusion node
# is not supported, and should be reported.
actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
jit_f_graph, args, skip_last_fusion_node=False
)
expected = [
"aten::linalg_multi_dot",
]
self.assertEqual(set(expected), set(actual))
# DO skip nodes after the last fusion node.
# in this case, one of the nodes (linalg.multi_dot) following the last fusion node
# is not supported, but is suppressed due to the skip_last_fusion_node flag.
actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
jit_f_graph, args, skip_last_fusion_node=True
)
expected = []
self.assertEqual(set(expected), set(actual))