Skip to content

Commit c339df5

Browse files
authored
Add static shape propagation and broadcasting support for ONNX IR operations (tracel-ai#3763)
* Add ONNX tests for mixed concat inputs Introduces new ONNX models and Rust tests for concat operations with mixed inputs (Shape and constant rank-1 tensors). Updates concat output rank estimation logic in onnx-ir to better handle rank-1 tensor contributions for mixed input scenarios. * Propagate static shapes for constants, where, and broadcast ops Adds static shape propagation for Constant and Where nodes, and updates broadcast utilities to propagate static shapes when possible. Introduces new ONNX test cases for concat, expand, and where ops with static shapes, and updates tests to verify correct shape inference. Also clarifies MatMulInteger zero point assertions to allow scalar or per-channel zero points. * Implement NumPy-style broadcasting for static shapes Updated `compute_broadcast_static_shape` to follow NumPy-style broadcasting rules for compatible shapes, replacing the previous conservative approach. Added comprehensive tests to verify correct broadcasting behavior for same shapes, compatible shapes, different ranks, scalar broadcasting, and incompatible shapes.
1 parent 2e4f6d2 commit c339df5

22 files changed

+999
-37
lines changed

crates/burn-import/onnx-tests/build.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ fn main() {
6969
.input("tests/concat/concat.onnx")
7070
.input("tests/concat/concat_shape.onnx")
7171
.input("tests/concat/concat_shape_with_constant.onnx")
72+
.input("tests/concat/concat_mixed_single_element.onnx")
73+
.input("tests/concat/concat_mixed_three_elements.onnx")
74+
.input("tests/concat/concat_multiple_mixed.onnx")
75+
.input("tests/concat/concat_with_constants.onnx")
7276
.input("tests/constant/constant_f32.onnx")
7377
.input("tests/constant/constant_f64.onnx")
7478
.input("tests/constant/constant_i32.onnx")
@@ -112,6 +116,7 @@ fn main() {
112116
.input("tests/expand/expand.onnx")
113117
.input("tests/expand/expand_tensor.onnx")
114118
.input("tests/expand/expand_shape.onnx")
119+
.input("tests/expand/expand_with_where_shape.onnx")
115120
.input("tests/eye_like/eye_like.onnx")
116121
.input("tests/eye_like/eye_like_k1.onnx")
117122
.input("tests/eye_like/eye_like_int.onnx")
@@ -188,6 +193,7 @@ fn main() {
188193
.input("tests/where_op/where_shape_all_shapes.onnx")
189194
.input("tests/where_op/where_shape_scalar_cond.onnx")
190195
.input("tests/where_op/where_shapes_from_inputs.onnx")
196+
.input("tests/where_op/where_static_shape.onnx")
191197
.input("tests/matmul/matmul.onnx")
192198
.input("tests/matmulinteger/matmulinteger.onnx")
193199
.input("tests/matmulinteger/matmulinteger_ranks.onnx")
Binary file not shown.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
3+
# Test: concat with Shape and single-element constant tensor
4+
5+
import onnx
6+
import onnx.helper
7+
import numpy as np
8+
9+
10+
def build_model():
11+
# Create a constant node with a single-element rank-1 tensor
12+
const_node = onnx.helper.make_node(
13+
"Constant",
14+
inputs=[],
15+
outputs=["const_single"],
16+
value=onnx.helper.make_tensor(
17+
name="const_value",
18+
data_type=onnx.TensorProto.INT64,
19+
dims=[1],
20+
vals=[100]
21+
),
22+
name="/Constant"
23+
)
24+
25+
# Create Shape node to extract shape from input tensor
26+
shape_node = onnx.helper.make_node(
27+
"Shape",
28+
inputs=["input1"],
29+
outputs=["shape1"],
30+
name="/Shape"
31+
)
32+
33+
# Create a Concat node that concatenates shape and single-element constant
34+
concat_node = onnx.helper.make_node(
35+
"Concat",
36+
inputs=["shape1", "const_single"],
37+
outputs=["concatenated"],
38+
axis=0,
39+
name="/Concat"
40+
)
41+
42+
# Create the graph
43+
graph = onnx.helper.make_graph(
44+
name="main_graph",
45+
nodes=[const_node, shape_node, concat_node],
46+
inputs=[
47+
onnx.helper.make_value_info(
48+
name="input1",
49+
type_proto=onnx.helper.make_tensor_type_proto(
50+
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
51+
),
52+
),
53+
],
54+
outputs=[
55+
onnx.helper.make_value_info(
56+
name="concatenated",
57+
type_proto=onnx.helper.make_tensor_type_proto(
58+
elem_type=onnx.TensorProto.INT64, shape=[3] # 2 + 1 = 3
59+
),
60+
)
61+
]
62+
)
63+
64+
# Create the model
65+
model = onnx.helper.make_model(
66+
graph,
67+
ir_version=8,
68+
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
69+
)
70+
71+
return model
72+
73+
74+
def main():
75+
onnx_model = build_model()
76+
file_name = "concat_mixed_single_element.onnx"
77+
onnx.save(onnx_model, file_name)
78+
onnx.checker.check_model(file_name)
79+
80+
print(f"Finished exporting model to {file_name}")
81+
82+
# Test with onnx.reference.ReferenceEvaluator
83+
try:
84+
from onnx.reference import ReferenceEvaluator
85+
86+
# Create test data
87+
test_input = np.ones((2, 3), dtype=np.float32)
88+
89+
# Run inference
90+
sess = ReferenceEvaluator(onnx_model)
91+
result = sess.run(None, {"input1": test_input})
92+
93+
print(f"Test input shape: {test_input.shape}")
94+
print(f"Concatenated output: {result[0]}")
95+
print(f"Expected: [2, 3, 100]")
96+
97+
except ImportError:
98+
print("onnx.reference not available, skipping inference test")
99+
100+
101+
if __name__ == "__main__":
102+
main()
Binary file not shown.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
3+
# Test: concat with Shape and three-element constant tensor
4+
5+
import onnx
6+
import onnx.helper
7+
import numpy as np
8+
9+
10+
def build_model():
11+
# Create a constant node with a three-element rank-1 tensor
12+
const_node = onnx.helper.make_node(
13+
"Constant",
14+
inputs=[],
15+
outputs=["const_three"],
16+
value=onnx.helper.make_tensor(
17+
name="const_value",
18+
data_type=onnx.TensorProto.INT64,
19+
dims=[3],
20+
vals=[10, 20, 30]
21+
),
22+
name="/Constant"
23+
)
24+
25+
# Create Shape node to extract shape from input tensor
26+
shape_node = onnx.helper.make_node(
27+
"Shape",
28+
inputs=["input1"],
29+
outputs=["shape1"],
30+
name="/Shape"
31+
)
32+
33+
# Create a Concat node that concatenates shape and three-element constant
34+
concat_node = onnx.helper.make_node(
35+
"Concat",
36+
inputs=["shape1", "const_three"],
37+
outputs=["concatenated"],
38+
axis=0,
39+
name="/Concat"
40+
)
41+
42+
# Create the graph
43+
graph = onnx.helper.make_graph(
44+
name="main_graph",
45+
nodes=[const_node, shape_node, concat_node],
46+
inputs=[
47+
onnx.helper.make_value_info(
48+
name="input1",
49+
type_proto=onnx.helper.make_tensor_type_proto(
50+
elem_type=onnx.TensorProto.FLOAT, shape=[4, 5, 6]
51+
),
52+
),
53+
],
54+
outputs=[
55+
onnx.helper.make_value_info(
56+
name="concatenated",
57+
type_proto=onnx.helper.make_tensor_type_proto(
58+
elem_type=onnx.TensorProto.INT64, shape=[6] # 3 + 3 = 6
59+
),
60+
)
61+
]
62+
)
63+
64+
# Create the model
65+
model = onnx.helper.make_model(
66+
graph,
67+
ir_version=8,
68+
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
69+
)
70+
71+
return model
72+
73+
74+
def main():
75+
onnx_model = build_model()
76+
file_name = "concat_mixed_three_elements.onnx"
77+
onnx.save(onnx_model, file_name)
78+
onnx.checker.check_model(file_name)
79+
80+
print(f"Finished exporting model to {file_name}")
81+
82+
# Test with onnx.reference.ReferenceEvaluator
83+
try:
84+
from onnx.reference import ReferenceEvaluator
85+
86+
# Create test data
87+
test_input = np.ones((4, 5, 6), dtype=np.float32)
88+
89+
# Run inference
90+
sess = ReferenceEvaluator(onnx_model)
91+
result = sess.run(None, {"input1": test_input})
92+
93+
print(f"Test input shape: {test_input.shape}")
94+
print(f"Concatenated output: {result[0]}")
95+
print(f"Expected: [4, 5, 6, 10, 20, 30]")
96+
97+
except ImportError:
98+
print("onnx.reference not available, skipping inference test")
99+
100+
101+
if __name__ == "__main__":
102+
main()
Binary file not shown.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env python3
2+
3+
# Test: concat with multiple Shapes and multiple constant tensors
4+
5+
import onnx
6+
import onnx.helper
7+
import numpy as np
8+
9+
10+
def build_model():
11+
# Create first constant with 2 elements
12+
const1_node = onnx.helper.make_node(
13+
"Constant",
14+
inputs=[],
15+
outputs=["const1"],
16+
value=onnx.helper.make_tensor(
17+
name="const1_value",
18+
data_type=onnx.TensorProto.INT64,
19+
dims=[2],
20+
vals=[100, 200]
21+
),
22+
name="/Constant1"
23+
)
24+
25+
# Create second constant with 1 element
26+
const2_node = onnx.helper.make_node(
27+
"Constant",
28+
inputs=[],
29+
outputs=["const2"],
30+
value=onnx.helper.make_tensor(
31+
name="const2_value",
32+
data_type=onnx.TensorProto.INT64,
33+
dims=[1],
34+
vals=[300]
35+
),
36+
name="/Constant2"
37+
)
38+
39+
# Create Shape nodes for two input tensors
40+
shape1_node = onnx.helper.make_node(
41+
"Shape",
42+
inputs=["input1"],
43+
outputs=["shape1"],
44+
name="/Shape1"
45+
)
46+
47+
shape2_node = onnx.helper.make_node(
48+
"Shape",
49+
inputs=["input2"],
50+
outputs=["shape2"],
51+
name="/Shape2"
52+
)
53+
54+
# Create a Concat node with mixed inputs: shape, const, shape, const
55+
concat_node = onnx.helper.make_node(
56+
"Concat",
57+
inputs=["shape1", "const1", "shape2", "const2"],
58+
outputs=["concatenated"],
59+
axis=0,
60+
name="/Concat"
61+
)
62+
63+
# Create the graph
64+
graph = onnx.helper.make_graph(
65+
name="main_graph",
66+
nodes=[const1_node, const2_node, shape1_node, shape2_node, concat_node],
67+
inputs=[
68+
onnx.helper.make_value_info(
69+
name="input1",
70+
type_proto=onnx.helper.make_tensor_type_proto(
71+
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3]
72+
),
73+
),
74+
onnx.helper.make_value_info(
75+
name="input2",
76+
type_proto=onnx.helper.make_tensor_type_proto(
77+
elem_type=onnx.TensorProto.FLOAT, shape=[4, 5, 6]
78+
),
79+
),
80+
],
81+
outputs=[
82+
onnx.helper.make_value_info(
83+
name="concatenated",
84+
type_proto=onnx.helper.make_tensor_type_proto(
85+
elem_type=onnx.TensorProto.INT64, shape=[8] # 2 + 2 + 3 + 1 = 8
86+
),
87+
)
88+
]
89+
)
90+
91+
# Create the model
92+
model = onnx.helper.make_model(
93+
graph,
94+
ir_version=8,
95+
opset_imports=[onnx.helper.make_operatorsetid("", 16)]
96+
)
97+
98+
return model
99+
100+
101+
def main():
102+
onnx_model = build_model()
103+
file_name = "concat_multiple_mixed.onnx"
104+
onnx.save(onnx_model, file_name)
105+
onnx.checker.check_model(file_name)
106+
107+
print(f"Finished exporting model to {file_name}")
108+
109+
# Test with onnx.reference.ReferenceEvaluator
110+
try:
111+
from onnx.reference import ReferenceEvaluator
112+
113+
# Create test data
114+
test_input1 = np.ones((2, 3), dtype=np.float32)
115+
test_input2 = np.ones((4, 5, 6), dtype=np.float32)
116+
117+
# Run inference
118+
sess = ReferenceEvaluator(onnx_model)
119+
result = sess.run(None, {"input1": test_input1, "input2": test_input2})
120+
121+
print(f"Test input1 shape: {test_input1.shape}")
122+
print(f"Test input2 shape: {test_input2.shape}")
123+
print(f"Concatenated output: {result[0]}")
124+
print(f"Expected: [2, 3, 100, 200, 4, 5, 6, 300]")
125+
126+
except ImportError:
127+
print("onnx.reference not available, skipping inference test")
128+
129+
130+
if __name__ == "__main__":
131+
main()
Binary file not shown.

0 commit comments

Comments
 (0)