Skip to content

Commit 2c8421c

Browse files
author
channings
authored
Merge pull request #186 from Channingss/support_nlp
Support ERNIE of PaddleNLP
2 parents e6ebabd + 1d4debb commit 2c8421c

File tree

7 files changed

+173
-23
lines changed

7 files changed

+173
-23
lines changed

docs/en/op_list.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
| conv2d | 1~12 |
1919
| conv2d_transpose | 1~12 |
2020
| collect_fpn_proposals | 11~12 |
21+
| cumsum | 11~12 |
2122
| deformable_conv | 11~12 |
2223
| depthwise_conv2d | 1~12 |
2324
| distribute_fpn_proposals | 11~12 |
@@ -51,9 +52,12 @@
5152
| leaky_relu | 1~12 |
5253
| less_equal| 12~ |
5354
| log | 1~12 |
55+
| lookup_table | 1~12 |
56+
| lookup_table_v2 | 1~12 |
5457
| logical_and | 1~12 |
5558
| matmul | 1~12 |
5659
| matmul_v2 | 1~12 |
60+
| mean | 1~12 |
5761
| mul | 1~12 |
5862
| muticlass_nms | 10~12 |
5963
| muticlass_nms2 | 10~12 |
@@ -80,11 +84,13 @@
8084
| softmax | 1~12 |
8185
| scale | 1~12 | opset 1~6 limited supported |
8286
| sequence_expand | 1~12 |
87+
| softmax_with_cross_entropy | 12 |
8388
| shape | 1~12 |
8489
| sigmoid | 1~12 |
8590
| slice | 1~12 |
8691
| split | 1~12 |
8792
| squeeze2 | 1~12 |
93+
| square | 7~12 |
8894
| sqrt | 1~12 |
8995
| stack | 1~12 |
9096
| stride_slice | 1~12 |

docs/zh/model_zoo.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,17 @@
4747
## 图像检测
4848
待测试
4949

50+
## 自然语言处理
51+
目前支持的模型有ERNIE系列模型,测试模型来自于PaddleNLP [2.0-beta 分支](https://github.com/PaddlePaddle/models/tree/release/2.0-beta/PaddleNLP)
52+
53+
| 模型 | 来源 |
54+
|-------|--------|
55+
|ERNIE-1.0|[PaddleNLP](https://github.com/PaddlePaddle/models/blob/develop/PaddleNLP/docs/models.md#paddlenlpmodels) |
56+
|ERNIE-2.0|[PaddleNLP](https://github.com/PaddlePaddle/models/blob/develop/PaddleNLP/docs/models.md#paddlenlpmodels) |
5057

5158
# 静态图模型
5259
## 图像分类
53-
图像分类模型支持比较完善,测试模型来自于 PaddleCls [master/](https://github.com/PaddlePaddle/PaddleClas/tree/master)
60+
图像分类模型支持比较完善,测试模型来自于 PaddleCls [master 分支](https://github.com/PaddlePaddle/PaddleClas/tree/master)
5461

5562
| 模型 | 来源 |
5663
|-------|--------|

docs/zh/op_list.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
| conv2d | 1~12 |
1919
| conv2d_transpose | 1~12 |
2020
| collect_fpn_proposals | 11~12 |
21+
| cumsum | 11~12 |
2122
| deformable_conv | 11~12 |
2223
| depthwise_conv2d | 1~12 |
2324
| distribute_fpn_proposals | 11~12 |
@@ -51,9 +52,12 @@
5152
| leaky_relu | 1~12 |
5253
| less_equal| 12~ |
5354
| log | 1~12 |
55+
| lookup_table | 1~12 |
56+
| lookup_table_v2 | 1~12 |
5457
| logical_and | 1~12 |
5558
| matmul | 1~12 |
5659
| matmul_v2 | 1~12 |
60+
| mean | 1~12 |
5761
| mul | 1~12 |
5862
| muticlass_nms | 10~12 |
5963
| muticlass_nms2 | 10~12 |
@@ -80,11 +84,13 @@
8084
| softmax | 1~12 |
8185
| scale | 1~12 | opset 1~6 limited supported |
8286
| sequence_expand | 1~12 |
87+
| softmax_with_cross_entropy | 12 |
8388
| shape | 1~12 |
8489
| sigmoid | 1~12 |
8590
| slice | 1~12 |
8691
| split | 1~12 |
8792
| squeeze2 | 1~12 |
93+
| square | 7~12 |
8894
| sqrt | 1~12 |
8995
| stack | 1~12 |
9096
| stride_slice | 1~12 |

paddle2onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import absolute_import
1515

16-
__version__ = "0.4"
16+
__version__ = "0.5"
1717

1818
from .convert import dygraph2onnx, program2onnx
1919
from .op_mapper import register_op_mapper

paddle2onnx/op_mapper/math.py

Lines changed: 132 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from paddle2onnx.constant import dtypes
1919
from paddle2onnx.op_mapper import OpMapper as op_mapper
20+
from paddle2onnx.op_mapper import mapper_helper
2021

2122

2223
@op_mapper('matmul')
@@ -27,8 +28,24 @@ class MatMul():
2728
def opset_1(cls, graph, node, **kw):
2829
x = node.input('X', idx=0)
2930
y = node.input('Y', idx=0)
30-
graph.make_node('MatMul', inputs=[x, y], outputs=node.output('Out'))
31-
31+
if node.attr('transpose_X'):
32+
perm = list(range(len(node.input_shape('X', 0))))
33+
perm[-1], perm[-2] = perm[-2], perm[-1]
34+
x = graph.make_node('Transpose', inputs=[x], perm=perm)
35+
if node.attr('transpose_Y'):
36+
perm = list(range(len(node.input_shape('Y', 0))))
37+
perm[-1], perm[-2] = perm[-2], perm[-1]
38+
y = graph.make_node('Transpose', inputs=[y], perm=perm)
39+
if node.attr('alpha') == 1.0:
40+
graph.make_node('MatMul', inputs=[x, y], outputs=node.output('Out'))
41+
else:
42+
matmul = graph.make_node('MatMul', inputs=[x, y])
43+
scale = graph.make_node(
44+
'Constant',
45+
dtype=dtypes.ONNX.FLOAT,
46+
value=node.attr('alpha'))
47+
onnx_node = graph.make_node(
48+
'Mul', inputs=[matmul, scale], outputs=node.output('Out'))
3249

3350
@op_mapper('matmul_v2')
3451
class MatMul():
@@ -40,9 +57,13 @@ def opset_1(cls, graph, node, **kw):
4057
y = node.input('Y', idx=0)
4158
out = node.output('Out')
4259
if node.attr('trans_x'):
43-
x = graph.make_node('Transpose', inputs=[x])
60+
perm = list(range(len(node.input_shape('X', 0))))
61+
perm[-1], perm[-2] = perm[-2], perm[-1]
62+
x = graph.make_node('Transpose', inputs=[x], perm=perm)
4463
if node.attr('trans_y'):
45-
y = graph.make_node('Transpose', inputs=[y])
64+
perm = list(range(len(node.input_shape('Y', 0))))
65+
perm[-1], perm[-2] = perm[-2], perm[-1]
66+
y = graph.make_node('Transpose', inputs=[y], perm=perm)
4667
graph.make_node('MatMul', inputs=[x, y], outputs=out)
4768

4869

@@ -131,6 +152,30 @@ def opset_8(cls, graph, node, **kw):
131152
'Pow', inputs=[x, factor_broadcast], outputs=node.output('Out'))
132153

133154

155+
@op_mapper('square')
156+
class Square():
157+
support_opset_verision_range = (7, 12)
158+
159+
@classmethod
160+
def opset_7(cls, graph, node, **kw):
161+
x = node.input('X', 0)
162+
onnx_node = graph.make_node(
163+
'Mul', inputs=[x, x], outputs=node.output('Out'))
164+
165+
@op_mapper('cumsum')
166+
class CumSum():
167+
support_opset_version_range = (11, 12)
168+
169+
@classmethod
170+
def opset_11(cls, graph, node, **kw):
171+
172+
axis = graph.make_node('Constant', dtype=dtypes.ONNX.INT64, value=node.attr('axis'))
173+
graph.make_node(
174+
'CumSum',
175+
inputs=[node.input('X', 0), axis],
176+
outputs=node.output('Out'))
177+
178+
134179
@op_mapper('mul')
135180
class Mul():
136181
support_opset_version_range = (1, 12)
@@ -140,16 +185,24 @@ def opset_1(cls, graph, node, **kw):
140185
x = node.input('X', 0)
141186
y = node.input('Y', 0)
142187
out = node.output('Out', 0)
143-
x_shape = node.input_shape('X', 0)
144-
y_shape = node.input_shape('Y', 0)
145188
x_num_col_dims = node.attr('x_num_col_dims')
146189
y_num_col_dims = node.attr('y_num_col_dims')
147190
flatten_x = graph.make_node(
148191
'Flatten', inputs=node.input('X'), attrs={'axis': x_num_col_dims})
149192
flatten_y = graph.make_node(
150193
'Flatten', inputs=node.input('Y'), attrs={'axis': y_num_col_dims})
151-
mul_node = graph.make_node(
152-
'MatMul', inputs=[flatten_x, flatten_y], outputs=node.output('Out'))
194+
mul_node = graph.make_node('MatMul', inputs=[flatten_x, flatten_y])
195+
196+
x_shape = graph.make_node('Shape', inputs=[x])
197+
l_shape = mapper_helper.slice_helper(
198+
graph, x_shape, axes=[0], starts=[0], ends=[x_num_col_dims])
199+
y_shape = graph.make_node('Shape', inputs=[y])
200+
y_rank = len(node.input_shape('Y', 0))
201+
r_shape = mapper_helper.slice_helper(
202+
graph, y_shape, axes=[0], starts=[y_num_col_dims], ends=[y_rank])
203+
204+
out_shape = graph.make_node('Concat', inputs=[l_shape, r_shape], axis=0)
205+
graph.make_node('Reshape', [mul_node, out_shape], node.output('Out'))
153206

154207

155208
@op_mapper('affine_channel')
@@ -244,6 +297,19 @@ def opset_1(cls, graph, node, **kw):
244297
axes=[0])
245298

246299

300+
@op_mapper('mean')
301+
class Mean():
302+
support_opset_verison_range = (1, 12)
303+
304+
@classmethod
305+
def opset_1(cls, graph, node, **kw):
306+
graph.make_node(
307+
'ReduceMean',
308+
inputs=node.input('X'),
309+
outputs=node.output('Out'),
310+
keepdims=0)
311+
312+
247313
@op_mapper('arg_max')
248314
class ArgMax():
249315
support_opset_version_range = (1, 12)
@@ -282,26 +348,28 @@ def opset_7(cls, graph, node, **kw):
282348
'Identity', inputs=node.input('X'), outputs=node.output('Out'))
283349
else:
284350
scale_node = graph.make_node(
285-
'Constant', attrs={'dtype': dtypes.ONNX.FLOAT,
286-
'value': scale})
351+
'Constant',
352+
attrs={'dtype': dtypes.ONNX.FLOAT,
353+
'value': [scale]})
287354
bias_node = graph.make_node(
288-
'Constant', attrs={'dtype': dtypes.ONNX.FLOAT,
289-
'value': bias})
355+
'Constant',
356+
attrs={'dtype': dtypes.ONNX.FLOAT,
357+
'value': [bias]})
290358
cast_node = graph.make_node(
291359
'Cast', inputs=node.input('X'),
292360
attrs={'to': dtypes.ONNX.FLOAT})
293361
if node.attr('bias_after_scale'):
294-
node1 = graph.make_node('Mul', inputs=[scale_node, cast_node])
362+
node1 = graph.make_node('Mul', inputs=[cast_node, scale_node])
295363
node2 = graph.make_node(
296364
'Add',
297-
inputs=[bias_node, node1],
365+
inputs=[node1, bias_node],
298366
outputs=node.output('Out'))
299367
else:
300-
node1 = graph.make_node('Add', inputs=[bias_node, cast_node])
368+
node1 = graph.make_node('Add', inputs=[cast_node, bias_node])
301369
node2 = graph.make_node(
302370
'Mul',
303-
inputs=[scale_node, node1],
304-
outputs=[node.output('Out')])
371+
inputs=[node1, scale_node],
372+
outputs=[node.output('Out', 0)])
305373

306374

307375
@op_mapper('softmax')
@@ -333,3 +401,50 @@ def opset_1(cls, graph, node, **kw):
333401
inputs=[softmax_node],
334402
outputs=node.output('Out'),
335403
attrs={'perm': perm})
404+
405+
406+
@op_mapper('softmax_with_cross_entropy')
407+
class SoftmaxCrossEntropyLoss():
408+
support_opset_verison_range = (12, 12)
409+
410+
@classmethod
411+
def opset_12(cls, graph, node, **kw):
412+
if node.attr('soft_label'):
413+
raise Exception(
414+
"SoftmaxCrossEntropyLoss in onnx not support soft label.")
415+
416+
labels = node.input('Label', 0)
417+
scores = node.input('Logits', 0)
418+
419+
outputs = [node.output('Loss', 0)]
420+
if 'Softmax' in node.outputs:
421+
outputs.append(node.output('Softmax', 0))
422+
423+
shape = node.input_shape('Logits', 0)
424+
axis = node.attr('axis')
425+
if axis < 0:
426+
axis += len(shape)
427+
if axis == len(shape) - 1:
428+
graph.make_node(
429+
'SoftmaxCrossEntropyLoss',
430+
inputs=[scores, labels],
431+
outputs=outputs,
432+
ignore_index=node.attr('ignore_index'),
433+
reduction='mean')
434+
else:
435+
perm = [i for i in range(len(shape))]
436+
perm[-1] = axis
437+
perm[axis] = len(shape) - 1
438+
transpose_node = graph.make_node(
439+
'Transpose', inputs=node.input('X'), attrs={'perm': perm})
440+
node = graph.make_node(
441+
'SoftmaxCrossEntropyLoss',
442+
inputs=[scores, labels],
443+
outputs=outputs,
444+
ignore_index=node.attr('ignore_index'),
445+
reduction='mean')
446+
transpose_node1 = graph.make_node(
447+
'Transpose',
448+
inputs=[softmax_node],
449+
outputs=node.output('Out'),
450+
attrs={'perm': perm})

paddle2onnx/op_mapper/tensor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,21 @@ def opset_1(cls, graph, node, **kw):
365365
})
366366

367367

368+
@op_mapper(['lookup_table_v2', 'lookup_table'])
369+
class Embedding():
370+
support_opset_verison_range = (1, 12)
371+
372+
@classmethod
373+
def opset_1(cls, graph, node, **kw):
374+
ids = node.input('Ids', 0)
375+
if node.type == 'lookup_table' and node.input_shape('Ids', 0)[-1] == 1:
376+
ids = graph.make_node(
377+
'Squeeze', inputs=node.input('Ids', 0), axes=[-1])
378+
graph.make_node(
379+
'Gather',
380+
inputs=[node.input('W', 0), ids],
381+
outputs=node.output('Out'))
382+
368383
@op_mapper('fill_constant_batch_size_like')
369384
class FillConstantBatchSizeLike():
370385
support_opset_verison_range = (9, 12)
@@ -414,14 +429,15 @@ def opset_9(cls, graph, node, **kw):
414429
input_dtype = node.input_var('X', 0).dtype
415430
if dtype is None:
416431
dtype = input_dtype
417-
dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[dtype]
432+
np_dtype = dtypes.DTYPE_PADDLE_STR_MAP[dtype]
433+
onnx_dtype = dtypes.DTYPE_PADDLE_ONNX_MAP[dtype]
418434
graph.make_node(
419435
'ConstantOfShape',
420436
inputs=[shape_node],
421437
outputs=node.output('Out'),
422438
dims=[1],
423-
dtype=dtype,
424-
value=value)
439+
dtype=onnx_dtype,
440+
value=np.array(value).astype(np_dtype))
425441

426442

427443
@op_mapper('gather')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
setuptools.setup(
2424
name="paddle2onnx",
25-
version=0.4,
25+
version=0.5,
2626
author="dltp-sz",
2727
author_email="dltp-sz@baidu.com",
2828
description="a toolkit for converting trained model of PaddlePaddle to ONNX.",

0 commit comments

Comments
 (0)