1717import numpy as np
1818from paddle2onnx .constant import dtypes
1919from paddle2onnx .op_mapper import OpMapper as op_mapper
20+ from paddle2onnx .op_mapper import mapper_helper
2021
2122
2223@op_mapper ('matmul' )
@@ -27,8 +28,24 @@ class MatMul():
2728 def opset_1 (cls , graph , node , ** kw ):
2829 x = node .input ('X' , idx = 0 )
2930 y = node .input ('Y' , idx = 0 )
30- graph .make_node ('MatMul' , inputs = [x , y ], outputs = node .output ('Out' ))
31-
31+ if node .attr ('transpose_X' ):
32+ perm = list (range (len (node .input_shape ('X' , 0 ))))
33+ perm [- 1 ], perm [- 2 ] = perm [- 2 ], perm [- 1 ]
34+ x = graph .make_node ('Transpose' , inputs = [x ], perm = perm )
35+ if node .attr ('transpose_Y' ):
36+ perm = list (range (len (node .input_shape ('Y' , 0 ))))
37+ perm [- 1 ], perm [- 2 ] = perm [- 2 ], perm [- 1 ]
38+ y = graph .make_node ('Transpose' , inputs = [y ], perm = perm )
39+ if node .attr ('alpha' ) == 1.0 :
40+ graph .make_node ('MatMul' , inputs = [x , y ], outputs = node .output ('Out' ))
41+ else :
42+ matmul = graph .make_node ('MatMul' , inputs = [x , y ])
43+ scale = graph .make_node (
44+ 'Constant' ,
45+ dtype = dtypes .ONNX .FLOAT ,
46+ value = node .attr ('alpha' ))
47+ onnx_node = graph .make_node (
48+ 'Mul' , inputs = [matmul , scale ], outputs = node .output ('Out' ))
3249
3350@op_mapper ('matmul_v2' )
3451class MatMul ():
@@ -40,9 +57,13 @@ def opset_1(cls, graph, node, **kw):
4057 y = node .input ('Y' , idx = 0 )
4158 out = node .output ('Out' )
4259 if node .attr ('trans_x' ):
43- x = graph .make_node ('Transpose' , inputs = [x ])
60+ perm = list (range (len (node .input_shape ('X' , 0 ))))
61+ perm [- 1 ], perm [- 2 ] = perm [- 2 ], perm [- 1 ]
62+ x = graph .make_node ('Transpose' , inputs = [x ], perm = perm )
4463 if node .attr ('trans_y' ):
45- y = graph .make_node ('Transpose' , inputs = [y ])
64+ perm = list (range (len (node .input_shape ('Y' , 0 ))))
65+ perm [- 1 ], perm [- 2 ] = perm [- 2 ], perm [- 1 ]
66+ y = graph .make_node ('Transpose' , inputs = [y ], perm = perm )
4667 graph .make_node ('MatMul' , inputs = [x , y ], outputs = out )
4768
4869
@@ -131,6 +152,30 @@ def opset_8(cls, graph, node, **kw):
131152 'Pow' , inputs = [x , factor_broadcast ], outputs = node .output ('Out' ))
132153
133154
155+ @op_mapper ('square' )
156+ class Square ():
157+ support_opset_verision_range = (7 , 12 )
158+
159+ @classmethod
160+ def opset_7 (cls , graph , node , ** kw ):
161+ x = node .input ('X' , 0 )
162+ onnx_node = graph .make_node (
163+ 'Mul' , inputs = [x , x ], outputs = node .output ('Out' ))
164+
165+ @op_mapper ('cumsum' )
166+ class CumSum ():
167+ support_opset_version_range = (11 , 12 )
168+
169+ @classmethod
170+ def opset_11 (cls , graph , node , ** kw ):
171+
172+ axis = graph .make_node ('Constant' , dtype = dtypes .ONNX .INT64 , value = node .attr ('axis' ))
173+ graph .make_node (
174+ 'CumSum' ,
175+ inputs = [node .input ('X' , 0 ), axis ],
176+ outputs = node .output ('Out' ))
177+
178+
134179@op_mapper ('mul' )
135180class Mul ():
136181 support_opset_version_range = (1 , 12 )
@@ -140,16 +185,24 @@ def opset_1(cls, graph, node, **kw):
140185 x = node .input ('X' , 0 )
141186 y = node .input ('Y' , 0 )
142187 out = node .output ('Out' , 0 )
143- x_shape = node .input_shape ('X' , 0 )
144- y_shape = node .input_shape ('Y' , 0 )
145188 x_num_col_dims = node .attr ('x_num_col_dims' )
146189 y_num_col_dims = node .attr ('y_num_col_dims' )
147190 flatten_x = graph .make_node (
148191 'Flatten' , inputs = node .input ('X' ), attrs = {'axis' : x_num_col_dims })
149192 flatten_y = graph .make_node (
150193 'Flatten' , inputs = node .input ('Y' ), attrs = {'axis' : y_num_col_dims })
151- mul_node = graph .make_node (
152- 'MatMul' , inputs = [flatten_x , flatten_y ], outputs = node .output ('Out' ))
194+ mul_node = graph .make_node ('MatMul' , inputs = [flatten_x , flatten_y ])
195+
196+ x_shape = graph .make_node ('Shape' , inputs = [x ])
197+ l_shape = mapper_helper .slice_helper (
198+ graph , x_shape , axes = [0 ], starts = [0 ], ends = [x_num_col_dims ])
199+ y_shape = graph .make_node ('Shape' , inputs = [y ])
200+ y_rank = len (node .input_shape ('Y' , 0 ))
201+ r_shape = mapper_helper .slice_helper (
202+ graph , y_shape , axes = [0 ], starts = [y_num_col_dims ], ends = [y_rank ])
203+
204+ out_shape = graph .make_node ('Concat' , inputs = [l_shape , r_shape ], axis = 0 )
205+ graph .make_node ('Reshape' , [mul_node , out_shape ], node .output ('Out' ))
153206
154207
155208@op_mapper ('affine_channel' )
@@ -244,6 +297,19 @@ def opset_1(cls, graph, node, **kw):
244297 axes = [0 ])
245298
246299
300+ @op_mapper ('mean' )
301+ class Mean ():
302+ support_opset_verison_range = (1 , 12 )
303+
304+ @classmethod
305+ def opset_1 (cls , graph , node , ** kw ):
306+ graph .make_node (
307+ 'ReduceMean' ,
308+ inputs = node .input ('X' ),
309+ outputs = node .output ('Out' ),
310+ keepdims = 0 )
311+
312+
247313@op_mapper ('arg_max' )
248314class ArgMax ():
249315 support_opset_version_range = (1 , 12 )
@@ -282,26 +348,28 @@ def opset_7(cls, graph, node, **kw):
282348 'Identity' , inputs = node .input ('X' ), outputs = node .output ('Out' ))
283349 else :
284350 scale_node = graph .make_node (
285- 'Constant' , attrs = {'dtype' : dtypes .ONNX .FLOAT ,
286- 'value' : scale })
351+ 'Constant' ,
352+ attrs = {'dtype' : dtypes .ONNX .FLOAT ,
353+ 'value' : [scale ]})
287354 bias_node = graph .make_node (
288- 'Constant' , attrs = {'dtype' : dtypes .ONNX .FLOAT ,
289- 'value' : bias })
355+ 'Constant' ,
356+ attrs = {'dtype' : dtypes .ONNX .FLOAT ,
357+ 'value' : [bias ]})
290358 cast_node = graph .make_node (
291359 'Cast' , inputs = node .input ('X' ),
292360 attrs = {'to' : dtypes .ONNX .FLOAT })
293361 if node .attr ('bias_after_scale' ):
294- node1 = graph .make_node ('Mul' , inputs = [scale_node , cast_node ])
362+ node1 = graph .make_node ('Mul' , inputs = [cast_node , scale_node ])
295363 node2 = graph .make_node (
296364 'Add' ,
297- inputs = [bias_node , node1 ],
365+ inputs = [node1 , bias_node ],
298366 outputs = node .output ('Out' ))
299367 else :
300- node1 = graph .make_node ('Add' , inputs = [bias_node , cast_node ])
368+ node1 = graph .make_node ('Add' , inputs = [cast_node , bias_node ])
301369 node2 = graph .make_node (
302370 'Mul' ,
303- inputs = [scale_node , node1 ],
304- outputs = [node .output ('Out' )])
371+ inputs = [node1 , scale_node ],
372+ outputs = [node .output ('Out' , 0 )])
305373
306374
307375@op_mapper ('softmax' )
@@ -333,3 +401,50 @@ def opset_1(cls, graph, node, **kw):
333401 inputs = [softmax_node ],
334402 outputs = node .output ('Out' ),
335403 attrs = {'perm' : perm })
404+
405+
406+ @op_mapper ('softmax_with_cross_entropy' )
407+ class SoftmaxCrossEntropyLoss ():
408+ support_opset_verison_range = (12 , 12 )
409+
410+ @classmethod
411+ def opset_12 (cls , graph , node , ** kw ):
412+ if node .attr ('soft_label' ):
413+ raise Exception (
414+ "SoftmaxCrossEntropyLoss in onnx not support soft label." )
415+
416+ labels = node .input ('Label' , 0 )
417+ scores = node .input ('Logits' , 0 )
418+
419+ outputs = [node .output ('Loss' , 0 )]
420+ if 'Softmax' in node .outputs :
421+ outputs .append (node .output ('Softmax' , 0 ))
422+
423+ shape = node .input_shape ('Logits' , 0 )
424+ axis = node .attr ('axis' )
425+ if axis < 0 :
426+ axis += len (shape )
427+ if axis == len (shape ) - 1 :
428+ graph .make_node (
429+ 'SoftmaxCrossEntropyLoss' ,
430+ inputs = [scores , labels ],
431+ outputs = outputs ,
432+ ignore_index = node .attr ('ignore_index' ),
433+ reduction = 'mean' )
434+ else :
435+ perm = [i for i in range (len (shape ))]
436+ perm [- 1 ] = axis
437+ perm [axis ] = len (shape ) - 1
438+ transpose_node = graph .make_node (
439+ 'Transpose' , inputs = node .input ('X' ), attrs = {'perm' : perm })
440+ node = graph .make_node (
441+ 'SoftmaxCrossEntropyLoss' ,
442+ inputs = [scores , labels ],
443+ outputs = outputs ,
444+ ignore_index = node .attr ('ignore_index' ),
445+ reduction = 'mean' )
446+ transpose_node1 = graph .make_node (
447+ 'Transpose' ,
448+ inputs = [softmax_node ],
449+ outputs = node .output ('Out' ),
450+ attrs = {'perm' : perm })
0 commit comments