Skip to content

Commit c36fba4

Browse files
tp5uiucclaude
andcommitted
feat: add TRT-RTX native CUDA graph support
Add cuda_graph_strategy compilation setting and automatic RTX-native CUDA graph integration for the Python runtime path. Key changes: - New cuda_graph_strategy setting ("disabled" / "whole_graph_capture") on CompilationSettings, mapped to trt.CudaGraphStrategy on IRuntimeConfig (same pattern as dynamic_shapes_kernel_specialization) - In SUBGRAPH cudagraph mode on RTX, always use RTX-native CUDA graphs (manual torch.cuda.CUDAGraph capture is not safe due to lazy kernel specialization and potential runtime allocation) - _is_monolithic_capturable() check using context.is_stream_capturable() and strategy != "lazy" for WHOLE_GRAPH mode safety validation - _enable_rtx_native_cudagraphs() for runtime context recreation - _check_monolithic_capturability() in CudaGraphsTorchTensorRTModule for mixed TRT + PyTorch graph validation - Comprehensive unit tests covering all code paths Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 12166c6 commit c36fba4

7 files changed

Lines changed: 674 additions & 7 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def cross_compile_for_windows(
9494
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
9595
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
9696
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
97+
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
9798
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
9899
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
99100
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
@@ -176,6 +177,7 @@ def cross_compile_for_windows(
176177
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
177178
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
178179
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
180+
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
179181
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
180182
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
181183
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
@@ -342,6 +344,7 @@ def cross_compile_for_windows(
342344
"timing_cache_path": timing_cache_path,
343345
"runtime_cache_path": runtime_cache_path,
344346
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
347+
"cuda_graph_strategy": cuda_graph_strategy,
345348
"lazy_engine_init": lazy_engine_init,
346349
"cache_built_engines": cache_built_engines,
347350
"reuse_cached_engines": reuse_cached_engines,
@@ -455,6 +458,7 @@ def compile(
455458
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
456459
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
457460
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
461+
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
458462
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
459463
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
460464
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
@@ -552,6 +556,7 @@ def compile(
552556
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
553557
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
554558
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
559+
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
555560
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
556561
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
557562
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
@@ -761,6 +766,7 @@ def compile(
761766
"timing_cache_path": timing_cache_path,
762767
"runtime_cache_path": runtime_cache_path,
763768
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
769+
"cuda_graph_strategy": cuda_graph_strategy,
764770
"lazy_engine_init": lazy_engine_init,
765771
"cache_built_engines": cache_built_engines,
766772
"reuse_cached_engines": reuse_cached_engines,
@@ -1176,6 +1182,7 @@ def convert_exported_program_to_serialized_trt_engine(
11761182
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
11771183
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
11781184
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
1185+
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
11791186
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
11801187
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
11811188
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
@@ -1254,6 +1261,7 @@ def convert_exported_program_to_serialized_trt_engine(
12541261
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
12551262
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
12561263
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
1264+
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
12571265
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
12581266
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
12591267
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
@@ -1429,6 +1437,7 @@ def convert_exported_program_to_serialized_trt_engine(
14291437
"timing_cache_path": timing_cache_path,
14301438
"runtime_cache_path": runtime_cache_path,
14311439
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
1440+
"cuda_graph_strategy": cuda_graph_strategy,
14321441
"lazy_engine_init": lazy_engine_init,
14331442
"cache_built_engines": cache_built_engines,
14341443
"reuse_cached_engines": reuse_cached_engines,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
DECOMPOSE_ATTENTION = False
7474
ATTN_BIAS_IS_CAUSAL = True
7575
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"
76+
CUDA_GRAPH_STRATEGY = "disabled"
7677

7778
if platform.system() == "Linux":
7879
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AUTOCAST_MAX_OUTPUT_THRESHOLD,
1818
CACHE_BUILT_ENGINES,
1919
CPU_MEMORY_BUDGET,
20+
CUDA_GRAPH_STRATEGY,
2021
DECOMPOSE_ATTENTION,
2122
DISABLE_TF32,
2223
DLA_GLOBAL_DRAM_SIZE,
@@ -102,6 +103,7 @@ class CompilationSettings:
102103
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
103104
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
104105
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
106+
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled".
105107
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
106108
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
107109
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
@@ -159,6 +161,7 @@ class CompilationSettings:
159161
dynamic_shapes_kernel_specialization_strategy: str = (
160162
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
161163
)
164+
cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY
162165
lazy_engine_init: bool = LAZY_ENGINE_INIT
163166
cache_built_engines: bool = CACHE_BUILT_ENGINES
164167
reuse_cached_engines: bool = REUSE_CACHED_ENGINES

py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,53 @@ def __del__(self) -> None:
114114
def set_use_output_allocator(self, enable: bool) -> None:
115115
self.use_output_allocator_outputs = enable
116116

117+
def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None:
118+
"""Verify all TRT submodules are monolithically capturable on RTX.
119+
120+
For whole-graph CUDA graph mode with mixed TRT + PyTorch ops,
121+
all TRT engines must be safe for manual stream capture. If any
122+
engine has lazy kernel specialization or non-capturable conditions,
123+
raises RuntimeError.
124+
"""
125+
from torch_tensorrt._features import ENABLED_FEATURES
126+
127+
if not ENABLED_FEATURES.tensorrt_rtx:
128+
return # non-RTX: no check needed
129+
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
130+
PythonTorchTensorRTModule,
131+
)
132+
133+
for name, mod in self.compiled_module.named_modules():
134+
if isinstance(mod, PythonTorchTensorRTModule):
135+
if not mod._is_monolithic_capturable(stream):
136+
raise RuntimeError(
137+
f"CUDA graph capture failed: TRT submodule "
138+
f"'{name}' is not monolithically capturable "
139+
f"(lazy kernel specialization or non-capturable "
140+
f"stream). Whole-graph CUDA graph mode with mixed "
141+
f"TRT + PyTorch ops requires all TRT engines to be "
142+
f"capturable. Consider using "
143+
f"cuda_graph_strategy='whole_graph_capture' with "
144+
f"set_cudagraphs_mode(True) instead of "
145+
f"enable_cudagraphs()."
146+
)
147+
# Ensure RTX-native is DISABLED so TRT engines do not
148+
# interfere with the outer monolithic capture
149+
if mod._rtx_native_cudagraphs:
150+
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
151+
_get_cuda_graph_strategy,
152+
)
153+
154+
mod.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
155+
"disabled"
156+
)
157+
mod.context = mod._create_context()
158+
mod._rtx_native_cudagraphs = False
159+
logger.info(
160+
f"Disabled RTX-native CUDA graphs for '{name}' "
161+
f"(using outer monolithic capture instead)"
162+
)
163+
117164
def forward(
118165
self, *args: Any, **kwargs: Any
119166
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
@@ -183,6 +230,7 @@ def forward(
183230

184231
with torch.cuda.stream(self._engine_stream):
185232
if need_cudagraphs_record:
233+
self._check_monolithic_capturability(self._engine_stream)
186234
self.cudagraph = torch.cuda.CUDAGraph()
187235
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
188236
self._output_buffers = self.compiled_module(*args, **kwargs)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any:
3636
}.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY)
3737

3838

39+
def _get_cuda_graph_strategy(strategy_str: str) -> Any:
40+
"""Map strategy string to TRT CudaGraphStrategy enum. Only called on RTX builds."""
41+
return {
42+
"disabled": trt.CudaGraphStrategy.DISABLED,
43+
"whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE,
44+
}.get(strategy_str, trt.CudaGraphStrategy.DISABLED)
45+
46+
3947
class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
4048
def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None:
4149
trt.IOutputAllocator.__init__(self)
@@ -241,6 +249,7 @@ def __init__(
241249
self.runtime_config: Any = None
242250
self.runtime_cache: Any = None
243251
self.runtime_cache_path = settings.runtime_cache_path
252+
self._rtx_native_cudagraphs = False
244253

245254
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
246255
self.setup_engine()
@@ -309,6 +318,10 @@ def setup_engine(self) -> None:
309318

310319
if ENABLED_FEATURES.tensorrt_rtx:
311320
self._setup_runtime_config()
321+
self._rtx_native_cudagraphs = (
322+
ENABLED_FEATURES.tensorrt_rtx
323+
and self.settings.cuda_graph_strategy != "disabled"
324+
)
312325

313326
self.context = self._create_context()
314327
assert self.context is not None, "Failed to create execution context"
@@ -336,7 +349,10 @@ def setup_engine(self) -> None:
336349
if self.requires_output_allocator:
337350
self.create_output_allocator()
338351

339-
if torch_tensorrt.runtime.get_cudagraphs_mode():
352+
if (
353+
torch_tensorrt.runtime.get_cudagraphs_mode()
354+
and not self._rtx_native_cudagraphs
355+
):
340356
self.cudagraph = torch.cuda.CUDAGraph()
341357

342358
self.is_shape_inference_io = {
@@ -362,6 +378,10 @@ def _setup_runtime_config(self) -> None:
362378
logger.info(
363379
f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}"
364380
)
381+
self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
382+
self.settings.cuda_graph_strategy
383+
)
384+
logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}")
365385
self.runtime_cache = self.runtime_config.create_runtime_cache()
366386
self._load_runtime_cache()
367387
self.runtime_config.set_runtime_cache(self.runtime_cache)
@@ -466,6 +486,32 @@ def _reset_captured_graph(self) -> None:
466486
self.cudagraph.reset()
467487
self.cudagraph = None
468488

489+
def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool:
490+
"""Check if manual torch.cuda.CUDAGraph capture is safe for this engine.
491+
492+
Returns False on RTX if the engine has conditions that prevent
493+
manual stream capture (runtime allocation, DDS, lazy kernels).
494+
"""
495+
if not ENABLED_FEATURES.tensorrt_rtx:
496+
return True # non-RTX: assume capturable (existing behavior)
497+
# Check 1: TRT-RTX stream capturability (runtime allocation, DDS, etc.)
498+
if not self.context.is_stream_capturable(stream.cuda_stream):
499+
return False
500+
# Check 2: Lazy kernel specialization would invalidate captured graph
501+
if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy":
502+
return False
503+
return True
504+
505+
def _enable_rtx_native_cudagraphs(self) -> None:
506+
"""Switch to RTX-native CUDA graphs by recreating the execution context."""
507+
if self.runtime_config is not None:
508+
self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
509+
"whole_graph_capture"
510+
)
511+
self.context = self._create_context()
512+
self._rtx_native_cudagraphs = True
513+
logger.info("Switched to TRT-RTX native CUDA graphs")
514+
469515
def __del__(self) -> None:
470516
self._save_runtime_cache()
471517
self._reset_captured_graph()
@@ -559,13 +605,32 @@ def create_output_allocator(self) -> None:
559605

560606
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
561607
def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
608+
# On RTX + SUBGRAPH cudagraphs: always use RTX-native CUDA graphs.
609+
# Manual torch.cuda.CUDAGraph capture is not safe on TRT-RTX because
610+
# lazy kernel specialization can invalidate captured graphs and
611+
# runtime allocation can prevent stream capture.
612+
if ENABLED_FEATURES.tensorrt_rtx and self.cudagraphs_enabled:
613+
if not self._rtx_native_cudagraphs:
614+
logger.warning(
615+
"Manual CUDA graph capture is not guaranteed to work "
616+
"on TRT-RTX (lazy kernel specialization or "
617+
"non-capturable stream). Switching to TRT-RTX native "
618+
"CUDA graphs. Set cuda_graph_strategy="
619+
'"whole_graph_capture" at compile time to avoid '
620+
"this warning."
621+
)
622+
self._enable_rtx_native_cudagraphs()
623+
624+
effective_cudagraphs = (
625+
self.cudagraphs_enabled and not self._rtx_native_cudagraphs
626+
)
562627
shape_changed = self.validate_input_shapes(contiguous_inputs)
563628
(
564629
need_cudagraphs_record,
565630
can_use_pre_allocated_outputs,
566631
need_cudagraphs_reset,
567632
) = self.runtime_states.set_runtime_states(
568-
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
633+
effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed
569634
)
570635

571636
if need_cudagraphs_reset:
@@ -587,7 +652,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
587652
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
588653

589654
self.setup_input_tensors(
590-
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
655+
contiguous_inputs, effective_cudagraphs, need_cudagraphs_record
591656
)
592657

593658
if shape_changed:
@@ -623,7 +688,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
623688
if need_cudagraphs_record:
624689
self._output_buffers[o] = outputs[o].clone()
625690

626-
if self.cudagraphs_enabled:
691+
if effective_cudagraphs:
627692
self.context.set_tensor_address(
628693
output_name, self._output_buffers[o].data_ptr()
629694
)
@@ -649,7 +714,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
649714
self._engine_stream.wait_stream(self._caller_stream)
650715

651716
with torch.cuda.stream(self._engine_stream):
652-
if self.cudagraphs_enabled:
717+
if effective_cudagraphs:
653718
if need_cudagraphs_record:
654719
self.cudagraph = torch.cuda.CUDAGraph()
655720

@@ -683,7 +748,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
683748
):
684749
self.pre_allocated_outputs = self.create_output_tensors()
685750

686-
if self.cudagraphs_enabled:
751+
if effective_cudagraphs:
687752
for idx, o in enumerate(outputs):
688753
o.copy_(self._output_buffers[idx])
689754

@@ -840,7 +905,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
840905
return run_output_allocator()
841906
else:
842907
logger.debug(
843-
f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}."
908+
f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}"
909+
+ (" (RTX native)" if self._rtx_native_cudagraphs else "")
844910
)
845911
return run_standard_execution()
846912

0 commit comments

Comments
 (0)