@@ -137,15 +137,38 @@ def opset_1(cls, graph, node, **kw):
137137 assert node .attrs ['data_format' ] == 'NCHW' or node .attrs ['data_format' ] == "AnyLayout" , \
138138 "The conv data format should be 'NCHW', but received data format " \
139139 "is %s." % node .attrs ['data_format' ]
140+ x_dtype = node .input_dtype ('X' , 0 )
141+ need_dtype_convert = False
142+ input_name = node .input ('X' , 0 )
143+ if x_dtype != paddle .float32 :
144+ need_dtype_convert = True
145+ input_name = graph .make_node (
146+ 'Cast' , inputs = node .input ('X' ), to = dtypes .ONNX .FLOAT )
147+
140148 if node .attr ('global_pooling' ) or (node .attr ('adaptive' ) and
141149 node .attr ('ksize' ) == [1 , 1 ]):
142- onnx_node = graph .make_node (
143- cls .pool_type [node .attr ('pooling_type' )][1 ],
144- inputs = node .input ('X' ),
145- outputs = node .output ('Out' ))
150+ if need_dtype_convert :
151+ onnx_node = graph .make_node (
152+ cls .pool_type [node .attr ('pooling_type' )][1 ],
153+ inputs = [input_name ])
154+ graph .make_node (
155+ 'Cast' ,
156+ inputs = [onnx_node ],
157+ outputs = node .output ('Out' ),
158+ to = dtypes .ONNX .DOUBLE )
159+ else :
160+ onnx_node = graph .make_node (
161+ cls .pool_type [node .attr ('pooling_type' )][1 ],
162+ inputs = [input_name ],
163+ outputs = node .output ('Out' ))
146164 elif node .attr ('adaptive' ):
147165 # if pool is adaptive, check if input shape of pool is fixed.
148- mapper_helper .is_static_shape (node .input_shape ('X' , 0 ))
166+ if node .input_shape ('X' , 0 )[2 :].count (- 1 ) > 0 :
167+ raise Exception (
168+ "Converting this model to ONNX need with static input shape," \
169+ " please fix input shape of this model, see doc Q2 in" \
170+ " https://github.com/PaddlePaddle/paddle2onnx/blob/develop/docs/en/FAQ.md."
171+ )
149172 input_h , input_w = node .input_shape ('X' , 0 )[2 :]
150173 output_h , output_w = node .output_shape ('Out' , 0 )[2 :]
151174 stride_h = int (input_h / output_h )
@@ -179,11 +202,22 @@ def opset_1(cls, graph, node, **kw):
179202 attrs ['auto_pad' ] = 'VALID'
180203 if node .attr ('pooling_type' ) == 'avg' :
181204 attrs ['count_include_pad' ] = not node .attr ('exclusive' )
182- onnx_node = graph .make_node (
183- cls .pool_type [node .attr ('pooling_type' )][0 ],
184- inputs = node .input ('X' ),
185- outputs = node .output ('Out' ),
186- attrs = attrs )
205+ if need_dtype_convert :
206+ onnx_node = graph .make_node (
207+ cls .pool_type [node .attr ('pooling_type' )][0 ],
208+ inputs = [input_name ],
209+ attrs = attrs )
210+ graph .make_node (
211+ 'Cast' ,
212+ inputs = [onnx_node ],
213+ outputs = node .output ('Out' ),
214+ to = dtypes .ONNX .DOUBLE )
215+ else :
216+ onnx_node = graph .make_node (
217+ cls .pool_type [node .attr ('pooling_type' )][0 ],
218+ inputs = [input_name ],
219+ outputs = node .output ('Out' ),
220+ attrs = attrs )
187221 else :
188222 input_shape = node .input_shape ('X' , 0 )
189223 k_size = node .attr ('ksize' )
@@ -200,7 +234,7 @@ def opset_1(cls, graph, node, **kw):
200234 if input_shape [3 ] > 0 and input_shape [3 ] + pads [1 ] < k_size [1 ]:
201235 k_size [1 ] = input_shape [3 ] + pads [1 ]
202236
203- input_x = node . input ( 'X' )
237+ input_x = [ input_name ]
204238 if max (k_size ) <= max (pads ):
205239 onnx_paddings = [0 , 0 , pads [0 ], pads [1 ], 0 , 0 , pads [2 ], pads [3 ]]
206240 attrs_pad = {'mode' : 'constant' , }
@@ -243,11 +277,22 @@ def opset_1(cls, graph, node, **kw):
243277
244278 if node .attr ('pooling_type' ) == 'avg' :
245279 attrs ['count_include_pad' ] = not node .attr ('exclusive' )
246- onnx_node = graph .make_node (
247- cls .pool_type [node .attr ('pooling_type' )][0 ],
248- inputs = input_x ,
249- outputs = node .output ('Out' ),
250- attrs = attrs )
280+ if need_dtype_convert :
281+ onnx_node = graph .make_node (
282+ cls .pool_type [node .attr ('pooling_type' )][0 ],
283+ inputs = input_x ,
284+ attrs = attrs )
285+ graph .make_node (
286+ 'Cast' ,
287+ inputs = [onnx_node ],
288+ outputs = node .output ('Out' ),
289+ to = dtypes .ONNX .DOUBLE )
290+ else :
291+ onnx_node = graph .make_node (
292+ cls .pool_type [node .attr ('pooling_type' )][0 ],
293+ inputs = input_x ,
294+ outputs = node .output ('Out' ),
295+ attrs = attrs )
251296
252297
253298@op_mapper ('pool3d' )
@@ -283,7 +328,12 @@ def opset_1(cls, graph, node, **kw):
283328 outputs = node .output ('Out' ))
284329 elif node .attr ('adaptive' ):
285330 # if pool is adaptive, check if input shape of pool is fixed.
286- mapper_helper .is_static_shape (node .input_shape ('X' , 0 ))
331+ if node .input_shape ('X' , 0 )[2 :].count (- 1 ) > 0 :
332+ raise Exception (
333+ "Converting this model to ONNX need with static input shape," \
334+ " please fix input shape of this model, see doc Q2 in" \
335+ " https://github.com/PaddlePaddle/paddle2onnx/blob/develop/docs/en/FAQ.md."
336+ )
287337 input_d , input_h , input_w = node .input_shape ('X' , 0 )[2 :]
288338 output_d , output_h , output_w = node .output_shape ('Out' , 0 )[2 :]
289339 stride_d = int (input_d / output_d )
0 commit comments