Skip to content

Commit 9fc8d67

Browse files
authored
Merge pull request #647 from yeliang2258/unfold_dev
Add unfold op
2 parents 68c89d5 + cf31a2b commit 9fc8d67

File tree

4 files changed

+269
-2
lines changed

4 files changed

+269
-2
lines changed

paddle2onnx/op_mapper/mapper_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ def is_static_shape(shape):
2828
)
2929

3030

31+
def shape_helper(graph, input, dim=None):
32+
if dim is None:
33+
shape_node = graph.make_node('Shape', inputs=[input])
34+
return shape_node
35+
full_shape = graph.make_node('Shape', inputs=[input])
36+
shape_node = slice_helper(graph, full_shape, [0], [dim], [dim + 1])
37+
return shape_node
38+
39+
3140
def split_helper(graph, input, axis=0, split=None, outputs=None):
3241
assert outputs is not None, "outputs can not be None in split_helper."
3342
inputs = []

paddle2onnx/op_mapper/math.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,143 @@ def opset_13(cls, graph, node, **kw):
12421242
outputs=node.output('Out'))
12431243

12441244

1245+
@op_mapper('unfold')
1246+
class Unfold():
1247+
support_opset_version_range = (11, 15)
1248+
1249+
@classmethod
1250+
def opset_11(cls, graph, node, **kw):
1251+
1252+
strides = node.attr('strides')
1253+
stride_h = strides[0]
1254+
stride_w = strides[1]
1255+
1256+
paddings = node.attr('paddings')
1257+
padding_h_1 = paddings[0]
1258+
padding_w_1 = paddings[1]
1259+
padding_h_2 = paddings[2]
1260+
padding_w_2 = paddings[3]
1261+
1262+
dilations = node.attr('dilations')
1263+
dilation_h = dilations[0]
1264+
dilation_w = dilations[1]
1265+
1266+
kernel_sizes = node.attr('kernel_sizes')
1267+
kernel_h = kernel_sizes[0]
1268+
kernel_w = kernel_sizes[1]
1269+
1270+
input_w = mapper_helper.shape_helper(graph, node.input('X', 0), 3)
1271+
blocks_row_indices_node = cls._get_im2col_indices_along_dim(
1272+
graph, node, 2, kernel_h, dilation_h, padding_h_1, padding_h_2,
1273+
stride_h)
1274+
blocks_col_indices_node = cls._get_im2col_indices_along_dim(
1275+
graph, node, 3, kernel_w, dilation_w, padding_w_1, padding_w_2,
1276+
stride_w)
1277+
1278+
output_shape = cls._get_im2col_output_shape(graph, node, kernel_h,
1279+
kernel_w)
1280+
padded_input = cls._get_im2col_padded_input(
1281+
graph, node, padding_h_1, padding_h_2, padding_w_1, padding_w_2)
1282+
1283+
output = graph.make_node(
1284+
'Gather', inputs=[padded_input, blocks_row_indices_node], axis=2)
1285+
1286+
output = graph.make_node(
1287+
'Gather', inputs=[output, blocks_col_indices_node], axis=4)
1288+
output = graph.make_node(
1289+
'Transpose', inputs=[output], perm=[0, 1, 2, 4, 3, 5])
1290+
1291+
graph.make_node(
1292+
'Reshape', inputs=[output, output_shape], outputs=node.output('Y'))
1293+
1294+
@classmethod
1295+
def _get_im2col_indices_along_dim(cls, graph, node, index, kernel_size_d,
1296+
dilation_d, padding_d_1, padding_d_2,
1297+
stride_d):
1298+
input_shape = node.input_shape('X', 0)
1299+
if input_shape[index] == -1:
1300+
input_d_node = mapper_helper.shape_helper(graph,
1301+
node.input('X', 0), index)
1302+
1303+
padding_d_node = graph.make_node(
1304+
'Constant',
1305+
dtype=dtypes.ONNX.INT64,
1306+
value=[padding_d_1 + padding_d_2])
1307+
blocks_d_node = graph.make_node(
1308+
'Add', inputs=[input_d_node, padding_d_node])
1309+
1310+
dilation_kernel_size_node = graph.make_node(
1311+
'Constant',
1312+
dtype=dtypes.ONNX.INT64,
1313+
value=[dilation_d * (kernel_size_d - 1)])
1314+
blocks_d_node = graph.make_node(
1315+
'Sub', inputs=[blocks_d_node, dilation_kernel_size_node])
1316+
1317+
zero_node = graph.make_node(
1318+
'Constant', dtype=dtypes.ONNX.INT64, value=[0])
1319+
stride_node = graph.make_node(
1320+
'Constant', dtype=dtypes.ONNX.INT64, value=[stride_d])
1321+
blocks_d_indices_node = graph.make_node(
1322+
'Range', inputs=[zero_node, blocks_d_node, stride_node])
1323+
else:
1324+
end = input_shape[
1325+
index] + padding_d_1 + padding_d_2 - dilation_d * (kernel_size_d
1326+
- 1)
1327+
stride = stride_d
1328+
blocks_d_indices = np.arange(0, end, stride)
1329+
blocks_d_indices_node = graph.make_node(
1330+
'Constant',
1331+
dtype=dtypes.ONNX.INT64,
1332+
value=blocks_d_indices.flatten().tolist())
1333+
1334+
kernel_grid = np.arange(0, kernel_size_d * dilation_d, dilation_d)
1335+
kernel_grid_node = graph.make_node(
1336+
'Constant',
1337+
dtype=dtypes.ONNX.INT64,
1338+
value=kernel_grid.flatten().tolist())
1339+
1340+
shape_node = graph.make_node(
1341+
'Constant', dtype=dtypes.ONNX.INT64, value=[-1, 1])
1342+
kernel_mask_node = graph.make_node(
1343+
'Reshape', inputs=[kernel_grid_node, shape_node])
1344+
1345+
block_mask_node = graph.make_node(
1346+
'Add', inputs=[blocks_d_indices_node, kernel_mask_node])
1347+
return block_mask_node
1348+
1349+
@classmethod
1350+
def _get_im2col_output_shape(cls, graph, node, kernel_h, kernel_w):
1351+
batch_dim = mapper_helper.shape_helper(graph, node.input('X', 0), 0)
1352+
channel_dim = mapper_helper.shape_helper(graph, node.input('X', 0), 1)
1353+
1354+
constant_node = graph.make_node(
1355+
'Constant', dtype=dtypes.ONNX.INT64, value=[kernel_h * kernel_w])
1356+
channel_unfolded = graph.make_node(
1357+
'Mul', inputs=[channel_dim, constant_node])
1358+
1359+
concat_const_node = graph.make_node(
1360+
'Constant', dtype=dtypes.ONNX.INT64, value=[-1])
1361+
result_node = graph.make_node(
1362+
'Concat',
1363+
inputs=[batch_dim, channel_unfolded, concat_const_node],
1364+
axis=0)
1365+
1366+
return result_node
1367+
1368+
@classmethod
1369+
def _get_im2col_padded_input(cls, graph, node, padding_h_1, padding_h_2,
1370+
padding_w_1, padding_w_2):
1371+
pad_const_node = graph.make_node(
1372+
'Constant',
1373+
dtype=dtypes.ONNX.INT64,
1374+
value=[
1375+
0, 0, padding_h_1, padding_w_1, 0, 0, padding_h_2, padding_w_2
1376+
])
1377+
result_node = graph.make_node(
1378+
'Pad', inputs=[node.input('X', 0), pad_const_node])
1379+
return result_node
1380+
1381+
12451382
@op_mapper('softmax_with_cross_entropy')
12461383
class SoftmaxCrossEntropyLoss():
12471384
support_opset_version_range = (12, 15)

paddle2onnx/op_mapper/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
@op_mapper('set_value')
27-
class Set_value():
27+
class SetValue():
2828
support_opset_version_range = (11, 15)
2929

3030
@classmethod
@@ -71,7 +71,8 @@ def opset_11(cls, graph, node, **kw):
7171
onnx_paddings[axis] = starts[i]
7272
onnx_paddings[axis + len(input_x_shape)] = input_x_shape[
7373
axis] - ends[i]
74-
74+
if onnx_paddings[axis + len(input_x_shape)] < 0:
75+
onnx_paddings[axis + len(input_x_shape)] = 0
7576
dtype = node.input_dtype('Input', 0)
7677
dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[dtype]
7778

tests/test_auto_scan_unfold.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from auto_scan_test import OPConvertAutoScanTest, BaseNet
16+
from hypothesis import reproduce_failure
17+
import hypothesis.strategies as st
18+
import numpy as np
19+
import unittest
20+
import paddle
21+
22+
23+
class Net(BaseNet):
24+
"""
25+
simple Net
26+
"""
27+
28+
def forward(self, inputs):
29+
"""
30+
forward
31+
"""
32+
x = paddle.nn.functional.unfold(
33+
inputs,
34+
self.config["kernel_size"],
35+
strides=self.config["strides"],
36+
paddings=self.config["paddings"],
37+
dilations=self.config["dilations"])
38+
return x
39+
40+
41+
class TestUnfoldConvert(OPConvertAutoScanTest):
42+
"""
43+
api: paddle.nn.functional.unfold
44+
OPset version: 11, 15
45+
"""
46+
47+
def sample_convert_config(self, draw):
48+
input_shape = draw(
49+
st.lists(
50+
st.integers(
51+
min_value=20, max_value=30), min_size=4, max_size=4))
52+
53+
kernel_size = draw(
54+
st.lists(
55+
st.integers(
56+
min_value=1, max_value=5), min_size=1, max_size=2))
57+
if len(kernel_size) == 1:
58+
kernel_size = kernel_size[0]
59+
60+
strides = draw(
61+
st.lists(
62+
st.integers(
63+
min_value=1, max_value=5), min_size=1, max_size=2))
64+
if len(strides) == 1:
65+
strides = strides[0]
66+
67+
if draw(st.booleans()):
68+
paddings = draw(
69+
st.lists(
70+
st.integers(
71+
min_value=1, max_value=5),
72+
min_size=1,
73+
max_size=2))
74+
if len(paddings) == 1:
75+
paddings = paddings[0]
76+
else:
77+
paddings = draw(
78+
st.lists(
79+
st.integers(
80+
min_value=1, max_value=5),
81+
min_size=4,
82+
max_size=4))
83+
84+
dilations = draw(
85+
st.lists(
86+
st.integers(
87+
min_value=1, max_value=3), min_size=1, max_size=2))
88+
89+
if len(dilations) == 1:
90+
dilations = dilations[0]
91+
92+
dtype = draw(st.sampled_from(["float32", "float64"]))
93+
94+
input_spec_shape = []
95+
if draw(st.booleans()):
96+
input_spec_shape = [[-1, input_shape[1], -1, -1]]
97+
98+
config = {
99+
"op_names": ["unfold"],
100+
"test_data_shapes": [input_shape],
101+
"test_data_types": [[dtype]],
102+
"opset_version": [11, 12, 13, 14, 15],
103+
"input_spec_shape": input_spec_shape,
104+
"kernel_size": kernel_size,
105+
"strides": strides,
106+
"dilations": dilations,
107+
"paddings": paddings,
108+
"use_gpu": False,
109+
}
110+
111+
models = Net(config)
112+
113+
return (config, models)
114+
115+
def test(self):
116+
self.run_and_statis(max_examples=25)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)