Skip to content

Commit cf31a2b

Browse files
committed
fix shape_helper
1 parent ac5b79d commit cf31a2b

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

paddle2onnx/op_mapper/mapper_helper.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,7 @@ def shape_helper(graph, input, dim=None):
3333
shape_node = graph.make_node('Shape', inputs=[input])
3434
return shape_node
3535
full_shape = graph.make_node('Shape', inputs=[input])
36-
start_node = graph.make_node(
37-
'Constant', dtype=dtypes.ONNX.INT64, value=[dim])
38-
ends_node = graph.make_node(
39-
'Constant', dtype=dtypes.ONNX.INT64, value=[dim + 1])
40-
axes_node = graph.make_node('Constant', dtype=dtypes.ONNX.INT64, value=[0])
41-
shape_node = graph.make_node(
42-
"Slice", inputs=[full_shape, start_node, ends_node, axes_node])
36+
shape_node = slice_helper(graph, full_shape, [0], [dim], [dim + 1])
4337
return shape_node
4438

4539

0 commit comments

Comments
 (0)