Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Tutorials
------------

* :ref:`torch_compile_advanced_usage`
* :ref:`compile_with_dynamic_inputs`
* :ref:`vgg16_ptq`
* :ref:`engine_caching_example`
* :ref:`engine_caching_bert_example`
Expand All @@ -70,6 +71,7 @@ Tutorials
* :ref:`auto_generate_plugins`
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`dynamic_memory_allocation`
* :ref:`pre_allocated_output_example`
* :ref:`debugger_example`

Expand All @@ -79,6 +81,7 @@ Tutorials
:hidden:

tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/compile_with_dynamic_inputs
tutorials/_rendered_examples/dynamo/vgg16_ptq
tutorials/_rendered_examples/dynamo/engine_caching_example
tutorials/_rendered_examples/dynamo/engine_caching_bert_example
Expand All @@ -91,6 +94,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/auto_generate_plugins
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/dynamic_memory_allocation
tutorials/_rendered_examples/dynamo/pre_allocated_output_example

Dynamo Frontend
Expand Down
80 changes: 77 additions & 3 deletions examples/dynamo/compile_with_dynamic_inputs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
.. _compile_with_dynamic_inputs:

Compiling Models with Dynamic Input Shapes
==========================================================

Dynamic shapes are essential when your model
needs to handle varying batch sizes or sequence lengths at inference time without recompilation.

The example uses a Vision Transformer-style model with expand and reshape operations,
which are common patterns that benefit from dynamic shape handling.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import logging

import torch
Expand All @@ -8,7 +25,13 @@

torch.manual_seed(0)

# %%


# Define a model with expand and reshape operations
# This is a simplified Vision Transformer pattern with:
# - A learnable class token that needs to expand to match batch size
# - A QKV projection followed by reshaping for multi-head attention
class ExpandReshapeModel(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
Expand All @@ -28,13 +51,40 @@ def forward(self, x: torch.Tensor):
model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()

# 1. JIT: torch.compile
# %%
# Approach 1: JIT Compilation with `torch.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The first approach uses PyTorch's `torch.compile` with the TensorRT backend.
# This is a Just-In-Time (JIT) compilation method where the model is compiled
# during the first inference call.
#
# Key points:
#
# - Use `torch._dynamo.mark_dynamic()` to specify which dimensions are dynamic
# - The `index` parameter indicates which dimension (0 = batch dimension)
# - Provide `min` and `max` bounds for the dynamic dimension
# - The model will work for any batch size within the specified range

x1 = x.clone()
torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out1 = trt_module(x1)

# 2. AOT: torch_tensorrt.compile
# %%
# Approach 2: AOT Compilation with `torch_tensorrt.compile`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The second approach uses Ahead-Of-Time (AOT) compilation with `torch_tensorrt.compile`.
# This compiles the model upfront before inference.
#
# Key points:
#
# - Use `torch_tensorrt.Input()` to specify dynamic shape ranges
# - Provide `min_shape`, `opt_shape`, and `max_shape` for each input
# - The `opt_shape` is used for optimization and should represent typical input sizes
# - Set `ir="dynamo"` to use the Dynamo frontend

x2 = x.clone()
example_input = torch_tensorrt.Input(
min_shape=[1, 196, 768],
Expand All @@ -45,14 +95,38 @@ def forward(self, x: torch.Tensor):
trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input)
out2 = trt_module(x2)

# 3. AOT: torch.export + Dynamo compile
# %%
# Approach 3: AOT with `torch.export` + Dynamo Compile
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The third approach uses PyTorch 2.0's `torch.export` API combined with
# Torch-TensorRT's Dynamo compiler. This provides the most explicit control
# over dynamic shapes.
#
# Key points:
#
# - Use `torch.export.Dim()` to define symbolic dimensions with constraints
# - Create a `dynamic_shapes` dictionary mapping inputs to their dynamic dimensions
# - Export the model to an `ExportedProgram` with these constraints
# - Compile the exported program with `torch_tensorrt.dynamo.compile`

x3 = x.clone()
bs = torch.export.Dim("bs", min=1, max=32)
dynamic_shapes = {"x": {0: bs}}
exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)
trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))
out3 = trt_module(x3)

# %%
# Verify All Approaches Produce Identical Results
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# All three approaches should produce the same numerical results.
# This verification ensures that dynamic shape handling works correctly
# across different compilation methods.

assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
assert torch.allclose(out2, out3)

print("All three approaches produced identical results!")
4 changes: 1 addition & 3 deletions examples/dynamo/custom_kernel_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,7 @@ def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin:
from torch_tensorrt.fx.converters.converter_utils import set_layer_name


@dynamo_tensorrt_converter(
torch.ops.torchtrt_ex.triton_circular_pad.default
) # type: ignore
@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.triton_circular_pad.default) # type: ignore
# Recall the schema defined above:
# torch.ops.torchtrt_ex.triton_circular_pad.default(Tensor x, IntList padding) -> Tensor
def circular_padding_converter(
Expand Down
1 change: 0 additions & 1 deletion examples/dynamo/debugger_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
"remove_detach"
], # fx graph visualization before certain lowering pass
):

trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
Expand Down
64 changes: 64 additions & 0 deletions examples/dynamo/dynamic_memory_allocation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
"""
.. _dynamic_memory_allocation:

Dynamic Memory Allocation
==========================================================

This script demonstrates how to use dynamic memory allocation with Torch-TensorRT
to reduce GPU memory footprint. When enabled, TensorRT engines allocate and deallocate resources
dynamically during inference, which can significantly reduce peak memory usage.

This is particularly useful when:

- Running multiple models on the same GPU
- Working with limited GPU memory
- Memory usage needs to be minimized between inference calls
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import gc
import time

Expand All @@ -11,6 +31,19 @@
torch.manual_seed(5)
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]

# %%
# Compilation Settings with Dynamic Memory Allocation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Key settings for dynamic memory allocation:
#
# - ``dynamically_allocate_resources=True``: Enables dynamic resource allocation
# - ``lazy_engine_init=True``: Delays engine initialization until first inference
# - ``immutable_weights=False``: Allows weight refitting if needed
#
# With these settings, the engine will allocate GPU memory only when needed
# and deallocate it after inference completes.

settings = {
"ir": "dynamo",
"use_python_runtime": False,
Expand All @@ -25,6 +58,20 @@
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
compiled_module(*inputs)

# %%
# Runtime Resource Allocation Control
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can control resource allocation behavior at runtime using the
# ``ResourceAllocationStrategy`` context manager. This allows you to:
#
# - Switch between dynamic and static allocation modes
# - Control when resources are allocated and deallocated
# - Optimize memory usage for specific inference patterns
#
# In this example, we temporarily disable dynamic allocation to keep
# resources allocated between inference calls, which can improve performance
# when running multiple consecutive inferences.

time.sleep(30)
with torch_trt.dynamo.runtime.ResourceAllocationStrategy(
Expand All @@ -43,3 +90,20 @@
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
)
compiled_module(*inputs)

# %%
# Memory Usage Comparison
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Dynamic memory allocation trades off some performance for reduced memory footprint:
#
# **Benefits:**
#
# - Lower peak GPU memory usage
# - Reduced memory pressure on shared GPUs
#
# **Considerations:**
#
# - Slight overhead from allocation/deallocation
# - Best suited for scenarios where memory is constrained
# - May not be necessary for single-model deployments with ample memory
1 change: 0 additions & 1 deletion examples/dynamo/low_cpu_memory_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def forward(self, x):
logging_dir="/home/profile/logging/moe",
engine_builder_monitor=False,
):

exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
Expand Down
5 changes: 3 additions & 2 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@

# Check the output
with torch.no_grad():
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
expected_outputs, refitted_outputs = (
exp_program2.module()(*inputs),
new_trt_gm(*inputs),
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,17 @@ test = [
"transformers>=4.49.0",
]

docs = [
"sphinx==5.0.1",
"sphinx-gallery==0.13.0",
"breathe==4.34.0",
"exhale==0.3.7",
"pytorch_sphinx_theme @ git+https://github.com/pytorch/pytorch_sphinx_theme.git",
"nbsphinx==0.9.3",
"docutils==0.17.1",
"pillow",
]

quantization = ["nvidia-modelopt[all]>=0.27.1"]

[project.urls]
Expand Down
Loading
Loading