File tree 1 file changed +10
-2
lines changed
1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change 13
13
import logging
14
14
15
15
import onnx
16
- from onnx import helper
16
+ from onnx import helper , shape_inference
17
17
18
18
from tf2onnx .graph import GraphUtil
19
19
from tf2onnx import logging , optimizer , constants
@@ -46,6 +46,12 @@ def load_graph(fname, target):
46
46
return g , model_proto
47
47
48
48
49
+ def model_shape_inference (onnx_model_proto ):
50
+ inferred_model = shape_inference .infer_shapes (onnx_model_proto )
51
+ onnx .checker .check_model (inferred_model )
52
+ return inferred_model
53
+
54
+
49
55
def main ():
50
56
args = get_args ()
51
57
@@ -64,10 +70,12 @@ def main():
64
70
65
71
model_proto = helper .make_model (onnx_graph , ** kwargs )
66
72
73
+ model_proto_inferred = model_shape_inference (model_proto )
74
+
67
75
# write onnx graph
68
76
if args .output :
69
77
with open (args .output , "wb" ) as f :
70
- f .write (model_proto .SerializeToString ())
78
+ f .write (model_proto_inferred .SerializeToString ())
71
79
72
80
73
81
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments