Skip to content

Commit 4161071

Browse files
authored
Merge pull request #651 from robertknight/timm-export-dynamo
Add option to use dynamo ONNX exporter in export-timm-model.py
2 parents 0101455 + 4ec3975 commit 4161071

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

Diff for: tools/export-timm-model.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def print_predictions(scores: torch.Tensor):
5151
print(f" {label_desc} ({prob:.2f})")
5252

5353

54-
def export_timm_model(config: str, onnx_path: str):
54+
def export_timm_model(config: str, onnx_path: str, dynamo: bool = False):
5555
"""
5656
Export a PyTorch model from timm to ONNX.
5757
@@ -89,16 +89,13 @@ def export_timm_model(config: str, onnx_path: str):
8989
print_predictions(output)
9090

9191
print(f"Exporting model to {onnx_path}")
92-
torch.onnx.export(model, input_img, onnx_path)
92+
torch.onnx.export(model, input_img, onnx_path, dynamo=dynamo)
9393

9494
# Test exported model with ONNX Runtime as a reference implementation.
9595
#
9696
# We test both with graph optimizations disabled and enabled, to show the
9797
# impact of running the ONNX model "as is" vs. with the various fusions that
9898
# ONNX Runtime does.
99-
#
100-
# RTen currently doesn't do any fusions, so the unoptimized performance
101-
# is a "fairer" comparison.
10299
print(f"Testing model with ONNX Runtime (unoptimized)...")
103100
sess_options = ort.SessionOptions()
104101
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
@@ -145,14 +142,20 @@ def main():
145142
help="Name of the model configuration or Hugging Face model URL or path",
146143
)
147144
parser.add_argument("onnx_path", nargs="?", help="Path to ONNX file")
145+
parser.add_argument(
146+
"-d",
147+
"--dynamo",
148+
action="store_true",
149+
help="Use PyTorch's newer TorchDynamo-based ONNX exporter",
150+
)
148151
args = parser.parse_args()
149152

150153
config_name = extract_config_name(args.model_config)
151154
onnx_path = args.onnx_path
152155
if onnx_path is None:
153156
onnx_path = config_name + ".onnx"
154157

155-
export_timm_model(config_name, onnx_path)
158+
export_timm_model(config_name, onnx_path, dynamo=args.dynamo)
156159

157160

158161
if __name__ == "__main__":

0 commit comments

Comments
 (0)