@@ -85,6 +85,108 @@ struct ShutdownProtobuf {
8585
8686namespace onnxruntime {
8787
88+ // Helper function to check if a data type is supported by NvTensorRTRTX EP
89+ static bool IsSupportedDataType (ONNXTensorElementDataType data_type) {
90+ switch (data_type) {
91+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point
92+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point
93+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16
94+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean
95+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer
96+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer
97+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer
98+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer
99+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
100+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
101+ return true ;
102+ default :
103+ return false ;
104+ }
105+ }
106+
107+ // Helper function to get data type name as string
108+ static std::string GetDataTypeName (ONNXTensorElementDataType data_type) {
109+ switch (data_type) {
110+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
111+ return " FLOAT" ;
112+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
113+ return " FLOAT16" ;
114+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
115+ return " BFLOAT16" ;
116+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
117+ return " BOOL" ;
118+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4:
119+ return " INT4" ;
120+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
121+ return " INT8" ;
122+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
123+ return " UINT8" ;
124+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
125+ return " INT32" ;
126+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
127+ return " INT64" ;
128+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN:
129+ return " FLOAT8E4M3FN" ;
130+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
131+ return " DOUBLE" ;
132+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
133+ return " STRING" ;
134+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
135+ return " UINT16" ;
136+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
137+ return " UINT32" ;
138+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
139+ return " UINT64" ;
140+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
141+ return " INT16" ;
142+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
143+ return " COMPLEX64" ;
144+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
145+ return " COMPLEX128" ;
146+ default :
147+ return " UNKNOWN(" + std::to_string (static_cast <int >(data_type)) + " )" ;
148+ }
149+ }
150+
151+ // Helper function to check if a node has supported data types
152+ static bool CheckNodeDataTypes (const Node* node) {
153+ // Check input data types
154+ for (const auto * input_def : node->InputDefs ()) {
155+ if (input_def->Exists ()) {
156+ const auto * type_proto = input_def->TypeAsProto ();
157+ if (type_proto && type_proto->has_tensor_type ()) {
158+ auto data_type = static_cast <ONNXTensorElementDataType>(type_proto->tensor_type ().elem_type ());
159+ if (!IsSupportedDataType (data_type)) {
160+ LOGS_DEFAULT (WARNING) << " [NvTensorRTRTX EP] Node '" << node->Name ()
161+ << " ' (OpType: " << node->OpType ()
162+ << " ) has unsupported input data type: " << GetDataTypeName (data_type)
163+ << " for input '" << input_def->Name () << " '" ;
164+ return false ;
165+ }
166+ }
167+ }
168+ }
169+
170+ // Check output data types
171+ for (const auto * output_def : node->OutputDefs ()) {
172+ if (output_def->Exists ()) {
173+ const auto * type_proto = output_def->TypeAsProto ();
174+ if (type_proto && type_proto->has_tensor_type ()) {
175+ auto data_type = static_cast <ONNXTensorElementDataType>(type_proto->tensor_type ().elem_type ());
176+ if (!IsSupportedDataType (data_type)) {
177+ LOGS_DEFAULT (WARNING) << " [NvTensorRTRTX EP] Node '" << node->Name ()
178+ << " ' (OpType: " << node->OpType ()
179+ << " ) has unsupported output data type: " << GetDataTypeName (data_type)
180+ << " for output '" << output_def->Name () << " '" ;
181+ return false ;
182+ }
183+ }
184+ }
185+ }
186+
187+ return true ;
188+ }
189+
88190void * OutputAllocator::reallocateOutputAsync (char const * /* tensorName*/ , void * /* currentMemory*/ , uint64_t size,
89191 uint64_t /* alignment*/ , cudaStream_t /* stream*/ ) noexcept {
90192 // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
@@ -478,10 +580,12 @@ Status BindContextInput(Ort::KernelContext& ctx,
478580 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t )
479581 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t )
480582 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool )
583+ CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t )
481584 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t )
482585 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t )
483586 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t )
484587 CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t )
588+ CASE_GET_INPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t )
485589 default : {
486590 return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
487591 " NvTensorRTRTX EP input onnx tensor data type: " + std::to_string (tensor_type) + " not supported." );
@@ -562,10 +666,12 @@ Status BindContextOutput(Ort::KernelContext& ctx,
562666 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t )
563667 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t )
564668 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool )
669+ CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t )
565670 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t )
566671 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t )
567672 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t )
568673 CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t )
674+ CASE_GET_OUTPUT_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t )
569675 default : {
570676 return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
571677 " NvTensorRTRTX EP output tensor data type: " + std::to_string (output_type) + " not supported." );
@@ -624,10 +730,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
624730 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t )
625731 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t )
626732 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool )
733+ CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t )
627734 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t )
628735 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t )
629736 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t )
630737 CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t )
738+ CASE_COPY_TENSOR (ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t )
631739 default : {
632740 return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
633741 " NvTensorRTRTX EP output tensor data type: " + std::to_string (output_type) + " not supported." );
@@ -1878,6 +1986,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
18781986 /* Iterate all the nodes and exclude the node if:
18791987 * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
18801988 * 2. It's a DDS op.
1989+ * 3. It has unsupported data types.
18811990 */
18821991 for (const auto & index : nodes_vector) {
18831992 const auto & node = graph.GetNode (node_index[index]);
@@ -1917,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
19172026 supported_node = false ;
19182027 }
19192028
2029+ // Check data types and print warnings for unsupported types
2030+ if (supported_node) {
2031+ if (!CheckNodeDataTypes (node)) {
2032+ supported_node = false ;
2033+ LOGS_DEFAULT (INFO) << " [NvTensorRTRTX EP] Node '" << node->Name ()
2034+ << " ' (OpType: " << node->OpType ()
2035+ << " ) excluded due to unsupported data types" ;
2036+ }
2037+ }
2038+
19202039 if (supported_node) {
19212040 if (new_subgraph) {
19222041 parser_nodes_vector.emplace_back ();
0 commit comments