Skip to content

Commit 55d7d16

Browse files
committed
fix: we now store the traced symbolic functions from compile time in the metadata to use in the case of reexport. Also removes the need to access the real tensorrt engine during reexport
1 parent c341d98 commit 55d7d16

19 files changed

+2173
-144
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ std::string TRTEngine::get_engine_layer_info() {
325325
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
326326
}
327327

328+
std::string TRTEngine::get_serialized_metadata() {
329+
return this->serialized_metadata;
330+
}
331+
328332
std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
329333
std::vector<at::Tensor> outputs;
330334
TORCHTRT_CHECK(

core/runtime/TRTEngine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct TRTEngine : torch::CustomClassHolder {
158158
void set_profile_format(std::string profile_format);
159159
void disable_profiling();
160160
std::string get_engine_layer_info();
161+
std::string get_serialized_metadata();
161162

162163
void dump_engine_layer_info_to_file(const std::string& path);
163164
void dump_engine_layer_info();

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8888
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8989
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
9090
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
91+
.def("get_serialized_metadata", &TRTEngine::get_serialized_metadata)
9192
.def("infer_outputs", &TRTEngine::infer_outputs)
9293
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
9394
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)

docsrc/user_guide/dynamic_shapes.rst

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Custom Dynamic Shape Constraints
4949
---------------------------------
5050

5151
Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
52-
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
52+
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
5353
`torch.export.Dim` objects with the provided dynamic dimensions accordingly. Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
5454
If you have to set any custom constraints to your model (by using `torch.export.Dim`), we recommend exporting your program first before compiling with Torch-TensorRT.
5555
Please refer to this `documentation <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_ to export the Pytorch module with dynamic shapes.
@@ -78,7 +78,6 @@ Here's a simple example that exports a matmul layer with some restrictions on dy
7878
# Run inference
7979
trt_gm(*inputs)
8080
81-
8281
Dynamic shapes using torch.compile (JIT)
8382
------------------------------------
8483

@@ -102,3 +101,164 @@ to avoid recompilation of TensorRT engines.
102101
# No recompilation of TRT engines with modified batch size
103102
inputs_bs2 = torch.randn((2, 3, 224, 224), dtype=torch.float32)
104103
trt_gm(inputs_bs2)
104+
105+
106+
Saving and Loading Models with Dynamic Shapes
107+
----------------------------------------------
108+
109+
When you compile a model with dynamic shapes and want to save it for later use, you need to preserve the dynamic shape
110+
specifications. Torch-TensorRT provides two methods to accomplish this:
111+
112+
Method 1: Automatic Inference from torch_tensorrt.Input
113+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
114+
115+
The simplest approach is to pass the same ``torch_tensorrt.Input`` objects (with min/opt/max shapes) to both ``compile()`` and ``save()``.
116+
The dynamic shape specifications will be inferred automatically:
117+
118+
.. code-block:: python
119+
120+
import torch
121+
import torch_tensorrt
122+
123+
model = MyModel().eval().cuda()
124+
125+
# Define Input with dynamic shapes once
126+
inputs = [
127+
torch_tensorrt.Input(
128+
min_shape=(1, 3, 224, 224),
129+
opt_shape=(8, 3, 224, 224),
130+
max_shape=(32, 3, 224, 224),
131+
dtype=torch.float32,
132+
name="x" # Optional: provides better dimension naming
133+
)
134+
]
135+
136+
# Compile with dynamic shapes
137+
trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
138+
139+
# Save - dynamic shapes inferred automatically!
140+
torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs)
141+
142+
# Load and use with different batch sizes
143+
loaded_model = torch_tensorrt.load("model.ep").module()
144+
output1 = loaded_model(torch.randn(4, 3, 224, 224).cuda()) # Works!
145+
output2 = loaded_model(torch.randn(16, 3, 224, 224).cuda()) # Works!
146+
147+
148+
Method 2: Explicit torch.export.Dim Specification
149+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
150+
151+
For advanced use cases or when you need fine-grained control over dimension naming, you can explicitly provide ``dynamic_shapes``
152+
using ``torch.export.Dim``:
153+
154+
.. code-block:: python
155+
156+
import torch
157+
import torch_tensorrt
158+
159+
model = MyModel().eval().cuda()
160+
example_input = torch.randn((2, 3, 224, 224)).cuda()
161+
162+
# Define dynamic dimensions explicitly
163+
dyn_batch = torch.export.Dim("batch", min=1, max=32)
164+
dynamic_shapes = {"x": {0: dyn_batch}}
165+
166+
# Export with dynamic shapes
167+
exp_program = torch.export.export(
168+
model, (example_input,),
169+
dynamic_shapes=dynamic_shapes,
170+
strict=False
171+
)
172+
173+
# Compile
174+
trt_model = torch_tensorrt.dynamo.compile(
175+
exp_program,
176+
inputs=[torch_tensorrt.Input(
177+
min_shape=(1, 3, 224, 224),
178+
opt_shape=(8, 3, 224, 224),
179+
max_shape=(32, 3, 224, 224),
180+
)]
181+
)
182+
183+
# Save with explicit dynamic_shapes
184+
torch_tensorrt.save(
185+
trt_model,
186+
"model.ep",
187+
arg_inputs=[example_input],
188+
dynamic_shapes=dynamic_shapes # Same as used during export
189+
)
190+
191+
# Load and use
192+
loaded_model = torch_tensorrt.load("model.ep").module()
193+
194+
**When to use this method:**
195+
- You need specific dimension names for torch.export compatibility
196+
- You're working with existing torch.export workflows
197+
- You require fine-grained control over dynamic dimension specifications
198+
199+
Multiple Dynamic Dimensions
200+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201+
202+
Both methods support multiple dynamic dimensions (e.g., dynamic batch, height, and width):
203+
204+
.. code-block:: python
205+
206+
# Method 1 (Automatic): Multiple dynamic dimensions
207+
inputs = [
208+
torch_tensorrt.Input(
209+
min_shape=(1, 3, 64, 64),
210+
opt_shape=(8, 3, 256, 256),
211+
max_shape=(16, 3, 512, 512),
212+
name="image"
213+
)
214+
]
215+
216+
trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
217+
torch_tensorrt.save(trt_model, "model.ep", arg_inputs=inputs) # All 3 dims inferred!
218+
219+
# Load and test with various sizes
220+
loaded = torch_tensorrt.load("model.ep").module()
221+
loaded(torch.randn(4, 3, 128, 128).cuda())
222+
loaded(torch.randn(12, 3, 384, 384).cuda())
223+
224+
Saving with Keyword Arguments
225+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
226+
227+
If your model uses keyword arguments with dynamic shapes, both methods support them:
228+
229+
.. code-block:: python
230+
231+
# Define dynamic inputs for both args and kwargs
232+
arg_inputs = [
233+
torch_tensorrt.Input(
234+
min_shape=(1, 10),
235+
opt_shape=(4, 10),
236+
max_shape=(8, 10),
237+
name="x"
238+
)
239+
]
240+
241+
kwarg_inputs = {
242+
"mask": torch_tensorrt.Input(
243+
min_shape=(1, 5),
244+
opt_shape=(4, 5),
245+
max_shape=(8, 5),
246+
name="mask"
247+
)
248+
}
249+
250+
# Compile
251+
trt_model = torch_tensorrt.compile(
252+
model,
253+
ir="dynamo",
254+
arg_inputs=arg_inputs,
255+
kwarg_inputs=kwarg_inputs
256+
)
257+
258+
# Save - both arg and kwarg dynamic shapes inferred automatically
259+
torch_tensorrt.save(
260+
trt_model,
261+
"model.ep",
262+
arg_inputs=arg_inputs,
263+
kwarg_inputs=kwarg_inputs
264+
)

docsrc/user_guide/saving_models.rst

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,98 @@ Here's an example usage
4242
model = torch.export.load("trt.ep").module()
4343
model(*inputs)
4444
45+
46+
Saving Models with Dynamic Shapes
47+
""""""""""""""""""""""""""""""""""
48+
49+
When saving models compiled with dynamic shapes, you have two methods to preserve
50+
the dynamic shape specifications:
51+
52+
**Method 1: Using torch.export.Dim (explicit)**
53+
54+
Provide explicit ``dynamic_shapes`` parameter following torch.export's pattern:
55+
56+
.. code-block:: python
57+
58+
import torch
59+
import torch_tensorrt
60+
61+
model = MyModel().eval().cuda()
62+
example_input = torch.randn((2, 3, 224, 224)).cuda()
63+
64+
# Define dynamic batch dimension
65+
dyn_batch = torch.export.Dim("batch", min=1, max=32)
66+
dynamic_shapes = {"x": {0: dyn_batch}}
67+
68+
# Export with dynamic shapes
69+
exp_program = torch.export.export(
70+
model, (example_input,),
71+
dynamic_shapes=dynamic_shapes,
72+
strict=False
73+
)
74+
75+
# Compile with dynamic input specifications
76+
trt_gm = torch_tensorrt.dynamo.compile(
77+
exp_program,
78+
inputs=[torch_tensorrt.Input(
79+
min_shape=(1, 3, 224, 224),
80+
opt_shape=(8, 3, 224, 224),
81+
max_shape=(32, 3, 224, 224),
82+
)]
83+
)
84+
85+
# Save with dynamic_shapes to preserve dynamic behavior
86+
torch_tensorrt.save(
87+
trt_gm,
88+
"trt_dynamic.ep",
89+
arg_inputs=[example_input],
90+
dynamic_shapes=dynamic_shapes # Same as used during export
91+
)
92+
93+
# Load and use with different batch sizes
94+
loaded_model = torch_tensorrt.load("trt_dynamic.ep").module()
95+
output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda())
96+
output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda())
97+
98+
**Method 2: Using torch_tensorrt.Input**
99+
100+
Pass ``torch_tensorrt.Input`` objects with min/opt/max shapes directly, and the
101+
dynamic shapes will be inferred automatically:
102+
103+
.. code-block:: python
104+
105+
import torch
106+
import torch_tensorrt
107+
108+
model = MyModel().eval().cuda()
109+
110+
# Define Input with dynamic shapes
111+
inputs = [
112+
torch_tensorrt.Input(
113+
min_shape=(1, 3, 224, 224),
114+
opt_shape=(8, 3, 224, 224),
115+
max_shape=(32, 3, 224, 224),
116+
dtype=torch.float32,
117+
name="x" # Optional: provides better dimension naming
118+
)
119+
]
120+
121+
# Compile with Torch-TensorRT
122+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
123+
124+
# Save with Input objects - dynamic_shapes inferred automatically!
125+
torch_tensorrt.save(
126+
trt_gm,
127+
"trt_dynamic.ep",
128+
arg_inputs=inputs # Dynamic shapes inferred from Input objects
129+
)
130+
131+
# Load and use with different batch sizes
132+
loaded_model = torch_tensorrt.load("trt_dynamic.ep").module()
133+
output_bs4 = loaded_model(torch.randn(4, 3, 224, 224).cuda())
134+
output_bs16 = loaded_model(torch.randn(16, 3, 224, 224).cuda())
135+
136+
45137
b) Torchscript
46138
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
47139

0 commit comments

Comments
 (0)