Skip to content

Commit 3506183

Browse files
avoid new variable for native-ops check
1 parent 872f539 commit 3506183

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,15 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
4040
static std::unique_ptr<OrtCustomOpDomain> native_custom_op_domain = std::make_unique<OrtCustomOpDomain>();
4141

4242
// Owns the TensorRTCustomOp objects for native custom ops. Raw pointers are stored in native_custom_op_domain->custom_ops_.
43+
// Non-empty list indicates native custom ops have been registered (used to avoid re-registration on subsequent calls).
4344
static std::vector<std::unique_ptr<TensorRTCustomOp>> native_custom_op_list;
4445

45-
// Tracks whether native custom ops have been registered to avoid re-registration on subsequent calls.
46-
static bool native_custom_ops_initialized = false;
47-
4846
// Protects concurrent access to all the above static members.
4947
static std::mutex mutex;
5048
std::lock_guard<std::mutex> lock(mutex);
5149

5250
// Add already-initialized native ops to domain list
53-
if (native_custom_ops_initialized) {
51+
if (!native_custom_op_list.empty()) {
5452
domain_list.push_back(native_custom_op_domain.get());
5553
}
5654

@@ -149,7 +147,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
149147
}
150148

151149
// Register native custom ops (register these independent of TRT plugin library availability)
152-
if (!native_custom_ops_initialized) {
150+
if (native_custom_op_list.empty()) {
153151
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
154152
size_t num_native_custom_ops = std::size(native_custom_ops_names);
155153

@@ -161,7 +159,6 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
161159

162160
native_custom_op_domain->domain_ = "trt";
163161
domain_list.push_back(native_custom_op_domain.get());
164-
native_custom_ops_initialized = true;
165162
}
166163

167164
return Status::OK();

0 commit comments

Comments
 (0)