Skip to content

Commit 304db7c

Browse files
avoid repetitive creation fp4 native-custom-op domains
1 parent f83d4d0 commit 304db7c

File tree

1 file changed

+25
-22
lines changed

1 file changed

+25
-22
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
3232
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
3333
static std::unique_ptr<OrtCustomOpDomain> native_custom_op_domain = std::make_unique<OrtCustomOpDomain>();
3434
static std::vector<std::unique_ptr<TensorRTCustomOp>> native_custom_op_list;
35+
static bool native_custom_ops_initialized = false;
3536
static std::mutex mutex;
3637
std::lock_guard<std::mutex> lock(mutex);
38+
39+
// Add already-initialized native ops to domain list
40+
if (native_custom_ops_initialized) {
41+
domain_list.push_back(native_custom_op_domain.get());
42+
}
43+
3744
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
3845
domain_list.push_back(custom_op_domain.get());
39-
if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) {
40-
domain_list.push_back(native_custom_op_domain.get());
41-
}
4246
return Status::OK();
4347
}
4448

@@ -132,35 +136,34 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
132136
}
133137

134138
// Register native custom ops (register these independent of TRT plugin library availability)
135-
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
136-
int num_native_custom_ops = std::size(native_custom_ops_names);
139+
if (!native_custom_ops_initialized) {
140+
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
141+
int num_native_custom_ops = std::size(native_custom_ops_names);
142+
143+
for (int i = 0; i < num_native_custom_ops; i++) {
144+
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
145+
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
146+
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
147+
}
137148

138-
for (int i = 0; i < num_native_custom_ops; i++) {
139-
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
140-
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
141-
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
149+
native_custom_op_domain->domain_ = "trt";
150+
domain_list.push_back(native_custom_op_domain.get());
151+
native_custom_ops_initialized = true;
142152
}
143153

144-
native_custom_op_domain->domain_ = "trt";
145-
domain_list.push_back(native_custom_op_domain.get());
146154
return Status::OK();
147155
}
148156

149157
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
150-
if (domain != nullptr) {
151-
for (auto ptr : domain->custom_ops_) {
152-
if (ptr != nullptr) {
153-
delete ptr;
154-
}
155-
}
156-
delete domain;
157-
}
158+
(void)domain; // Suppress unused parameter warning
159+
// The custom ops (TensorRTCustomOp) and domain (OrtCustomOpDomain) are marked as static
160+
// with unique_ptr at the time of creation in CreateTensorRTCustomOpDomainList() function.
161+
// Deleting them here can cause double-delete.
158162
}
159163

160164
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
161-
for (auto ptr : custom_op_domain_list) {
162-
ReleaseTensorRTCustomOpDomain(ptr);
163-
}
165+
// Only clear the reference vector, don't delete the static domain objects.
166+
custom_op_domain_list.clear();
164167
}
165168

166169
} // namespace onnxruntime

0 commit comments

Comments
 (0)