Skip to content

Commit 8388709

Browse files
authored
Fix usage example (#3337)
1 parent 25075e2 commit 8388709

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

docsrc/user_guide/saving_models.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Here's an example usage
3434
model = MyModel().eval().cuda()
3535
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
3636
# trt_ep is a torch.fx.GraphModule object
37-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
37+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
3838
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)
3939
4040
# Later, you can load it and run inference
@@ -52,7 +52,7 @@ b) Torchscript
5252
model = MyModel().eval().cuda()
5353
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
5454
# trt_gm is a torch.fx.GraphModule object
55-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
55+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
5656
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
5757
5858
# Later, you can load it and run inference
@@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
7373
7474
model = MyModel().eval().cuda()
7575
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
76-
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
76+
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
7777
torch.jit.save(trt_ts, "trt_model.ts")
7878
7979
# Later, you can load it and run inference

0 commit comments

Comments
 (0)