diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 90e488a1eda18..a7c37cd481894 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -28,17 +28,32 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { + // Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object. + // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession. static std::unique_ptr custom_op_domain = std::make_unique(); + + // Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_. static std::vector> created_custom_op_list; + + // Domain for native custom ops (domain name: "trt"). Owns the OrtCustomOpDomain object. + // Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession. static std::unique_ptr native_custom_op_domain = std::make_unique(); + + // Owns the TensorRTCustomOp objects for native custom ops. Raw pointers are stored in native_custom_op_domain->custom_ops_. + // Non-empty list indicates native custom ops have been registered (used to avoid re-registration on subsequent calls). static std::vector> native_custom_op_list; + + // Protects concurrent access to all the above static members. static std::mutex mutex; std::lock_guard lock(mutex); + + // Add already-initialized native ops to domain list + if (!native_custom_op_list.empty()) { + domain_list.push_back(native_custom_op_domain.get()); + } + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { domain_list.push_back(custom_op_domain.get()); - if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) { - domain_list.push_back(native_custom_op_domain.get()); - } return Status::OK(); } @@ -132,35 +147,36 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& } // Register native custom ops (register these independent of TRT plugin library availability) - const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"}; - int num_native_custom_ops = std::size(native_custom_ops_names); + if (native_custom_op_list.empty()) { + const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"}; + size_t num_native_custom_ops = std::size(native_custom_ops_names); + + for (size_t i = 0; i < num_native_custom_ops; i++) { + native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr)); + native_custom_op_list.back()->SetName(native_custom_ops_names[i]); + native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get()); + } - for (int i = 0; i < num_native_custom_ops; i++) { - native_custom_op_list.push_back(std::make_unique(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr)); - native_custom_op_list.back()->SetName(native_custom_ops_names[i]); - native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get()); + native_custom_op_domain->domain_ = "trt"; + domain_list.push_back(native_custom_op_domain.get()); } - native_custom_op_domain->domain_ = "trt"; - domain_list.push_back(native_custom_op_domain.get()); return Status::OK(); } void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { - if (domain != nullptr) { - for (auto ptr : domain->custom_ops_) { - if (ptr != nullptr) { - delete ptr; - } - } - delete domain; - } + (void)domain; // Suppress unused parameter warning + // The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList(). + // Callers receive raw pointers via .get(). + // 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit. + // 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects + // may still hold raw pointers to these same objects (handed out via domain_list). + // The static objects would be shared across EP instances and would persist for the program lifetime. } void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list) { - for (auto ptr : custom_op_domain_list) { - ReleaseTensorRTCustomOpDomain(ptr); - } + // Only clear the reference vector, don't delete the static domain objects. + custom_op_domain_list.clear(); } } // namespace onnxruntime