Skip to content

Commit 2b630e8

Browse files
committed
feat(runtime): add TensorRT-RTX native CUDA graph strategy to C++ runtime
Wire cuda_graph_strategy into the C++ runtime and make the execute_engine CUDA graph path TensorRT-RTX-aware. Fills in the apply_cuda_graph_strategy stub and adds coexistence handling for outer whole-graph capture. What - apply_cuda_graph_strategy() now calls IRuntimeConfig::setCudaGraphStrategy with either kDISABLED (default) or kWHOLE_GRAPH_CAPTURE. On RTX this hands capture/replay off to the TRT-RTX runtime, avoiding the lazy-kernel and dynamic-shape hazards of wrapping enqueueV3 in at::cuda::CUDAGraph. - is_monolithic_capturable(stream) returns whether an engine can safely be captured by an outer torch.cuda.CUDAGraph: RTX builds check IExecutionContext::isStreamCapturable and require a non-lazy kernel strategy; non-RTX builds always return true. - disable_rtx_native_cudagraphs() is a one-shot switch that turns off the engine internal capture and recreates the execution context so that outer stream captures contain the kernel launches directly. - execute_engine.cpp now computes effective_cudagraphs. On RTX, if a cuda_graph_strategy is set or SUBGRAPH cudagraphs is enabled, it bypasses the manual at::cuda::CUDAGraph path (the TRT-RTX runtime handles that inside enqueueV3). It also polls cudaStreamIsCapturing on the engine stream and, if an outer capture is already running, invokes disable_rtx_native_cudagraphs() so the outer capture proceeds without collision. Why - On TRT-RTX, the manual at::cuda::CUDAGraph wrapper around enqueueV3 can freeze fallback kernels in the captured graph (kLAZY specialisation would swap them later), and fails outright when the engine needs runtime allocation, DDS, control flow, or weight streaming. - Letting the TRT-RTX runtime own capture fixes both problems, and the outer-capture detection keeps the feature compatible with the existing CudaGraphsTorchTensorRTModule whole-graph wrapper without requiring it to know anything about RTX internals. Tests - tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py validates the setting default, both {disabled, whole_graph_capture} through the C++ runtime, the RTX-native override when set_cudagraphs_mode(True) is combined with a strategy, repeated inference correctness, and ValueError rejection of unknown strategy names.
1 parent 481455f commit 2b630e8

4 files changed

Lines changed: 188 additions & 7 deletions

File tree

core/runtime/TRTEngine.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,33 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
552552
}
553553
}
554554

555+
bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const {
556+
#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION)
557+
// "lazy" strategy (0) swaps specialized kernels in mid-run, which would invalidate a
558+
// captured graph. Any other strategy (eager/none) combined with a capturable stream is
559+
// safe for outer monolithic capture.
560+
return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != 0;
561+
#else
562+
(void)stream;
563+
return true;
564+
#endif
565+
}
566+
567+
void TRTEngine::disable_rtx_native_cudagraphs() {
568+
#ifdef TRT_MAJOR_RTX
569+
if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == 0) {
570+
return;
571+
}
572+
LOG_WARNING(
573+
"Outer CUDA stream capture detected; disabling TRT-RTX native CUDA graph strategy on engine "
574+
<< name << " for the remainder of its lifetime.");
575+
cuda_graph_strategy = 0;
576+
apply_cuda_graph_strategy();
577+
recreate_execution_context();
578+
rtx_native_cudagraphs_disabled = true;
579+
#endif
580+
}
581+
555582
void TRTEngine::recreate_execution_context() {
556583
#ifdef TRT_MAJOR_RTX
557584
if (!runtime_config) {
@@ -605,7 +632,12 @@ void TRTEngine::apply_dynamic_shapes_kernel_strategy() {
605632
}
606633

607634
void TRTEngine::apply_cuda_graph_strategy() {
608-
// Body added in a follow-up commit that wires the TRT-RTX native CUDA graph strategy.
635+
bool ok = runtime_config->setCudaGraphStrategy(
636+
cuda_graph_strategy == 1 ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE
637+
: nvinfer1::CudaGraphStrategy::kDISABLED);
638+
if (!ok) {
639+
LOG_WARNING("Failed to set CUDA graph strategy; continuing with default.");
640+
}
609641
}
610642

611643
void TRTEngine::load_runtime_cache() {

core/runtime/TRTEngine.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,24 @@ struct TRTEngine : torch::CustomClassHolder {
233233
std::string runtime_cache_path = "";
234234
int dynamic_shapes_kernel_strategy = 0; // 0=lazy, 1=eager, 2=none
235235
int cuda_graph_strategy = 0; // 0=disabled, 1=whole_graph_capture
236+
// One-shot flag: set the first time execute_engine detects an outer stream capture around
237+
// this engine, at which point its TRT-RTX native CUDA graph capture is turned off so the
238+
// two do not fight. The flag stays set for the remainder of the engine's lifetime.
239+
bool rtx_native_cudagraphs_disabled = false;
236240

237241
#ifdef TRT_MAJOR_RTX
238242
std::shared_ptr<nvinfer1::IRuntimeConfig> runtime_config;
239243
std::shared_ptr<nvinfer1::IRuntimeCache> runtime_cache;
240244
#endif
241245

246+
// Monolithic-capturability check used when this engine is wrapped by an outer whole-graph
247+
// capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true.
248+
bool is_monolithic_capturable(cudaStream_t stream) const;
249+
250+
// Disable TRT-RTX native CUDA graph capture on this engine (one-shot, invoked when an
251+
// outer stream capture is detected around execute_engine). No-op on non-RTX.
252+
void disable_rtx_native_cudagraphs();
253+
242254
private:
243255
// Single entry point that (re)creates exec_ctx. On RTX builds this also creates / reuses
244256
// the IRuntimeConfig and applies all runtime config settings.

core/runtime/execute_engine.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,29 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
217217

218218
auto run_standard_execution = [&]() {
219219
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
220+
// effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX
221+
// builds we bypass that path whenever the engine has a cuda_graph_strategy set or the
222+
// outer runtime has requested subgraph cudagraphs - the TRT-RTX runtime handles capture
223+
// and replay internally inside enqueueV3. If an outer stream capture is already in
224+
// progress (e.g. the caller wraps this module in CudaGraphsTorchTensorRTModule for
225+
// whole-graph capture), RTX-native capture would conflict, so we disable it one-shot.
226+
bool effective_cudagraphs = cudagraphs_enabled;
227+
#ifdef TRT_MAJOR_RTX
228+
if (compiled_engine->cuda_graph_strategy != 0 || cudagraphs_enabled) {
229+
effective_cudagraphs = false;
230+
cudaStreamCaptureStatus capture_status;
231+
cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status);
232+
if (capture_status != cudaStreamCaptureStatusNone) {
233+
compiled_engine->disable_rtx_native_cudagraphs();
234+
}
235+
}
236+
#endif
237+
220238
bool shape_changed = _validate_shapes(inputs, compiled_engine);
221239

222240
// Whether cudagraphs needs to record the graph on this pass
223241
auto result = compiled_engine->runtime_states.set_runtime_states(
224-
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);
242+
effective_cudagraphs, compiled_engine->use_pre_allocated_outputs, shape_changed);
225243

226244
bool need_cudagraphs_record = std::get<0>(result);
227245
bool can_use_pre_allocated_outputs = std::get<1>(result);
@@ -244,7 +262,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
244262
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
245263
}
246264

247-
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues);
265+
setup_input_tensors(
266+
inputs, compiled_engine, effective_cudagraphs, need_cudagraphs_record, inputShapeTensorValues);
248267
// Check if input shapes can be inferred.
249268
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
250269
std::vector<char const*> names(io_size);
@@ -276,7 +295,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
276295
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
277296
}
278297

279-
if (cudagraphs_enabled) {
298+
if (effective_cudagraphs) {
280299
TORCHTRT_CHECK(
281300
compiled_engine->exec_ctx->setTensorAddress(
282301
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
@@ -316,8 +335,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
316335
caller_exec_complete.record(compiled_engine->caller_stream);
317336
caller_exec_complete.block(compiled_engine->engine_stream);
318337

319-
if (!cudagraphs_enabled) {
320-
// Direct execution uses the caller buffers directly
338+
if (!effective_cudagraphs) {
339+
// Direct execution uses the caller buffers directly. On TRT-RTX with a
340+
// cuda_graph_strategy set, the engine captures/replays internally during
341+
// this enqueueV3 call.
321342
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
322343
} else {
323344
if (need_cudagraphs_record) {
@@ -350,7 +371,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
350371
trt_exec_complete.record(compiled_engine->engine_stream);
351372
trt_exec_complete.block(compiled_engine->caller_stream);
352373

353-
if (cudagraphs_enabled) {
374+
if (effective_cudagraphs) {
354375
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
355376
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
356377
outputs[o].copy_(compiled_engine->output_buffers[o], false);
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import unittest
2+
3+
import torch
4+
import torch_tensorrt as torchtrt
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
from torch_tensorrt._features import ENABLED_FEATURES
7+
from torch_tensorrt.dynamo._defaults import CUDA_GRAPH_STRATEGY
8+
from torch_tensorrt.dynamo._settings import CompilationSettings
9+
10+
11+
class CudaGraphModel(torch.nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.conv = torch.nn.Conv2d(3, 8, 3, padding=1)
15+
16+
def forward(self, x):
17+
return torch.relu(self.conv(x))
18+
19+
20+
def _compile_cpp(strategy):
21+
model = CudaGraphModel().eval().cuda()
22+
inputs = [torch.randn(2, 3, 16, 16).cuda()]
23+
compiled = torchtrt.compile(
24+
model,
25+
ir="dynamo",
26+
inputs=inputs,
27+
enabled_precisions={torch.float32},
28+
use_python_runtime=False,
29+
min_block_size=1,
30+
cuda_graph_strategy=strategy,
31+
)
32+
torch._dynamo.reset()
33+
return compiled, inputs
34+
35+
36+
class TestCudaGraphStrategySettings(TestCase):
37+
"""Setting-level validation that runs on every build (RTX and non-RTX)."""
38+
39+
def test_default_value(self):
40+
settings = CompilationSettings()
41+
self.assertEqual(settings.cuda_graph_strategy, CUDA_GRAPH_STRATEGY)
42+
43+
def test_settable_values(self):
44+
for value in ("disabled", "whole_graph_capture"):
45+
settings = CompilationSettings(cuda_graph_strategy=value)
46+
self.assertEqual(settings.cuda_graph_strategy, value)
47+
48+
49+
@unittest.skipIf(
50+
not ENABLED_FEATURES.torch_tensorrt_runtime,
51+
"C++ runtime is not available",
52+
)
53+
@unittest.skipIf(
54+
not ENABLED_FEATURES.tensorrt_rtx,
55+
"CUDA graph strategy is a TensorRT-RTX feature",
56+
)
57+
class TestCudaGraphStrategyCpp(TestCase):
58+
"""End-to-end: compile + infer through the C++ runtime with each strategy."""
59+
60+
def tearDown(self):
61+
torchtrt.runtime.set_cudagraphs_mode(False)
62+
63+
def test_disabled(self):
64+
compiled, inputs = _compile_cpp("disabled")
65+
y = compiled(*[inp.clone() for inp in inputs])
66+
self.assertEqual(tuple(y.shape), (2, 8, 16, 16))
67+
self.assertTrue(torch.isfinite(y).all().item())
68+
69+
def test_whole_graph_capture(self):
70+
compiled, inputs = _compile_cpp("whole_graph_capture")
71+
y = compiled(*[inp.clone() for inp in inputs])
72+
self.assertEqual(tuple(y.shape), (2, 8, 16, 16))
73+
self.assertTrue(torch.isfinite(y).all().item())
74+
75+
def test_whole_graph_capture_with_subgraph_cudagraphs(self):
76+
"""Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors."""
77+
compiled, inputs = _compile_cpp("whole_graph_capture")
78+
torchtrt.runtime.set_cudagraphs_mode(True)
79+
y = compiled(*[inp.clone() for inp in inputs])
80+
self.assertEqual(tuple(y.shape), (2, 8, 16, 16))
81+
self.assertTrue(torch.isfinite(y).all().item())
82+
83+
def test_repeated_inference(self):
84+
"""Repeated inference exercises the RTX-native capture/replay path."""
85+
compiled, inputs = _compile_cpp("whole_graph_capture")
86+
ref = compiled(*[inp.clone() for inp in inputs])
87+
for _ in range(4):
88+
out = compiled(*[inp.clone() for inp in inputs])
89+
self.assertEqual(out.shape, ref.shape)
90+
self.assertTrue(torch.isfinite(out).all().item())
91+
92+
93+
@unittest.skipIf(
94+
not ENABLED_FEATURES.torch_tensorrt_runtime,
95+
"C++ runtime is not available",
96+
)
97+
class TestCudaGraphStrategyInvalidValue(TestCase):
98+
"""Invalid strategy names are rejected at engine-packing time."""
99+
100+
def test_invalid_strategy_raises(self):
101+
model = CudaGraphModel().eval().cuda()
102+
inputs = [torch.randn(2, 3, 16, 16).cuda()]
103+
with self.assertRaises((ValueError, RuntimeError)):
104+
torchtrt.compile(
105+
model,
106+
ir="dynamo",
107+
inputs=inputs,
108+
enabled_precisions={torch.float32},
109+
use_python_runtime=False,
110+
min_block_size=1,
111+
cuda_graph_strategy="not_a_real_strategy",
112+
)
113+
114+
115+
if __name__ == "__main__":
116+
run_tests()

0 commit comments

Comments
 (0)