Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,32 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose);
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
// Domain for TRT plugin custom ops (domain name: "trt.plugins"). Owns the OrtCustomOpDomain object.
// Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession.
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();

// Owns the TensorRTCustomOp objects for TRT plugins. Raw pointers are stored in custom_op_domain->custom_ops_.
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;

// Domain for native custom ops (domain name: "trt"). Owns the OrtCustomOpDomain object.
// Raw pointers from .get() are handed out to callers via domain_list and may be held by InferenceSession.
static std::unique_ptr<OrtCustomOpDomain> native_custom_op_domain = std::make_unique<OrtCustomOpDomain>();

// Owns the TensorRTCustomOp objects for native custom ops. Raw pointers are stored in native_custom_op_domain->custom_ops_.
// Non-empty list indicates native custom ops have been registered (used to avoid re-registration on subsequent calls).
static std::vector<std::unique_ptr<TensorRTCustomOp>> native_custom_op_list;

// Protects concurrent access to all the above static members.
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);

// Add already-initialized native ops to domain list
if (!native_custom_op_list.empty()) {
domain_list.push_back(native_custom_op_domain.get());
}

if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain.get());
if (native_custom_op_domain->domain_ != "" && native_custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(native_custom_op_domain.get());
}
return Status::OK();
}

Expand Down Expand Up @@ -132,35 +147,36 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
}

// Register native custom ops (register these independent of TRT plugin library availability)
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
int num_native_custom_ops = std::size(native_custom_ops_names);
if (native_custom_op_list.empty()) {
const char* native_custom_ops_names[] = {"TRT_FP4DynamicQuantize", "TRT_FP8QuantizeLinear", "TRT_FP8DequantizeLinear"};
size_t num_native_custom_ops = std::size(native_custom_ops_names);

for (size_t i = 0; i < num_native_custom_ops; i++) {
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
}

for (int i = 0; i < num_native_custom_ops; i++) {
native_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kNvTensorRTRTXExecutionProvider, nullptr));
native_custom_op_list.back()->SetName(native_custom_ops_names[i]);
native_custom_op_domain->custom_ops_.push_back(native_custom_op_list.back().get());
native_custom_op_domain->domain_ = "trt";
domain_list.push_back(native_custom_op_domain.get());
}

native_custom_op_domain->domain_ = "trt";
domain_list.push_back(native_custom_op_domain.get());
return Status::OK();
}

void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
if (domain != nullptr) {
for (auto ptr : domain->custom_ops_) {
if (ptr != nullptr) {
delete ptr;
}
}
delete domain;
}
(void)domain; // Suppress unused parameter warning
// The domain and its custom ops are owned by static unique_ptrs in CreateTensorRTCustomOpDomainList().
// Callers receive raw pointers via .get().
// 1. Manually deleting them would cause a double-free when the static unique_ptrs are destroyed at program exit.
// 2. Resetting the static unique_ptrs is also unsafe because other EP instances or InferenceSession objects
// may still hold raw pointers to these same objects (handed out via domain_list).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this indicate a different problem of someone calling to destroy objects that are in-use? Should we fix that bug?

Another question, static objects would be destroyed just prior to this DLL being unloaded. We want to make sure that the entities being destroyed do not refer to another DLL that could potentially be unloaded first.

It is for the reason people usually introduce a special API to have control of the process and to destroy things at a safe time and not to delegate it to a OS dependent specifics when shared objects are unloaded and the order of static destruction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this indicate a different problem of someone calling to destroy objects that are in-use?

Yes, this is a potential use-after-free scenario. I think it should get mitigated with current change.

We want to make sure that the entities being destroyed do not refer to another DLL
usually introduce a special API to have control of the process and to destroy things at a safe time

I see your point. Usually, we could have ref-counted concerned objects for handling this (or, make them part of EP instance, or session to avoid shared usage). However, I believe no cross-DLL memory is actually accessed during destruction today.

I think it will be better to decouple current change about avoid-repetition handling with any potential design changes on this part. Please let me know if this sounds okay to you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here my take on this. There is not a firmly defined policy here on handling these objects. I think we need to make a choice here:

  • Remove the Release functions and give away shared_ptrs OR
  • Use the Release functions so client code can destroy the objects when it KNOWS that raw pointers are no longer in use.

Until that happens, this is going to be never-ending chasing of the tail with different OS dependent issues.

// The static objects would be shared across EP instances and would persist for the program lifetime.
}

void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
for (auto ptr : custom_op_domain_list) {
ReleaseTensorRTCustomOpDomain(ptr);
}
// Only clear the reference vector, don't delete the static domain objects.
custom_op_domain_list.clear();
}

} // namespace onnxruntime
Loading