-
Notifications
You must be signed in to change notification settings - Fork 393
Plugin TensorRT EP using ORT EP ABI #527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…e mutiple GPU devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need all of these helper files? this one doesn't seem to be compiled, with the suffix ".ccc".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching that, i removed them.
: severity == Severity::kWARNING ? "WARNING" | ||
: severity == Severity::kINFO ? " INFO" | ||
: "UNKNOWN"); | ||
if (severity <= Severity::kERROR) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to actually log something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added ORT default logger for TRT logger to print/log messages.
Will also add back default logger for plugin TRT EP as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general comment: can we put all the code that doesn't need to be in the global namespace into a top-level namespace? maybe trt_ep
or something. there is some existing code in onnxruntime
but we probably should change that too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
namespace trt_ep
is added and onnxruntime
is removed. Thanks for the suggestion.
// char hostname[HOST_NAME_MAX]; | ||
// if (gethostname(hostname, HOST_NAME_MAX) != 0) | ||
// strcpy(hostname, "?"); | ||
// #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general: there seems to be quite a lot of commented out code in this PR. it's not ideal because it can easily get out of date. can we avoid adding commented out code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the commented code is removed.
@@ -0,0 +1,161 @@ | |||
# usage: | |||
# cd build/ | |||
# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/home/lochi/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/home/lochi/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps should replace lochi
with a generic user
or something like it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it be put in the c_cxx folder along with other C/C++ examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps should replace
lochi
with a genericuser
or something like it
Removed specific username in the instruction.
/* | ||
std::vector<const OrtOpAttr*> node_attributes(num_node_attributes); | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributes(node, node_attributes.data(), node_attributes.size())); | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: not needed anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i removed almost all the commented code.
|
||
auto node = nodes[0]; | ||
|
||
size_t num_node_attributes = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it's removed.
|
||
const OrtOpAttr* node_attr = nullptr; | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "embed_mode", &node_attr)); | ||
const int64_t embed_mode = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->i(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this EP is largely an example of how to develop an EP, should we try to use the public C apis to get the attribute values (i.e., ReadOpAttr
) when possible? I think we want to show that an EP doesn't necessarily have to build with onnx to use these APIs.
Perhaps this wasn't done initially because the C API is cumbersome. But now that we have the C++ ORT APIs, getting the attribute values should hopefully be a one-liner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good suggestion. I used the C++ API to get attribute values instead.
// Get engine from byte stream. | ||
node_attr = nullptr; | ||
RETURN_IF_ERROR(ort_api.Node_GetAttributeByName(node, "ep_cache_context", &node_attr)); | ||
const std::string& context_binary = reinterpret_cast<const ONNX_NAMESPACE::AttributeProto*>(node_attr)->s(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Could potentially use the C++ ORT API to get attr value?
} else { | ||
output_tensors[i] = ctx.GetOutput(output_index, output_shapes); | ||
auto& output_tensor = output_tensors[i]; | ||
const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ API functions like this one can throw exceptions. Are these exceptions caught/handled somewhere in the EP (and maybe converted to a Status that can be returned to ORT)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! i added try/catch there.
// LOGS_DEFAULT(WARNING) << "[TensorRT EP] No graph will run on TensorRT execution provider"; | ||
} else if (number_of_trt_nodes == nodes.size()) { | ||
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; | ||
} else { | ||
// LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should the log statements be uncommented? (or maybe remove the if statements).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added default logger for this plugin TRT EP and it can now log something.
@chilo-ms Can you please also guide the changes for the wheel creation for using Python APIs independently with ORT TRT EP or any custom EP standalone code, for the latest API/ABI interfaces offered by ORT Core (available from ORT version 1.23.0) ? As we know that we have decoupled the TRT EP from ORT source code, we can no longer access & compile below file - |
You don't need to make any changes for creating the ORT GPU wheel. Here is the reference code: import onnxruntime as onnxrt
import numpy as np
ep_lib_path = "C:\\path\\to\\plugin_trt_ep\\TensorRTEp.dll"
ep_name = "TensorRTEp"
ep_registration_name = ep_name
onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path)
ep_devices = onnxrt.get_ep_devices()
trt_ep_device = None
for ep_device in ep_devices:
if ep_device.ep_name == ep_name:
trt_ep_device = ep_device
assert trt_ep_device != None
sess_options = onnxrt.SessionOptions()
sess_options.add_provider_for_devices([trt_ep_device], {'trt_engine_cache_enable': '1'})
assert sess_options.has_providers() == True
# Run sample model and check output
sess = onnxrt.InferenceSession("C:\\modles\\mul_1.onnx", sess_options=sess_options)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
input_name = sess.get_inputs()[0].name
res = sess.run([], {input_name: x})
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
onnxrt.unregister_execution_provider_library(ep_registration_name) The |
|
||
add_definitions(-DONNX_NAMESPACE=onnx) | ||
add_definitions(-DONNX_ML) | ||
add_definitions(-DNV_TENSORRT_MAJOR=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does NV_TENSORRT_MAJOR
need to be defined here? should we leave that to TensorRT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, i removed it.
OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP. | ||
} | ||
// TODO: Handle destructor | ||
//~CUDAAllocator(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does anything need to be done for the CUDAAllocator and CUDAPinnedAllocator destructors or is the default implementation fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default implementation is fine. I removed the comment.
ENFORCE(num_nodes == 1); | ||
|
||
std::vector<const OrtNode*> nodes(num_nodes); | ||
RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RETURN_IF_ERROR
returns a OrtStatus*
, right? but this function returns a bool and we probably don't want to convert a non-nullptr OrtStatus*
to true.
also, there are multiple error handling mechanisms used in this function. is it possible to simplify the error handling by consistently returning an OrtStatus*
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made them all return an OrtStatus*
try { | ||
ENFORCE(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING); | ||
} catch (const Ort::Exception& e) { | ||
return ort_api.CreateStatus(ORT_EP_FAIL, e.what()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general: can this try-enforce-catch pattern be replaced with RETURN_IF_NOT()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice suggestion, replace it with RETURN_IF_NOT()
true, // serialize refitted engine to disk | ||
detailed_build_log_); | ||
if (status != nullptr) { | ||
return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
status
should be freed or returned directly if it is not nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, i made the code to return it directly.
|
||
namespace tensorrt_ptr { | ||
|
||
struct TensorrtInferDeleter { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need TensorrtInferDeleter
? does std::default_delete<T>
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, yup, std::default_delete<T>
works and it's much more robust.
("Plugin EP has been created with name " + name_).c_str(), | ||
ORT_FILE, __LINE__, __FUNCTION__); | ||
// ignore status for now | ||
(void)ort_status; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe at least store it in Ort::Status
so it still gets released if it's not nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, Ort::Status
is used now.
*/ | ||
OrtStatus* EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const { | ||
size_t num_nodes = 0; | ||
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); | |
RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed.
const OrtApi* ort_api); | ||
|
||
bool IsAbsolutePath(const std::string& path_string) { | ||
#ifdef _WIN32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general: do we need separate windows/non-windows implementations or is it possible to have a single cross-platform implementation of these path helper functions using std::filesystem
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we can. I rewrite the function to have a single cross-platform implementation.
if (ValidateEPCtxNode(&graph) != nullptr) { | ||
return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EPContext node"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ValidateEPCtxNode()
returns a non-null OrtStatus*
, it will be leaked. maybe just return it directly.
if (ValidateEPCtxNode(&graph) != nullptr) { | |
return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EPContext node"); | |
} | |
RETURN_IF_ERROR(ValidateEPCtxNode(&graph)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. updated.
} \ | ||
} while (0) | ||
|
||
#define RETURN_IF_ORT_STATUS_ERROR(fn) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, it is confusing to have RETURN_IF_ORTSTATUS_ERROR
, RETURN_IF_ORT_STATUS_ERROR
, and RETURN_IF_ERROR
. can we just have a single one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep RETURN_IF_ERROR
and the rest are removed.
#endif // #ifdef _WIN32 | ||
|
||
#ifdef NO_EXCEPTIONS | ||
void PrintFinalMessage(const char* msg) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's removed now.
} // namespace trt_ep | ||
|
||
// To make symbols visible on macOS/iOS | ||
#ifdef __APPLE__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to support macOS or iOS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the macro is removed now.
Description
This plugin TRT EP is migrated from the original TRT EP and provides the implementations of
OrtEpFactory
,OrtEp
,OrtNodeComputeInfo
,OrtDataTransferImpl
... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0).Plugin EP should be built independently without the ORT source code, as it relies on the API/ABI provided by ORT. Therefore, it should reside in a separate repository outside the main ORT repository.
This plugin TRT EP can be built on Linux and Windows with "Debug" and "Release" mode.
Build plugin TRT EP on Windows:
(Note: The ORT_HOME should contain the include and lib folder as below)
Build plugin TRT EP on Linux:
Run the plugin TRT EP:
Please use
onnxruntime_perf_test
oronnx_test_runner
TODO
-Currently
GetCapability
assumes the whole graph is TRT eligible. Will have another PR to add TRT parser call for partition.-Add simple unit test