1+ """
2+ .. _compile_with_dynamic_inputs:
3+
4+ Compiling Models with Dynamic Input Shapes
5+ ==========================================================
6+
7+ Dynamic shapes are essential when your model
8+ needs to handle varying batch sizes or sequence lengths at inference time without recompilation.
9+
10+ The example uses a Vision Transformer-style model with expand and reshape operations,
11+ which are common patterns that benefit from dynamic shape handling.
12+ """
13+
14+ # %%
15+ # Imports and Model Definition
16+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17+
118import logging
219
320import torch
825
926torch .manual_seed (0 )
1027
28+ # %%
29+
1130
31+ # Define a model with expand and reshape operations
32+ # This is a simplified Vision Transformer pattern with:
33+ # - A learnable class token that needs to expand to match batch size
34+ # - A QKV projection followed by reshaping for multi-head attention
1235class ExpandReshapeModel (nn .Module ):
1336 def __init__ (self , embed_dim : int ):
1437 super ().__init__ ()
@@ -28,13 +51,40 @@ def forward(self, x: torch.Tensor):
2851model = ExpandReshapeModel (embed_dim = 768 ).cuda ().eval ()
2952x = torch .randn (4 , 196 , 768 ).cuda ()
3053
31- # 1. JIT: torch.compile
54+ # %%
55+ # Approach 1: JIT Compilation with `torch.compile`
56+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57+ #
58+ # The first approach uses PyTorch's `torch.compile` with the TensorRT backend.
59+ # This is a Just-In-Time (JIT) compilation method where the model is compiled
60+ # during the first inference call.
61+ #
62+ # Key points:
63+ #
64+ # - Use `torch._dynamo.mark_dynamic()` to specify which dimensions are dynamic
65+ # - The `index` parameter indicates which dimension (0 = batch dimension)
66+ # - Provide `min` and `max` bounds for the dynamic dimension
67+ # - The model will work for any batch size within the specified range
68+
3269x1 = x .clone ()
3370torch ._dynamo .mark_dynamic (x1 , index = 0 , min = 2 , max = 32 )
3471trt_module = torch .compile (model , backend = "tensorrt" )
3572out1 = trt_module (x1 )
3673
37- # 2. AOT: torch_tensorrt.compile
74+ # %%
75+ # Approach 2: AOT Compilation with `torch_tensorrt.compile`
76+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
77+ #
78+ # The second approach uses Ahead-Of-Time (AOT) compilation with `torch_tensorrt.compile`.
79+ # This compiles the model upfront before inference.
80+ #
81+ # Key points:
82+ #
83+ # - Use `torch_tensorrt.Input()` to specify dynamic shape ranges
84+ # - Provide `min_shape`, `opt_shape`, and `max_shape` for each input
85+ # - The `opt_shape` is used for optimization and should represent typical input sizes
86+ # - Set `ir="dynamo"` to use the Dynamo frontend
87+
3888x2 = x .clone ()
3989example_input = torch_tensorrt .Input (
4090 min_shape = [1 , 196 , 768 ],
@@ -45,14 +95,38 @@ def forward(self, x: torch.Tensor):
4595trt_module = torch_tensorrt .compile (model , ir = "dynamo" , inputs = example_input )
4696out2 = trt_module (x2 )
4797
48- # 3. AOT: torch.export + Dynamo compile
98+ # %%
99+ # Approach 3: AOT with `torch.export` + Dynamo Compile
100+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
101+ #
102+ # The third approach uses PyTorch 2.0's `torch.export` API combined with
103+ # Torch-TensorRT's Dynamo compiler. This provides the most explicit control
104+ # over dynamic shapes.
105+ #
106+ # Key points:
107+ #
108+ # - Use `torch.export.Dim()` to define symbolic dimensions with constraints
109+ # - Create a `dynamic_shapes` dictionary mapping inputs to their dynamic dimensions
110+ # - Export the model to an `ExportedProgram` with these constraints
111+ # - Compile the exported program with `torch_tensorrt.dynamo.compile`
112+
49113x3 = x .clone ()
50114bs = torch .export .Dim ("bs" , min = 1 , max = 32 )
51115dynamic_shapes = {"x" : {0 : bs }}
52116exp_program = torch .export .export (model , (x3 ,), dynamic_shapes = dynamic_shapes )
53117trt_module = torch_tensorrt .dynamo .compile (exp_program , (x3 ,))
54118out3 = trt_module (x3 )
55119
120+ # %%
121+ # Verify All Approaches Produce Identical Results
122+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
123+ #
124+ # All three approaches should produce the same numerical results.
125+ # This verification ensures that dynamic shape handling works correctly
126+ # across different compilation methods.
127+
56128assert torch .allclose (out1 , out2 )
57129assert torch .allclose (out1 , out3 )
58130assert torch .allclose (out2 , out3 )
131+
132+ print ("All three approaches produced identical results!" )
0 commit comments