This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Expand file tree
/
Copy pathtest_fusion.py
More file actions
371 lines (323 loc) · 12.8 KB
/
test_fusion.py
File metadata and controls
371 lines (323 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import os
import random
import itertools
import mxnet as mx
import numpy as np
from mxnet import autograd, gluon
from mxnet.test_utils import *
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
def check_fused_symbol(sym, **kwargs):
inputs = sym.list_inputs()
shapes = {inp : kwargs[inp].shape for inp in inputs}
ctx = kwargs.get('ctx', mx.gpu(0))
# Double identity so that there is always something to fuse
test_sym = mx.sym.Group([mx.sym.identity(mx.sym.identity(s)) for s in sym])
rtol = {'float16' : 1e-2,
'float32' : 1.5e-6,
'float64' : 1.5e-6,
}
atol = {'float16' : 1e-3,
'float32' : 1e-7,
'float64' : 1e-7,
}
for dtype in ['float16', 'float32', 'float64']:
data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
for grad_req in ['write', 'add']:
type_dict = {inp : dtype for inp in inputs}
with environment('MXNET_USE_FUSION', '0'):
orig_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
with environment('MXNET_USE_FUSION', '1'):
fused_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
fwd_orig = orig_exec.forward(is_train=True, **data)
out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
orig_exec.backward(out_grads=out_grads)
fwd_fused = fused_exec.forward(is_train=True, **data)
fused_exec.backward(out_grads=out_grads)
for orig, fused in zip(fwd_orig, fwd_fused):
np.testing.assert_allclose(orig.asnumpy(), fused.asnumpy(), rtol=rtol[dtype], atol=atol[dtype])
for orig, fused in zip(orig_exec.grad_arrays, fused_exec.grad_arrays):
if orig is None and fused is None:
continue
assert orig is not None
assert fused is not None
np.testing.assert_allclose(orig.asnumpy(), fused.asnumpy(), rtol=rtol[dtype], atol=atol[dtype])
def check_unary_ops():
unary_ops = [
'relu',
'sigmoid',
'softsign',
'exp',
'expm1',
'log',
'log10',
'log2',
'log1p',
'degrees',
'radians',
'sin',
'cos',
'tan',
'arcsin',
'arccos',
'arctan',
'sinh',
'cosh',
'tanh',
'arcsinh',
'arctanh',
'sqrt',
'rsqrt',
'cbrt',
'rcbrt',
'square',
'squeeze',
'zeros_like',
'ones_like',
'flatten',
'round',
'rint',
'fix',
'floor',
'ceil',
'trunc',
'sign',
'reciprocal',
'abs',
'gamma',
'gammaln',
'erf',
'negative',
'logical_not',
]
def announce_check(op_name):
print("Checking fusion of " + op_name)
arr = mx.random.uniform(shape=rand_shape_2d())
a = mx.sym.Variable('a')
for op_name in unary_ops:
announce_check(op_name)
op = getattr(mx.sym, op_name)
sym = op(a)
check_fused_symbol(sym, a=arr)
# unary ops requiring special treatment
# arccosh needs input to be >= 1
arr2 = arr + 1
announce_check('arccosh')
check_fused_symbol(mx.sym.arccosh(a), a=arr2)
# erfinv needs -1 < input < 1, but we avoid the limits of this range where the slope nears +inf.
arr2 = (arr - 0.5) * 1.99
announce_check('erfinv')
check_fused_symbol(mx.sym.erfinv(a), a=arr2)
# Activation requires act_type attribute
for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
announce_check("Activation(act_type='{}')".format(act_type))
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
if act_type == 'softrelu':
# Check that softrelu implementation doesn't overflow on large inputs
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr)
# Cast requires dtype
for dtype in ['float16', 'float32', 'float64', 'int32']:
announce_check("Cast(dtype='{}')".format(dtype))
check_fused_symbol(mx.sym.Cast(a, dtype=dtype), a=arr)
# reshape requires shape
announce_check('reshape')
check_fused_symbol(mx.sym.reshape(a, shape=(-1,)), a=arr)
# expand_dims requires axis
announce_check('expand_dims')
check_fused_symbol(mx.sym.expand_dims(a, axis=1), a=arr)
# clip requires a_min, a_max
announce_check('clip')
check_fused_symbol(mx.sym.clip(a, a_min=0.3, a_max=0.7), a=arr)
check_fused_symbol(mx.sym.clip(a, a_min=-np.inf, a_max=0.7), a=arr)
check_fused_symbol(mx.sym.clip(a, a_min=-np.inf, a_max=np.inf), a=arr)
check_fused_symbol(mx.sym.clip(a, a_min=0, a_max=np.nan), a=arr)
# smooth_l1 requires a scalar
announce_check('smooth_l1')
check_fused_symbol(mx.sym.smooth_l1(a, scalar=0.3), a=arr)
def check_binary_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)
check_fused_symbol(a+b, a=arr1, b=arr2)
check_fused_symbol(a+3, a=arr1)
check_fused_symbol(a-b, a=arr1, b=arr2)
check_fused_symbol(a-3, a=arr1)
check_fused_symbol(3-a, a=arr1)
check_fused_symbol(a*b, a=arr1, b=arr2)
check_fused_symbol(a*3, a=arr1)
check_fused_symbol(a/(b+1), a=arr1, b=arr2)
check_fused_symbol(a/3, a=arr1)
check_fused_symbol(3/a, a=arr1)
check_fused_symbol(a**b, a=arr1, b=arr2)
check_fused_symbol(a**3, a=arr1)
check_fused_symbol(mx.sym.pow(3,a), a=arr1)
check_fused_symbol(mx.sym.maximum(a,b), a=arr1, b=arr2)
check_fused_symbol(mx.sym.minimum(a,b), a=arr1, b=arr2)
check_fused_symbol(mx.sym.hypot(a,b), a=arr1, b=arr2)
check_fused_symbol(mx.sym.hypot(a,3), a=arr1)
def check_other_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = mx.sym.Variable('c')
shape = rand_shape_2d()
shape = list((5,) + shape)
# Make sure there is at least 2 elements for the test with negative indices
shape[1] += 1
shape[2] += 1
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)
arr3 = mx.random.uniform(shape=shape)
check_fused_symbol(mx.sym.add_n(a,b,c), a=arr1, b=arr2, c=arr3)
check_fused_symbol(mx.sym.slice_axis(a, axis=0, begin=1, end=4), a=arr1)
# Testing handling of negative axis
check_fused_symbol(mx.sym.slice_axis(a, axis=-3, begin=1, end=4), a=arr1)
begin = (random.randint(0, shape[0]-1),
random.randint(0, shape[1]-1),
random.randint(0, shape[2]-1))
end = (random.randint(begin[0]+1, shape[0]),
random.randint(begin[1]+1, shape[1]),
random.randint(begin[2]+1, shape[2]))
check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
begin = (random.randint(-shape[0], -2),
random.randint(-shape[1], -2),
random.randint(-shape[2], -2))
end = (random.randint(begin[0]+1, -1),
random.randint(begin[1]+1, -1),
random.randint(begin[2]+1, -1))
check_fused_symbol(mx.sym.slice(a, begin=begin, end=end), a=arr1)
arr1 = mx.random.uniform(shape=(2,3,4,5))
arr2 = mx.random.uniform(shape=(1,2,3))
check_fused_symbol(mx.sym.slice_like(a,b, axes=[-2, 0]), a=arr1, b=arr2)
arr1 = mx.random.uniform(shape=(1,1,2,3))
arr2 = mx.random.uniform(shape=(2,2,2,3))
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)
def check_leakyrelu_ops():
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)
# Testing gelu
print("Checking fusion of LeakyReLU:gelu")
check_fused_symbol(mx.sym.LeakyReLU(a+b, act_type='gelu'), a=arr1, b=arr2)
def test_fusion():
check_unary_ops()
check_binary_ops()
check_other_ops()
check_leakyrelu_ops()
def test_fusion_compiler_cache():
# Stresses the internal cache of CUfunctions by creating the same kernel multiple times and
# on multiple GPUs if available.
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
shape = rand_shape_2d()
arr1 = mx.random.uniform(shape=shape)
arr2 = mx.random.uniform(shape=shape)
# Invoke the same model twice, second time will exercise compile cache
check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2)
check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2)
# On multi-GPU systems, invoke the same model on other GPUs
num_gpus = mx.device.num_gpus()
if num_gpus > 1:
check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2)
@use_np
def test_fusion_boolean_inputs():
from mxnet.gluon import HybridBlock
class Foo(HybridBlock):
def __init__(self):
super(Foo, self).__init__()
def forward(self, valid_length):
mask = valid_length.astype(np.float32)
mask2 = valid_length.astype(np.float32)
mask = mask * mx.np.expand_dims(mask2, axis=-1)
return mask
foo = Foo()
foo.hybridize(static_alloc=True)
out = foo(mx.np.ones((10,), ctx=mx.gpu(), dtype=bool))
mx.npx.waitall()
@use_np
def test_fusion_different_dimensions():
from mxnet.gluon import HybridBlock
class Foo(HybridBlock):
def __init__(self):
super(Foo, self).__init__()
def forward(self, x):
mask2 = x.astype(np.float32)
mask = mx.np.expand_dims(mask2, axis=-1)
return mask
foo = Foo()
foo.hybridize(static_alloc=True)
# Pass 1-D data
out = foo(mx.np.ones((10,), ctx=mx.gpu()))
assert np.all(out.asnumpy() == np.ones((10,1)))
assert out.shape == (10,1)
# Pass 2-D data
out = foo(mx.np.ones((10,10), ctx=mx.gpu()))
assert np.all(out.asnumpy() == np.ones((10,10)))
assert out.shape == (10,10,1)
@use_np
def test_input_reorder():
class Block(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Block, self).__init__(**kwargs)
def forward(self, x, y, z):
s = x * 2
s2 = s + z
s = mx.np.add(s, y * y)
return mx.np.dot(s, s2)
for static_alloc in (False, True):
arg_shapes = [(10, 10), (10, 1), (10, 10)]
arg_data = [mx.np.random.uniform(size=s) for s in arg_shapes]
arrays = {}
for use_fusion in ('0', '1'):
with environment('MXNET_USE_FUSION', use_fusion):
arrays[use_fusion] = {}
n = Block()
n.hybridize(static_alloc=static_alloc)
args = [arg.copyto(mx.gpu()) for arg in arg_data]
for arg in args:
arg.attach_grad()
with autograd.record():
r = n(*args)
arrays[use_fusion]['result'] = r
r.backward()
for i, arg in enumerate(args):
arrays[use_fusion][i] = arg.grad
for key in ['result'] + list(range(len(arg_data))):
assert_allclose(arrays['0'][key].asnumpy(), arrays['1'][key].asnumpy())
@use_np
def test_fusion_cycle():
class Test(gluon.nn.HybridBlock):
def __init__(self, **kwargs):
super(Test, self).__init__(**kwargs)
def forward(self, x, y):
x = mx.npx.relu(x)
y = mx.npx.relu(y)
z1 = mx.np.expand_dims(mx.np.sum(x, axis=1), axis=1)
z2 = mx.np.expand_dims(mx.np.sum(y, axis=1), axis=1)
return x + z2, y + z1
t = Test()
a = mx.np.zeros(shape=(10,1), ctx=mx.gpu())
b = mx.np.zeros(shape=(10,1), ctx=mx.gpu())
t.hybridize(static_alloc=True, static_shape=True)
out = t(a, b)
mx.npx.waitall()