Skip to content

Commit 54f9ccd

Browse files
committed
refactor(runtime): extract TRTRuntimeConfig, address PR review
Address the structural PR feedback by extracting TensorRT-RTX-specific IRuntimeConfig state into its own type and collapsing the per-feature appliers that previously scattered `#ifdef TRT_MAJOR_RTX` through TRTEngine. What - New core/runtime/TRTRuntimeConfig.{h,cpp} owns the IRuntimeConfig shared_ptr plus (on TRT-RTX) the IRuntimeCache, runtime-cache path, dynamic shapes kernel strategy, CUDA graph strategy, and the rtx_native_cudagraphs_disabled one-shot flag. All per-feature appliers live there as public members and are no-ops on non-RTX builds, keeping the only `#ifdef TRT_MAJOR_RTX` scatter contained in this new file. - Strategy fields are now strongly-typed enums (`DynamicShapesKernelStrategy`, `CudaGraphStrategyOption`) with matching `to_string`/`to_int` helpers, validated at engine construction via `to_dynamic_shapes_kernel_strategy` / `to_cuda_ graph_strategy_option` rather than raw int ranges. - `TRTEngine::recreate_execution_context` is now backend-agnostic: it calls `runtime_cfg.ensure_initialized`, applies the allocation strategy, and creates the execution context via `createExecutionContext(IRuntimeConfig*)`. Both standard TensorRT and TRT-RTX go through this uniform path; only the three RTX-only setters (`setRuntimeCache`, `setDynamicShapesKernel SpecializationStrategy`, `setCudaGraphStrategy`) stay behind an `#ifdef TRT_MAJOR_RTX` guard inside the struct. - `~TRTEngine` now wraps cleanup in try/catch and delegates cache persistence to `TRTRuntimeConfig::save_runtime_cache_nothrow`, so stack unwinding can no longer propagate a cache-save failure out of the destructor. - `save_runtime_cache_nothrow` uses `std::filesystem` + atomic `tmp+rename` only; file locking is out of scope for this PR and will be introduced in a follow-up once we pick a portable mechanism. - `is_monolithic_capturable` asserts `exec_ctx` is non-null; the three RTX-only appliers `TORCHTRT_ASSERT` that `config` is live before dereferencing. - `disable_rtx_native_cudagraphs` persists the runtime cache before flipping the strategy so any kernels compiled under the internal capture survive to the next reload. - `TRTEngine::to_str` now emits human-readable strategy names (via `to_string(enum)`) instead of integer codes. - New serialization indices (`RUNTIME_CACHE_PATH_IDX`, `DYNAMIC_ SHAPES_KERNEL_STRATEGY_IDX`, `CUDA_GRAPH_STRATEGY_IDX`) are now `#ifdef TRT_MAJOR_RTX`-gated in runtime.h, register_jit_hooks.cpp, the FlattenedState tuple, the serialize/deserialize constructors, and `__obj_flatten__`. Standard TRT builds keep `SERIALIZATION_LEN == 11` so engines serialized there do not carry RTX-only slots. - Python `_TorchTensorRTModule` reads the RTX-only index accessors and writes the RTX-only engine-info slots only when `ENABLED_FEATURES.tensorrt_rtx` is true. Standard TRT users see no new behavior at runtime. - Deduplicated `_compiler.py` arguments after rebase on upstream main where PR pytorch#4184 had already added `dynamic_shapes_kernel_specialization_strategy`. Kept one copy of each arg; `cuda_graph_strategy` is threaded through all three compile() entry points. Build + tests - RTX build on A100 / L40S: libtorchtrt.so and libtorchtrt_ runtime.so link clean, no `#ifdef` diagnostics. Pre-commit checks pass (clang-format, black, isort, ruff, mypy, typos, buildifier). - All 35 runtime-cache/strategy tests pass; regression across test_000_runtime_cache.py (Python runtime), test_002_cudagraphs_ cpp.py, test_005_dynamic_allocation.py is green. Addresses review comments on PR pytorch#4202: - Guarding of new IDX entries and Python accessors on TRT_MAJOR_RTX / ENABLED_FEATURES.tensorrt_rtx. - Encapsulation of RTX-specific state in a dedicated type with enumerated strategies and transparent standard-TRT/RTX behavior. - Destructor exception safety. - Unification of the execution-context creation path via IRuntimeConfig. - Removal of file locking for runtime-cache persistence. - Debug asserts before dereferencing the live IRuntimeConfig. - Human-readable to_str output. - save_runtime_cache invoked from disable_rtx_native_cudagraphs.
1 parent 2b630e8 commit 54f9ccd

12 files changed

Lines changed: 434 additions & 278 deletions

File tree

core/runtime/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
load("@rules_cc//cc:defs.bzl", "cc_library")
22
load("@rules_pkg//:pkg.bzl", "pkg_tar")
33
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
4+
45
package(default_visibility = ["//visibility:public"])
56

67
config_setting(
@@ -66,6 +67,7 @@ cc_library(
6667
"RTDevice.cpp",
6768
"TRTEngine.cpp",
6869
"TRTEngineProfiler.cpp",
70+
"TRTRuntimeConfig.cpp",
6971
"execute_engine.cpp",
7072
"register_jit_hooks.cpp",
7173
"runtime.cpp",
@@ -75,6 +77,7 @@ cc_library(
7577
"RTDevice.h",
7678
"TRTEngine.h",
7779
"TRTEngineProfiler.h",
80+
"TRTRuntimeConfig.h",
7881
"runtime.h",
7982
],
8083
linkopts = [
@@ -107,6 +110,7 @@ filegroup(
107110
"RTDevice.h",
108111
"TRTEngine.h",
109112
"TRTEngineProfiler.h",
113+
"TRTRuntimeConfig.h",
110114
"runtime.h",
111115
],
112116
visibility = ["//visibility:public"],
@@ -121,6 +125,6 @@ pkg_tar(
121125
pkg_files(
122126
name = "include_pkg_files",
123127
srcs = [":include_files"],
124-
visibility = ["//visibility:public"],
125128
prefix = "include/torch_tensorrt/core/runtime/",
129+
visibility = ["//visibility:public"],
126130
)

core/runtime/TRTEngine.cpp

Lines changed: 40 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <algorithm>
22
#include <filesystem>
3-
#include <fstream>
43

54
#include <cuda_runtime.h>
65
#include "NvInfer.h"
@@ -12,12 +11,6 @@
1211
#include "core/util/prelude.h"
1312
#include "torch/torch.h"
1413

15-
#if defined(TRT_MAJOR_RTX) && !defined(_WIN32)
16-
#include <fcntl.h>
17-
#include <sys/file.h>
18-
#include <unistd.h>
19-
#endif
20-
2114
namespace torch_tensorrt {
2215
namespace core {
2316
namespace runtime {
@@ -102,10 +95,15 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
10295
serialized_info[SERIALIZED_METADATA_IDX],
10396
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
10497
? ResourceAllocationStrategy::kDynamic
105-
: ResourceAllocationStrategy::kStatic),
98+
: ResourceAllocationStrategy::kStatic)
99+
#ifdef TRT_MAJOR_RTX
100+
,
106101
serialized_info[RUNTIME_CACHE_PATH_IDX],
107102
std::stoi(serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]),
108-
std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX])) {}
103+
std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX])
104+
#endif
105+
) {
106+
}
109107

110108
TRTEngine::TRTEngine(
111109
const std::string& mod_name,
@@ -121,16 +119,9 @@ TRTEngine::TRTEngine(
121119
const std::string& runtime_cache_path,
122120
int dynamic_shapes_kernel_strategy,
123121
int cuda_graph_strategy) {
124-
this->runtime_cache_path = runtime_cache_path;
125-
TORCHTRT_CHECK(
126-
dynamic_shapes_kernel_strategy >= 0 && dynamic_shapes_kernel_strategy <= 2,
127-
"Invalid dynamic_shapes_kernel_strategy: " << dynamic_shapes_kernel_strategy
128-
<< ". Expected 0 (lazy), 1 (eager), or 2 (none).");
129-
this->dynamic_shapes_kernel_strategy = dynamic_shapes_kernel_strategy;
130-
TORCHTRT_CHECK(
131-
cuda_graph_strategy >= 0 && cuda_graph_strategy <= 1,
132-
"Invalid cuda_graph_strategy: " << cuda_graph_strategy << ". Expected 0 (disabled) or 1 (whole_graph_capture).");
133-
this->cuda_graph_strategy = cuda_graph_strategy;
122+
runtime_cfg.runtime_cache_path = runtime_cache_path;
123+
runtime_cfg.dynamic_shapes_kernel_strategy = to_dynamic_shapes_kernel_strategy(dynamic_shapes_kernel_strategy);
124+
runtime_cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(cuda_graph_strategy);
134125
TORCHTRT_CHECK(
135126
is_supported_on_current_platform(target_platform),
136127
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
@@ -288,12 +279,13 @@ TRTEngine::TRTEngine(
288279
}
289280

290281
TRTEngine::~TRTEngine() {
291-
torch::cuda::synchronize(device_info.id);
292-
#ifdef TRT_MAJOR_RTX
293-
save_runtime_cache();
294-
runtime_cache.reset();
295-
runtime_config.reset();
296-
#endif
282+
// Destructors must not throw; `save_runtime_cache_nothrow` is itself no-throw but we
283+
// wrap it defensively to keep stack unwinding safe in all circumstances.
284+
try {
285+
torch::cuda::synchronize(device_info.id);
286+
runtime_cfg.save_runtime_cache_nothrow();
287+
} catch (...) {
288+
}
297289
trt_engine_profiler.reset();
298290
exec_ctx.reset();
299291
cuda_engine.reset();
@@ -453,12 +445,8 @@ std::string TRTEngine::to_str() const {
453445
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
454446
ss << " Target Platform: " << target_platform << std::endl;
455447
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
456-
ss << " Runtime Cache Path: " << (runtime_cache_path.empty() ? "<disabled>" : runtime_cache_path) << std::endl;
457-
ss << " Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_strategy
458-
<< " (0=lazy, 1=eager, 2=none)" << std::endl;
459-
ss << " CUDA Graph Strategy: " << cuda_graph_strategy
460-
<< " (0=disabled, 1=whole_graph_capture)" << std::endl;
461448
// clang-format on
449+
runtime_cfg.write_to_str(ss);
462450
return ss.str();
463451
}
464452

@@ -502,10 +490,14 @@ FlattenedState TRTEngine::__obj_flatten__() {
502490
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
503491
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
504492
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
505-
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]),
493+
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])
494+
#ifdef TRT_MAJOR_RTX
495+
,
506496
std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]),
507497
std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]),
508-
std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]));
498+
std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])
499+
#endif
500+
);
509501
}
510502

511503
std::vector<std::string> TRTEngine::serialize() {
@@ -530,9 +522,12 @@ std::vector<std::string> TRTEngine::serialize() {
530522
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
531523
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
532524
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
533-
serialized_info[RUNTIME_CACHE_PATH_IDX] = this->runtime_cache_path;
534-
serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string(this->dynamic_shapes_kernel_strategy);
535-
serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(this->cuda_graph_strategy);
525+
#ifdef TRT_MAJOR_RTX
526+
serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path;
527+
serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] =
528+
std::to_string(static_cast<int>(runtime_cfg.dynamic_shapes_kernel_strategy));
529+
serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(static_cast<int>(runtime_cfg.cuda_graph_strategy));
530+
#endif
536531

537532
return serialized_info;
538533
}
@@ -553,183 +548,29 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
553548
}
554549

555550
bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const {
556-
#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION)
557-
// "lazy" strategy (0) swaps specialized kernels in mid-run, which would invalidate a
558-
// captured graph. Any other strategy (eager/none) combined with a capturable stream is
559-
// safe for outer monolithic capture.
560-
return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != 0;
561-
#else
562-
(void)stream;
563-
return true;
564-
#endif
551+
return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream);
565552
}
566553

567554
void TRTEngine::disable_rtx_native_cudagraphs() {
568-
#ifdef TRT_MAJOR_RTX
569-
if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == 0) {
570-
return;
555+
bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled;
556+
runtime_cfg.disable_rtx_native_cudagraphs(name);
557+
if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) {
558+
// The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx
559+
// so the new strategy takes effect for subsequent enqueueV3 calls.
560+
recreate_execution_context();
571561
}
572-
LOG_WARNING(
573-
"Outer CUDA stream capture detected; disabling TRT-RTX native CUDA graph strategy on engine "
574-
<< name << " for the remainder of its lifetime.");
575-
cuda_graph_strategy = 0;
576-
apply_cuda_graph_strategy();
577-
recreate_execution_context();
578-
rtx_native_cudagraphs_disabled = true;
579-
#endif
580562
}
581563

582564
void TRTEngine::recreate_execution_context() {
583-
#ifdef TRT_MAJOR_RTX
584-
if (!runtime_config) {
585-
runtime_config = make_trt(cuda_engine->createRuntimeConfig());
586-
TORCHTRT_CHECK(runtime_config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig");
587-
apply_runtime_cache();
588-
apply_dynamic_shapes_kernel_strategy();
589-
apply_cuda_graph_strategy();
590-
}
591-
runtime_config->setExecutionContextAllocationStrategy(
565+
runtime_cfg.ensure_initialized(cuda_engine.get());
566+
runtime_cfg.set_execution_context_allocation_strategy(
592567
resource_allocation_strategy == ResourceAllocationStrategy::kDynamic
593568
? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED
594569
: nvinfer1::ExecutionContextAllocationStrategy::kSTATIC);
595-
exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_config.get()));
596-
#else
597-
if (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
598-
exec_ctx =
599-
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
600-
} else {
601-
exec_ctx = make_trt(cuda_engine->createExecutionContext());
602-
}
603-
#endif
570+
exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get()));
604571
TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context");
605572
}
606573

607-
#ifdef TRT_MAJOR_RTX
608-
void TRTEngine::apply_runtime_cache() {
609-
if (runtime_cache_path.empty()) {
610-
LOG_DEBUG("Runtime cache disabled (no path configured).");
611-
return;
612-
}
613-
runtime_cache = make_trt(runtime_config->createRuntimeCache());
614-
if (runtime_cache.get() == nullptr) {
615-
LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped.");
616-
return;
617-
}
618-
load_runtime_cache();
619-
bool ok = runtime_config->setRuntimeCache(*runtime_cache);
620-
if (!ok) {
621-
LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused.");
622-
runtime_cache.reset();
623-
return;
624-
}
625-
LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path);
626-
}
627-
628-
void TRTEngine::apply_dynamic_shapes_kernel_strategy() {
629-
runtime_config->setDynamicShapesKernelSpecializationStrategy(
630-
static_cast<nvinfer1::DynamicShapesKernelSpecializationStrategy>(dynamic_shapes_kernel_strategy));
631-
LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << dynamic_shapes_kernel_strategy);
632-
}
633-
634-
void TRTEngine::apply_cuda_graph_strategy() {
635-
bool ok = runtime_config->setCudaGraphStrategy(
636-
cuda_graph_strategy == 1 ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE
637-
: nvinfer1::CudaGraphStrategy::kDISABLED);
638-
if (!ok) {
639-
LOG_WARNING("Failed to set CUDA graph strategy; continuing with default.");
640-
}
641-
}
642-
643-
void TRTEngine::load_runtime_cache() {
644-
if (runtime_cache == nullptr || runtime_cache_path.empty()) {
645-
return;
646-
}
647-
if (!std::filesystem::exists(runtime_cache_path)) {
648-
LOG_DEBUG("No existing runtime cache at " << runtime_cache_path);
649-
return;
650-
}
651-
#ifndef _WIN32
652-
int fd = ::open(runtime_cache_path.c_str(), O_RDONLY);
653-
if (fd < 0) {
654-
LOG_WARNING("Failed to open runtime cache for reading: " << runtime_cache_path);
655-
return;
656-
}
657-
if (::flock(fd, LOCK_SH) != 0) {
658-
LOG_WARNING("Failed to acquire shared lock on runtime cache; skipping load.");
659-
::close(fd);
660-
return;
661-
}
662-
#endif
663-
try {
664-
std::ifstream f(runtime_cache_path, std::ios::binary);
665-
std::vector<char> buf((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
666-
if (!buf.empty()) {
667-
bool ok = runtime_cache->deserialize(buf.data(), buf.size());
668-
if (ok) {
669-
LOG_INFO("Loaded runtime cache from " << runtime_cache_path << " (" << buf.size() << " bytes)");
670-
} else {
671-
LOG_WARNING("runtime_cache->deserialize returned false for " << runtime_cache_path);
672-
}
673-
}
674-
} catch (const std::exception& e) {
675-
LOG_WARNING("Failed to load runtime cache: " << e.what());
676-
}
677-
#ifndef _WIN32
678-
::flock(fd, LOCK_UN);
679-
::close(fd);
680-
#endif
681-
}
682-
683-
void TRTEngine::save_runtime_cache() {
684-
if (runtime_cache == nullptr || runtime_cache_path.empty()) {
685-
return;
686-
}
687-
auto host_mem = make_trt(runtime_cache->serialize());
688-
if (host_mem.get() == nullptr || host_mem->size() == 0) {
689-
return;
690-
}
691-
try {
692-
std::filesystem::path path(runtime_cache_path);
693-
if (path.has_parent_path()) {
694-
std::filesystem::create_directories(path.parent_path());
695-
}
696-
std::filesystem::path tmp_path = path;
697-
tmp_path += ".tmp";
698-
699-
#ifndef _WIN32
700-
int fd = ::open(tmp_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644);
701-
if (fd < 0) {
702-
LOG_WARNING("Failed to open runtime cache tmp file for writing: " << tmp_path.string());
703-
return;
704-
}
705-
if (::flock(fd, LOCK_EX) != 0) {
706-
LOG_WARNING("Failed to acquire exclusive lock on runtime cache tmp file; skipping save.");
707-
::close(fd);
708-
return;
709-
}
710-
ssize_t written = ::write(fd, host_mem->data(), host_mem->size());
711-
::flock(fd, LOCK_UN);
712-
::close(fd);
713-
if (written != static_cast<ssize_t>(host_mem->size())) {
714-
LOG_WARNING("Short write when saving runtime cache to " << tmp_path.string());
715-
return;
716-
}
717-
#else
718-
// Windows: best-effort write without a cross-process lock. Follow-up: LockFileEx.
719-
{
720-
std::ofstream out(tmp_path, std::ios::binary);
721-
out.write(reinterpret_cast<const char*>(host_mem->data()), host_mem->size());
722-
}
723-
LOG_WARNING("Runtime cache save on Windows runs without advisory locking; concurrent writers may race.");
724-
#endif
725-
std::filesystem::rename(tmp_path, path);
726-
LOG_INFO("Saved runtime cache to " << runtime_cache_path << " (" << host_mem->size() << " bytes)");
727-
} catch (const std::exception& e) {
728-
LOG_WARNING("Failed to save runtime cache: " << e.what());
729-
}
730-
}
731-
#endif // TRT_MAJOR_RTX
732-
733574
} // namespace runtime
734575
} // namespace core
735576
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)