Skip to content

Commit 33a9cfe

Browse files
keehyunanarendasan
andauthored
docs: Updated cuda graphs doc (#3357)
Co-authored-by: Naren Dasan <[email protected]>
1 parent c8155f5 commit 33a9cfe

File tree

4 files changed

+23
-12
lines changed

4 files changed

+23
-12
lines changed
67.1 KB
Loading
59.7 KB
Loading

examples/dynamo/torch_export_cudagraphs.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
Torch Export with Cudagraphs
55
======================================================
66
7-
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well.
7+
CUDA Graphs allow multiple GPU operations to be launched through a single CPU operation, reducing launch overheads and improving GPU utilization. Torch-TensorRT provides a simple interface to enable CUDA graphs. This feature allows users to easily leverage the performance benefits of CUDA graphs without managing the complexities of capture and replay manually.
8+
9+
.. image:: /tutorials/images/cuda_graphs.png
10+
11+
This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the
12+
`torch.compile` path as well.
813
"""
914

1015
# %%
@@ -70,19 +75,25 @@
7075

7176
# %%
7277
# Cuda graphs with module that contains graph breaks
73-
# ----------------------------------
78+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7479
#
7580
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
7681
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
7782
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
7883
# kernel launch overhead and improved execution efficiency, may be diminished.
84+
#
7985
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
80-
# that can be executed efficiently, even in the presence of graph breaks.
81-
# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire
82-
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
83-
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
86+
# that can be executed efficiently, even in the presence of graph breaks. If TensorRT module has graph breaks, CUDA
87+
# Graph context manager returns a wrapped_module. And this module captures entire execution graph, enabling efficient
88+
# replay during subsequent inferences by reducing kernel launch overheads and improving performance.
89+
#
90+
# Note that initializing with the wrapper module involves a warm-up phase where the
8491
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
8592
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.
93+
#
94+
# .. image:: /tutorials/images/cuda_graphs_breaks.png
95+
# :scale: 60 %
96+
# :align: left
8697

8798

8899
class SampleModel(torch.nn.Module):

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
115115
contiguous_inputs[i].dtype == self.inputs[i].dtype
116116
), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}."
117117

118-
if need_cudagraphs_record:
119-
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120-
# Clone is required to avoid re-using user-provided GPU memory
121-
self._input_buffers[i] = contiguous_inputs[i].clone()
122-
else:
123-
self._input_buffers[i].copy_(contiguous_inputs[i])
118+
if need_cudagraphs_record:
119+
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120+
# Clone is required to avoid re-using user-provided GPU memory
121+
self._input_buffers[i] = contiguous_inputs[i].clone()
122+
else:
123+
self._input_buffers[i].copy_(contiguous_inputs[i])
124124

125125
self._caller_stream = torch.cuda.current_stream()
126126
if (

0 commit comments

Comments
 (0)