Skip to content

Commit be8b921

Browse files
committed
Disable fuse_consecutive_squeezes pass for negative axes
Signed-off-by: take-cheeze <takechi101010@gmail.com>
1 parent 5627f51 commit be8b921

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

onnxoptimizer/passes/fuse_consecutive_squeezes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ struct FuseConsecutiveSqueezes final : public PredicateBasedPass {
4444
!GetValueFromAttrOrInput(n, kaxes, 1, axes_2)) {
4545
return false;
4646
}
47+
if (std::any_of(axes_1.begin(), axes_1.end(), [](int64_t v) { return v < 0; }) ||
48+
std::any_of(axes_2.begin(), axes_2.end(), [](int64_t v) { return v < 0; })) {
49+
return false;
50+
}
4751

4852
std::vector<int64_t> &ret = composed_axes;
4953
ret.clear();

onnxoptimizer/test/optimizer_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2872,6 +2872,18 @@ def test_fuse_consecutive_squeezes_multi_uses(self): # type: () -> None
28722872
if init.name == optimized_model.graph.node[2].input[1]:
28732873
assert list(to_array(init)) == [0, 1, 4, 5, 6]
28742874

2875+
def test_fuse_consecutive_squeezes_negative_axes(self): # type: () -> None
2876+
graph = parser.parse_graph("""
2877+
agraph (float[5, 7, 1, 1] X) => (float[5, 7] Z)
2878+
{
2879+
Axes = Constant<value=int64[1]{-1}> ()
2880+
Y = Squeeze (X, Axes)
2881+
Z = Squeeze (Y, Axes)
2882+
}
2883+
""")
2884+
optimized_model = self._optimized(graph, ["fuse_consecutive_squeezes"])
2885+
assert len(optimized_model.graph.node) == 3
2886+
28752887
@pytest.mark.xfail
28762888
def test_fuse_consecutive_softmax_log_axis(self): # type: () -> None
28772889
for axis in range(3):

0 commit comments

Comments
 (0)