Skip to content

Commit b49e69e

Browse files
authored
[NV TensorRT RTX] Handle unsupported data types (#25953)
### Description <!-- Describe your changes. --> The EP will reject the node with unsupported data types. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> The user will face a crash if the model with an unsupported datatype is used.
1 parent 31dcc60 commit b49e69e

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,108 @@ struct ShutdownProtobuf {
8585

8686
namespace onnxruntime {
8787

88+
// Helper function to check if a data type is supported by NvTensorRTRTX EP
89+
static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
90+
switch (data_type) {
91+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point
92+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point
93+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16
94+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean
95+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer
96+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer
97+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer
98+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer
99+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
100+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
101+
return true;
102+
default:
103+
return false;
104+
}
105+
}
106+
107+
// Helper function to get data type name as string
108+
static std::string GetDataTypeName(ONNXTensorElementDataType data_type) {
109+
switch (data_type) {
110+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
111+
return "FLOAT";
112+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
113+
return "FLOAT16";
114+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
115+
return "BFLOAT16";
116+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
117+
return "BOOL";
118+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
119+
return "INT4";
120+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
121+
return "INT8";
122+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
123+
return "UINT8";
124+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
125+
return "INT32";
126+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
127+
return "INT64";
128+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
129+
return "FLOAT8E4M3FN";
130+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
131+
return "DOUBLE";
132+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
133+
return "STRING";
134+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
135+
return "UINT16";
136+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
137+
return "UINT32";
138+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
139+
return "UINT64";
140+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
141+
return "INT16";
142+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
143+
return "COMPLEX64";
144+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
145+
return "COMPLEX128";
146+
default:
147+
return "UNKNOWN(" + std::to_string(static_cast<int>(data_type)) + ")";
148+
}
149+
}
150+
151+
// Helper function to check if a node has supported data types
152+
static bool CheckNodeDataTypes(const Node* node) {
153+
// Check input data types
154+
for (const auto* input_def : node->InputDefs()) {
155+
if (input_def->Exists()) {
156+
const auto* type_proto = input_def->TypeAsProto();
157+
if (type_proto && type_proto->has_tensor_type()) {
158+
auto data_type = static_cast<ONNXTensorElementDataType>(type_proto->tensor_type().elem_type());
159+
if (!IsSupportedDataType(data_type)) {
160+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name()
161+
<< "' (OpType: " << node->OpType()
162+
<< ") has unsupported input data type: " << GetDataTypeName(data_type)
163+
<< " for input '" << input_def->Name() << "'";
164+
return false;
165+
}
166+
}
167+
}
168+
}
169+
170+
// Check output data types
171+
for (const auto* output_def : node->OutputDefs()) {
172+
if (output_def->Exists()) {
173+
const auto* type_proto = output_def->TypeAsProto();
174+
if (type_proto && type_proto->has_tensor_type()) {
175+
auto data_type = static_cast<ONNXTensorElementDataType>(type_proto->tensor_type().elem_type());
176+
if (!IsSupportedDataType(data_type)) {
177+
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name()
178+
<< "' (OpType: " << node->OpType()
179+
<< ") has unsupported output data type: " << GetDataTypeName(data_type)
180+
<< " for output '" << output_def->Name() << "'";
181+
return false;
182+
}
183+
}
184+
}
185+
}
186+
187+
return true;
188+
}
189+
88190
void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
89191
uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept {
90192
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
@@ -478,10 +580,12 @@ Status BindContextInput(Ort::KernelContext& ctx,
478580
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
479581
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
480582
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
583+
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
481584
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
482585
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
483586
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
484587
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
588+
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
485589
default: {
486590
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
487591
"NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.");
@@ -562,10 +666,12 @@ Status BindContextOutput(Ort::KernelContext& ctx,
562666
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
563667
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
564668
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
669+
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
565670
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
566671
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
567672
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
568673
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
674+
CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
569675
default: {
570676
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
571677
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
@@ -624,10 +730,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
624730
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
625731
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t)
626732
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
733+
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t)
627734
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
628735
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
629736
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
630737
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
738+
CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t)
631739
default: {
632740
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
633741
"NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported.");
@@ -1878,6 +1986,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
18781986
/* Iterate all the nodes and exclude the node if:
18791987
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
18801988
* 2. It's a DDS op.
1989+
* 3. It has unsupported data types.
18811990
*/
18821991
for (const auto& index : nodes_vector) {
18831992
const auto& node = graph.GetNode(node_index[index]);
@@ -1917,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
19172026
supported_node = false;
19182027
}
19192028

2029+
// Check data types and print warnings for unsupported types
2030+
if (supported_node) {
2031+
if (!CheckNodeDataTypes(node)) {
2032+
supported_node = false;
2033+
LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Node '" << node->Name()
2034+
<< "' (OpType: " << node->OpType()
2035+
<< ") excluded due to unsupported data types";
2036+
}
2037+
}
2038+
19202039
if (supported_node) {
19212040
if (new_subgraph) {
19222041
parser_nodes_vector.emplace_back();

0 commit comments

Comments
 (0)