@@ -34,7 +34,7 @@ Here's an example usage
34
34
model = MyModel().eval().cuda()
35
35
inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
36
36
# trt_ep is a torch.fx.GraphModule object
37
- trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs)
37
+ trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs = inputs)
38
38
torch_tensorrt.save(trt_gm, " trt.ep" , inputs = inputs)
39
39
40
40
# Later, you can load it and run inference
@@ -52,7 +52,7 @@ b) Torchscript
52
52
model = MyModel().eval().cuda()
53
53
inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
54
54
# trt_gm is a torch.fx.GraphModule object
55
- trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs)
55
+ trt_gm = torch_tensorrt.compile(model, ir = " dynamo" , inputs = inputs )
56
56
torch_tensorrt.save(trt_gm, " trt.ts" , output_format = " torchscript" , inputs = inputs)
57
57
58
58
# Later, you can load it and run inference
@@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
73
73
74
74
model = MyModel().eval().cuda()
75
75
inputs = [torch.randn((1 , 3 , 224 , 224 )).cuda()]
76
- trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs) # Output is a ScriptModule object
76
+ trt_ts = torch_tensorrt.compile(model, ir = " ts" , inputs = inputs ) # Output is a ScriptModule object
77
77
torch.jit.save(trt_ts, " trt_model.ts" )
78
78
79
79
# Later, you can load it and run inference
0 commit comments