Skip to content

Commit 78925bf

Browse files
matteosalMatteo Salvarezzaandifejustinchuby
authored
Fix default attribute values in shape inference (onnx#7602)
This is an attempt to fix onnx#7573 When using the default value for an attribute, there are two separate scenarios: * The attribute is not present * The attribute is present but the numerical value for the attribute is not set The previous logic was using the schema default in both cases, but in the second case one should use the protobuf default for the relevant data type instead. This change fixes the example reported in the issue. --------- Signed-off-by: Matteo Salvarezza <matteos@wolfram.com> Co-authored-by: Matteo Salvarezza <matteos@wolfram.com> Co-authored-by: Andreas Fehlner <fehlner@arcor.de> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent ee910d0 commit 78925bf

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

onnx/defs/shape_inference.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,22 +161,31 @@ inline int64_t getAttribute(const InferenceContext& ctx, const std::string& attr
161161
const auto* attr_proto = ctx.getAttribute(attributeName);
162162
if ((nullptr != attr_proto) && attr_proto->has_i())
163163
return attr_proto->i();
164-
return defaultValue;
164+
else if (nullptr != attr_proto)
165+
return 0; // protobuf default for integers
166+
else
167+
return defaultValue;
165168
}
166169

167170
inline int64_t getAttribute(const DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
168171
const auto* attr_proto = ctx.getAttribute(attributeName);
169172
if ((nullptr != attr_proto) && attr_proto->has_i())
170173
return attr_proto->i();
171-
return defaultValue;
174+
else if (nullptr != attr_proto)
175+
return 0; // protobuf default for integers
176+
else
177+
return defaultValue;
172178
}
173179

174180
inline std::string
175181
getAttribute(const InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
176182
const auto* attr_proto = ctx.getAttribute(attributeName);
177183
if ((nullptr != attr_proto) && attr_proto->has_s())
178184
return attr_proto->s();
179-
return defaultValue;
185+
else if (nullptr != attr_proto)
186+
return ""; // protobuf default for strings
187+
else
188+
return defaultValue;
180189
}
181190

182191
inline TensorShapeProto::Dimension operator*(

onnx/test/shape_inference_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313
import pytest
14+
from google.protobuf import text_format
1415
from parameterized import parameterized
1516

1617
import onnx.shape_inference
@@ -10855,6 +10856,61 @@ def test_issue_constantofshape_6135(self, _, version):
1085510856
opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)],
1085610857
)
1085710858

10859+
def test_protobuf_default(self) -> None:
10860+
model_text = """
10861+
ir_version: 8
10862+
producer_name: "test"
10863+
graph {
10864+
node {
10865+
input: "in"
10866+
output: "out"
10867+
op_type: "Flatten"
10868+
attribute {
10869+
name: "axis"
10870+
type: INT
10871+
}
10872+
}
10873+
name: "g"
10874+
input {
10875+
name: "in"
10876+
type {
10877+
tensor_type {
10878+
elem_type: 1
10879+
shape {
10880+
dim {
10881+
dim_value: 2
10882+
}
10883+
dim {
10884+
dim_value: 3
10885+
}
10886+
}
10887+
}
10888+
}
10889+
}
10890+
output {
10891+
name: "out"
10892+
type {
10893+
tensor_type {
10894+
elem_type: 1
10895+
shape {
10896+
dim {
10897+
dim_value: 1
10898+
}
10899+
dim {
10900+
dim_value: 6
10901+
}
10902+
}
10903+
}
10904+
}
10905+
}
10906+
}
10907+
opset_import {
10908+
version: 18
10909+
}
10910+
"""
10911+
model = text_format.Parse(model_text, onnx.ModelProto())
10912+
self._assert_inferred(model, [])
10913+
1085810914

1085910915
class TestCustomSchemaShapeInference(TestShapeInferenceHelper):
1086010916
custom_op_type: str = "CustomOp"

0 commit comments

Comments
 (0)