Skip to content

Commit 4eaa73e

Browse files
authored
fix pool and add avgpool test (#573)
* fix pool and add avgpool test * add float64 support
1 parent 40ad325 commit 4eaa73e

File tree

2 files changed

+243
-17
lines changed

2 files changed

+243
-17
lines changed

paddle2onnx/op_mapper/nn.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,38 @@ def opset_1(cls, graph, node, **kw):
137137
assert node.attrs['data_format'] == 'NCHW' or node.attrs['data_format'] == "AnyLayout", \
138138
"The conv data format should be 'NCHW', but received data format " \
139139
"is %s." % node.attrs['data_format']
140+
x_dtype = node.input_dtype('X', 0)
141+
need_dtype_convert = False
142+
input_name = node.input('X', 0)
143+
if x_dtype != paddle.float32:
144+
need_dtype_convert = True
145+
input_name = graph.make_node(
146+
'Cast', inputs=node.input('X'), to=dtypes.ONNX.FLOAT)
147+
140148
if node.attr('global_pooling') or (node.attr('adaptive') and
141149
node.attr('ksize') == [1, 1]):
142-
onnx_node = graph.make_node(
143-
cls.pool_type[node.attr('pooling_type')][1],
144-
inputs=node.input('X'),
145-
outputs=node.output('Out'))
150+
if need_dtype_convert:
151+
onnx_node = graph.make_node(
152+
cls.pool_type[node.attr('pooling_type')][1],
153+
inputs=[input_name])
154+
graph.make_node(
155+
'Cast',
156+
inputs=[onnx_node],
157+
outputs=node.output('Out'),
158+
to=dtypes.ONNX.DOUBLE)
159+
else:
160+
onnx_node = graph.make_node(
161+
cls.pool_type[node.attr('pooling_type')][1],
162+
inputs=[input_name],
163+
outputs=node.output('Out'))
146164
elif node.attr('adaptive'):
147165
# if pool is adaptive, check if input shape of pool is fixed.
148-
mapper_helper.is_static_shape(node.input_shape('X', 0))
166+
if node.input_shape('X', 0)[2:].count(-1) > 0:
167+
raise Exception(
168+
"Converting this model to ONNX need with static input shape," \
169+
" please fix input shape of this model, see doc Q2 in" \
170+
" https://github.com/PaddlePaddle/paddle2onnx/blob/develop/docs/en/FAQ.md."
171+
)
149172
input_h, input_w = node.input_shape('X', 0)[2:]
150173
output_h, output_w = node.output_shape('Out', 0)[2:]
151174
stride_h = int(input_h / output_h)
@@ -179,11 +202,22 @@ def opset_1(cls, graph, node, **kw):
179202
attrs['auto_pad'] = 'VALID'
180203
if node.attr('pooling_type') == 'avg':
181204
attrs['count_include_pad'] = not node.attr('exclusive')
182-
onnx_node = graph.make_node(
183-
cls.pool_type[node.attr('pooling_type')][0],
184-
inputs=node.input('X'),
185-
outputs=node.output('Out'),
186-
attrs=attrs)
205+
if need_dtype_convert:
206+
onnx_node = graph.make_node(
207+
cls.pool_type[node.attr('pooling_type')][0],
208+
inputs=[input_name],
209+
attrs=attrs)
210+
graph.make_node(
211+
'Cast',
212+
inputs=[onnx_node],
213+
outputs=node.output('Out'),
214+
to=dtypes.ONNX.DOUBLE)
215+
else:
216+
onnx_node = graph.make_node(
217+
cls.pool_type[node.attr('pooling_type')][0],
218+
inputs=[input_name],
219+
outputs=node.output('Out'),
220+
attrs=attrs)
187221
else:
188222
input_shape = node.input_shape('X', 0)
189223
k_size = node.attr('ksize')
@@ -200,7 +234,7 @@ def opset_1(cls, graph, node, **kw):
200234
if input_shape[3] > 0 and input_shape[3] + pads[1] < k_size[1]:
201235
k_size[1] = input_shape[3] + pads[1]
202236

203-
input_x = node.input('X')
237+
input_x = [input_name]
204238
if max(k_size) <= max(pads):
205239
onnx_paddings = [0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]]
206240
attrs_pad = {'mode': 'constant', }
@@ -243,11 +277,22 @@ def opset_1(cls, graph, node, **kw):
243277

244278
if node.attr('pooling_type') == 'avg':
245279
attrs['count_include_pad'] = not node.attr('exclusive')
246-
onnx_node = graph.make_node(
247-
cls.pool_type[node.attr('pooling_type')][0],
248-
inputs=input_x,
249-
outputs=node.output('Out'),
250-
attrs=attrs)
280+
if need_dtype_convert:
281+
onnx_node = graph.make_node(
282+
cls.pool_type[node.attr('pooling_type')][0],
283+
inputs=input_x,
284+
attrs=attrs)
285+
graph.make_node(
286+
'Cast',
287+
inputs=[onnx_node],
288+
outputs=node.output('Out'),
289+
to=dtypes.ONNX.DOUBLE)
290+
else:
291+
onnx_node = graph.make_node(
292+
cls.pool_type[node.attr('pooling_type')][0],
293+
inputs=input_x,
294+
outputs=node.output('Out'),
295+
attrs=attrs)
251296

252297

253298
@op_mapper('pool3d')
@@ -283,7 +328,12 @@ def opset_1(cls, graph, node, **kw):
283328
outputs=node.output('Out'))
284329
elif node.attr('adaptive'):
285330
# if pool is adaptive, check if input shape of pool is fixed.
286-
mapper_helper.is_static_shape(node.input_shape('X', 0))
331+
if node.input_shape('X', 0)[2:].count(-1) > 0:
332+
raise Exception(
333+
"Converting this model to ONNX need with static input shape," \
334+
" please fix input shape of this model, see doc Q2 in" \
335+
" https://github.com/PaddlePaddle/paddle2onnx/blob/develop/docs/en/FAQ.md."
336+
)
287337
input_d, input_h, input_w = node.input_shape('X', 0)[2:]
288338
output_d, output_h, output_w = node.output_shape('Out', 0)[2:]
289339
stride_d = int(input_d / output_d)

tests/test_auto_scan_avgpool.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from auto_scan_test import OPConvertAutoScanTest, BaseNet
16+
from hypothesis import reproduce_failure
17+
import hypothesis.strategies as st
18+
import numpy as np
19+
import unittest
20+
import paddle
21+
22+
23+
class NetAvgpool2d(BaseNet):
24+
"""
25+
simple Net
26+
"""
27+
28+
def forward(self, inputs):
29+
"""
30+
forward
31+
"""
32+
kernel_size = self.config['kernel_size']
33+
stride = self.config['stride']
34+
padding = self.config['padding']
35+
ceil_mode = self.config['ceil_mode']
36+
data_format = self.config['data_format']
37+
x = paddle.nn.functional.avg_pool2d(
38+
inputs,
39+
kernel_size=kernel_size,
40+
stride=stride,
41+
padding=padding,
42+
ceil_mode=ceil_mode,
43+
data_format=data_format)
44+
return x
45+
46+
47+
class TestMaxpool2dConvert(OPConvertAutoScanTest):
48+
"""
49+
api: paddle.nn.functional.avg_pool2d
50+
OPset version: 7, 9, 15
51+
"""
52+
53+
def sample_convert_config(self, draw):
54+
input_shape = draw(
55+
st.lists(
56+
st.integers(
57+
min_value=10, max_value=20), min_size=4, max_size=4))
58+
59+
dtype = draw(st.sampled_from(["float32", "float64"]))
60+
data_format = draw(st.sampled_from(["NCHW"]))
61+
62+
# max_pool2d_with_index
63+
return_mask = draw(st.booleans())
64+
return_mask = False
65+
ceil_mode = draw(st.booleans())
66+
67+
kernel_type = draw(st.sampled_from(["int", "list"]))
68+
if kernel_type == "int":
69+
kernel_size = draw(st.integers(min_value=7, max_value=10))
70+
elif kernel_type == "list":
71+
kernel_size = draw(
72+
st.lists(
73+
st.integers(
74+
min_value=7, max_value=10),
75+
min_size=2,
76+
max_size=2))
77+
78+
stride_type = draw(st.sampled_from(["None", "int", "list"]))
79+
if stride_type == "int":
80+
stride = draw(st.integers(min_value=1, max_value=5))
81+
elif stride_type == "list":
82+
stride = draw(
83+
st.lists(
84+
st.integers(
85+
min_value=1, max_value=5),
86+
min_size=2,
87+
max_size=2))
88+
else:
89+
stride = None
90+
91+
padding_type = draw(
92+
st.sampled_from(["None", "str", "int", "list2", "list4", "list8"]))
93+
if padding_type == "str":
94+
padding = draw(st.sampled_from(["SAME", "VALID"]))
95+
elif padding_type == "int":
96+
padding = draw(st.integers(min_value=1, max_value=5))
97+
elif padding_type == "list2":
98+
padding = draw(
99+
st.lists(
100+
st.integers(
101+
min_value=1, max_value=5),
102+
min_size=2,
103+
max_size=2))
104+
elif padding_type == "list4":
105+
padding = draw(
106+
st.lists(
107+
st.integers(
108+
min_value=1, max_value=5),
109+
min_size=4,
110+
max_size=4))
111+
elif padding_type == "list8":
112+
padding1 = np.expand_dims(
113+
np.array(
114+
draw(
115+
st.lists(
116+
st.integers(
117+
min_value=1, max_value=5),
118+
min_size=2,
119+
max_size=2))),
120+
axis=0).tolist()
121+
padding2 = np.expand_dims(
122+
np.array(
123+
draw(
124+
st.lists(
125+
st.integers(
126+
min_value=1, max_value=5),
127+
min_size=2,
128+
max_size=2))),
129+
axis=0).tolist()
130+
if data_format == "NCHW":
131+
padding = [[0, 0]] + [[0, 0]] + padding1 + padding2
132+
else:
133+
padding = [[0, 0]] + padding1 + padding2 + [[0, 0]]
134+
else:
135+
padding = 0
136+
137+
if return_mask and padding_type in ["list2", "list4", "list8"]:
138+
padding = draw(st.integers(min_value=1, max_value=5))
139+
140+
if return_mask:
141+
opset_version = [[9, 15]]
142+
else:
143+
opset_version = [[7, 9, 15]]
144+
if ceil_mode:
145+
opset_version = [10, 15]
146+
147+
if padding == "VALID":
148+
ceil_mode = False
149+
if return_mask:
150+
op_names = 'max_pool2d_with_index'
151+
else:
152+
op_names = 'pool2d'
153+
config = {
154+
"op_names": [op_names],
155+
"test_data_shapes": [input_shape],
156+
"test_data_types": [[dtype]],
157+
"opset_version": opset_version,
158+
"input_spec_shape": [],
159+
"kernel_size": kernel_size,
160+
"stride": stride,
161+
"padding": padding,
162+
"return_mask": return_mask,
163+
"ceil_mode": ceil_mode,
164+
"data_format": data_format
165+
}
166+
167+
models = NetAvgpool2d(config)
168+
169+
return (config, models)
170+
171+
def test(self):
172+
self.run_and_statis(max_examples=30)
173+
174+
175+
if __name__ == "__main__":
176+
unittest.main()

0 commit comments

Comments
 (0)