Skip to content

Commit 4e6eace

Browse files
committed
Merge branch 'main' of https://github.com/onnx/tensorflow-onnx into ci
2 parents e9b45bc + 3dd7729 commit 4e6eace

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tools/onnx-optimize.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414

1515
import onnx
16-
from onnx import helper
16+
from onnx import helper, shape_inference
1717

1818
from tf2onnx.graph import GraphUtil
1919
from tf2onnx import logging, optimizer, constants
@@ -46,6 +46,12 @@ def load_graph(fname, target):
4646
return g, model_proto
4747

4848

49+
def model_shape_inference(onnx_model_proto):
50+
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
51+
onnx.checker.check_model(inferred_model)
52+
return inferred_model
53+
54+
4955
def main():
5056
args = get_args()
5157

@@ -64,10 +70,12 @@ def main():
6470

6571
model_proto = helper.make_model(onnx_graph, **kwargs)
6672

73+
model_proto_inferred = model_shape_inference(model_proto)
74+
6775
# write onnx graph
6876
if args.output:
6977
with open(args.output, "wb") as f:
70-
f.write(model_proto.SerializeToString())
78+
f.write(model_proto_inferred.SerializeToString())
7179

7280

7381
if __name__ == "__main__":

0 commit comments

Comments
 (0)