Skip to content

Commit 4556b66

Browse files
committed
TensorRT-RTX 1.2 Release
1 parent 49ece6a commit 4556b66

File tree

9 files changed

+817
-792
lines changed

9 files changed

+817
-792
lines changed

demo/flux1.dev/flux_demo.ipynb

Lines changed: 746 additions & 759 deletions
Large diffs are not rendered by default.

demo/flux1.dev/flux_demo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ def main():
6262
choices=["bf16", "fp8", "fp4"],
6363
)
6464
parser.add_argument("--enable-runtime-cache", action="store_true", help="Enable runtime caching")
65+
parser.add_argument(
66+
"--cuda-graph-strategy",
67+
type=str,
68+
default="disabled",
69+
help="Cuda graph strategy (default: disabled)",
70+
choices=["disabled", "whole_graph_capture"],
71+
)
6572
parser.add_argument("--low-vram", action="store_true", help="Enable low VRAM mode")
6673
parser.add_argument("--dynamic-shape", action="store_true", default=False, help="Enable dynamic-shape engines")
6774
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
@@ -83,6 +90,7 @@ def main():
8390
num_inference_steps=args.num_inference_steps,
8491
hf_token=args.hf_token,
8592
low_vram=args.low_vram,
93+
cuda_graph_strategy=args.cuda_graph_strategy,
8694
enable_runtime_cache=args.enable_runtime_cache,
8795
)
8896

@@ -99,6 +107,7 @@ def main():
99107
logger.info(f"Guidance scale: {args.guidance_scale}")
100108
logger.info(f"Cache directory: {args.cache_dir}")
101109
logger.info(f"Low VRAM mode: {args.low_vram}")
110+
logger.info(f"Cudagraphs: {args.cuda_graph_strategy}")
102111
logger.info(f"Dynamic shape: {args.dynamic_shape}")
103112
logger.info(f"Runtime caching: {args.enable_runtime_cache}")
104113
logger.info(f"Cache mode: {args.cache_mode}")

demo/flux1.dev/pipelines/flux_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
low_vram: bool = False,
6767
log_level: str = "INFO",
6868
enable_runtime_cache: bool = False,
69+
cuda_graph_strategy: str = "disabled",
6970
):
7071
super().__init__(
7172
pipeline_name="flux_1_dev",
@@ -77,6 +78,7 @@ def __init__(
7778
low_vram=low_vram,
7879
log_level=log_level,
7980
enable_runtime_cache=enable_runtime_cache,
81+
cuda_graph_strategy=cuda_graph_strategy,
8082
)
8183

8284
# Flux-specific parameters
@@ -250,7 +252,7 @@ def build_and_load_engine(
250252
)
251253

252254
if is_compatible:
253-
engine = Engine(engine_path, precision, model_id, self.runtime_cache_path)
255+
engine = Engine(engine_path, precision, model_id, self.runtime_cache_path, self.cuda_graph_strategy)
254256
try:
255257
if not self.low_vram:
256258
engine.load()
@@ -285,7 +287,7 @@ def build_and_load_engine(
285287
)
286288

287289
logger.debug(f"Building engine for path {engine_path}")
288-
engine = Engine(engine_path, precision, model_id, self.runtime_cache_path)
290+
engine = Engine(engine_path, precision, model_id, self.runtime_cache_path, self.cuda_graph_strategy)
289291
engine.build(
290292
onnx_path=str(onnx_path),
291293
input_profile=input_profile,

demo/tests/test_license_headers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def find_files_by_pattern(cls, root_path, patterns):
5454
# Directories to exclude from license header checks (only within the repository)
5555
exclude_dirs = {
5656
"build",
57-
".venv",
5857
}
5958

6059
files = []

demo/utils/engine.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,18 @@ def __init__(
8282
precision: str,
8383
model_name: str,
8484
runtime_cache_path: Optional[str] = None,
85+
cuda_graph_strategy: str = "disabled",
8586
):
8687
self.engine_path = engine_path
8788
self.engine = None
8889
self.context = None
8990
self.tensors = OrderedDict()
90-
self.cuda_graph_instance = None
9191
self.precision = precision
9292
self.model_name = model_name
9393
self.runtime_config = None
9494
self.runtime_cache = None
9595
self.runtime_cache_path = runtime_cache_path
96+
self.cuda_graph_strategy = cuda_graph_strategy
9697

9798
def __del__(self):
9899
del self.tensors
@@ -154,7 +155,7 @@ def build(
154155
)
155156

156157
# Build command with arguments
157-
build_command = [f"polygraphy convert {onnx_path} --convert-to trt --output {self.engine_path}"]
158+
build_command = [f"polygraphy convert {onnx_path} --convert-to trt --use-gpu --output {self.engine_path}"]
158159

159160
build_args = []
160161
verbosity = "extra_verbose" if verbose else "error"
@@ -254,6 +255,10 @@ def activate(self, device_memory: Optional[int] = None, defer_memory_allocation:
254255
"""Create execution context"""
255256

256257
self.runtime_config = self.engine.create_runtime_config()
258+
259+
if self.cuda_graph_strategy == "whole_graph_capture":
260+
self.runtime_config.cuda_graph_strategy = trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE
261+
257262
if self.runtime_cache_path:
258263
if self.runtime_cache is None:
259264
logger.debug("Creating runtime cache")
@@ -383,7 +388,7 @@ def deallocate_buffers(self):
383388
gc.collect()
384389
torch.cuda.empty_cache()
385390

386-
def infer(self, feed_dict: dict[str, Any], stream: torch.cuda.Stream, use_cuda_graph: bool = False):
391+
def infer(self, feed_dict: dict[str, Any], stream: torch.cuda.Stream):
387392
"""Run inference with the engine"""
388393
# Copy input data to tensors
389394
for name, buf in feed_dict.items():
@@ -394,26 +399,8 @@ def infer(self, feed_dict: dict[str, Any], stream: torch.cuda.Stream, use_cuda_g
394399
self.context.set_tensor_address(name, tensor.data_ptr())
395400

396401
# Execute inference
397-
if use_cuda_graph:
398-
if self.cuda_graph_instance is not None:
399-
_CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
400-
_CUASSERT(cudart.cudaStreamSynchronize(stream))
401-
else:
402-
# Initial inference before CUDA graph capture
403-
noerror = self.context.execute_async_v3(stream)
404-
if not noerror:
405-
raise ValueError(f"ERROR: Inference with {self.engine_path} failed.")
406-
407-
# Capture CUDA graph
408-
_CUASSERT(
409-
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
410-
)
411-
self.context.execute_async_v3(stream)
412-
self.graph = _CUASSERT(cudart.cudaStreamEndCapture(stream))
413-
self.cuda_graph_instance = _CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
414-
else:
415-
noerror = self.context.execute_async_v3(stream)
416-
if not noerror:
417-
raise ValueError(f"ERROR: Inference with {self.engine_path} failed.")
402+
noerror = self.context.execute_async_v3(stream)
403+
if not noerror:
404+
raise ValueError(f"ERROR: Inference with {self.engine_path} failed.")
418405

419406
return self.tensors

demo/utils/pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
low_vram: bool = False,
6161
log_level: str = "INFO",
6262
enable_runtime_cache: bool = False,
63+
cuda_graph_strategy: str = "disabled",
6364
):
6465
"""
6566
Initialize pipeline.
@@ -75,6 +76,7 @@ def __init__(
7576
low_vram: Enable low VRAM mode
7677
log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
7778
enable_runtime_cache: Enable use of serialized runtime cache to improve JIT compilation times
79+
cuda_graph_strategy: Enable use of Cudagraphs for accelerated inference (disabled, whole_graph_capture)
7880
"""
7981
# Configure logging FIRST, before any other operations
8082
self.configure_logging(verbose, log_level)
@@ -89,6 +91,13 @@ def __init__(
8991
self.verbose = verbose
9092
self.hf_token = hf_token
9193
self.low_vram = low_vram
94+
self.enable_runtime_cache = enable_runtime_cache
95+
96+
assert cuda_graph_strategy in ["disabled", "whole_graph_capture"], (
97+
"Invalid cuda graph strategy {cuda_graph_strategy}, must be either 'disabled' or 'whole_graph_capture'"
98+
)
99+
logger.debug(f"Cuda graph strategy: {cuda_graph_strategy}")
100+
self.cuda_graph_strategy = cuda_graph_strategy
92101

93102
if enable_runtime_cache:
94103
self.runtime_cache_path = os.path.join(cache_dir, "runtime.cache")
@@ -276,7 +285,7 @@ def calculate_max_device_memory(self) -> int:
276285
def run_engine(self, model_name: str, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
277286
"""Run inference on a specific engine"""
278287
engine = self.engines[model_name]
279-
return engine.infer(inputs, self.stream, use_cuda_graph=False)
288+
return engine.infer(inputs, self.stream)
280289

281290
def infer(self, *args, **kwargs):
282291
"""Run the full pipeline inference - to be implemented by subclasses"""

samples/apiUsage/cpp/apiUsage.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,24 @@ int main()
564564
}
565565
useOptionalAdvancedDynamicShapesAPI(runtimeConfig.get(), inferenceEngine.get());
566566

567+
// Enable Cudagraphs Whole Graph Capture for accelerated inference
568+
{
569+
// TensorRT-RTX can record CUDA graphs to reduce kernel launch overhead during JIT inference.
570+
// kDISABLED skips graph capture and runs kernels directly on the stream
571+
// kWHOLE_GRAPH_CAPTURE captures the complete computational graph of the model
572+
// and executes it atomically on the GPU stream. It automatically handles dynamic shape
573+
// cases, capturing the CUDA graph after shape-specialized kernels are compiled for a given shape.
574+
bool const setCudaGraphStrategySuccess
575+
= runtimeConfig->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE);
576+
if (!setCudaGraphStrategySuccess)
577+
{
578+
std::cerr << "Failed to set cuda graph strategy!" << std::endl;
579+
return EXIT_FAILURE;
580+
}
581+
// Query API to illustrate retrieval.
582+
(void) runtimeConfig->getCudaGraphStrategy();
583+
}
584+
567585
// Create an engine execution context out of the deserialized engine.
568586
// TRT-RTX performs "Just-in-Time" (JIT) optimization here, targeting the current GPU.
569587
// JIT phase is faster than AOT phase, and typically completes in under 15 seconds.

samples/apiUsage/python/api_usage.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,14 @@ def run_inference(serialized_engine: trt.IHostMemory, fc1_weights: trt.Weights,
364364

365365
use_optional_advanced_dynamic_shapes_api(runtime_config, inference_engine)
366366

367+
# Enable Cudagraphs Whole Graph Capture for accelerated inference
368+
# TensorRT-RTX can record CUDA graphs to reduce kernel launch overhead during JIT inference.
369+
# DISABLED skips graph capture and runs kernels directly on the stream
370+
# WHOLE_GRAPH_CAPTURE captures the complete computational graph of the model
371+
# and executes it atomically on the GPU stream. It automatically handles dynamic shape
372+
# cases, capturing the CUDA graph after shape-specialized kernels are compiled for a given shape.
373+
runtime_config.cuda_graph_strategy = trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE
374+
367375
# Create an engine execution context out of the deserialized engine.
368376
# TRT-RTX performs "Just-in-Time" (JIT) optimization here, targeting the current GPU.
369377
# JIT phase is faster than AOT phase, and typically completes in under 15 seconds.

samples/cmake/modules/get_version.cmake

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,20 @@ function(get_version include_dir version_variable soversion_variable)
4343
endif()
4444

4545
foreach(type MAJOR MINOR PATCH)
46-
string(REGEX MATCH "TRT_${type}_RTX [0-9]+" TRT_TYPE_STRING ${VERSION_STRINGS})
47-
string(REGEX MATCH "[0-9]+" TRT_${type} ${TRT_TYPE_STRING})
48-
if(NOT TRT_${type})
46+
set(trt_${type} "")
47+
foreach(version_line ${VERSION_STRINGS})
48+
string(REGEX MATCH "TRT_${type}_RTX [0-9]+" trt_type_string "${version_line}")
49+
if(trt_type_string)
50+
string(REGEX MATCH "[0-9]+" trt_${type} "${trt_type_string}")
51+
break()
52+
endif()
53+
endforeach()
54+
if(NOT DEFINED trt_${type})
4955
message(FATAL_ERROR "Failed to extract TRT_${type}_RTX from ${header_file}")
5056
endif()
5157
endforeach(type)
52-
set(${version_variable} ${TRT_MAJOR}.${TRT_MINOR}.${TRT_PATCH} PARENT_SCOPE)
53-
set(${soversion_variable} ${TRT_MAJOR}_${TRT_MINOR} PARENT_SCOPE)
58+
set(${version_variable} ${trt_MAJOR}.${trt_MINOR}.${trt_PATCH} PARENT_SCOPE)
59+
set(${soversion_variable} ${trt_MAJOR}_${trt_MINOR} PARENT_SCOPE)
5460
endfunction()
5561

5662
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)