|
11 | 11 | #include "tensor_helper.h" |
12 | 12 | #include "inference_session_wrap.h" |
13 | 13 |
|
| 14 | +// napi_float16_array was added in Node.js 23 (N-API version 10). |
| 15 | +// Define it for older Node.js versions to support Float16Array input tensors. |
| 16 | +#ifndef napi_float16_array |
| 17 | +#define napi_float16_array static_cast<napi_typedarray_type>(11) |
| 18 | +#endif |
| 19 | + |
14 | 20 | // make sure consistent with origin definition |
15 | 21 | static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime"); |
16 | 22 | static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT == 1, "definition not consistent with OnnxRuntime"); |
@@ -196,9 +202,19 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo* |
196 | 202 |
|
197 | 203 | auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>(); |
198 | 204 | std::underlying_type_t<napi_typedarray_type> typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType(); |
199 | | - ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env, |
200 | | - "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ", |
201 | | - tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); |
| 205 | + |
| 206 | + // For float16 tensors, accept both Uint16Array and Float16Array. |
| 207 | + // Float16Array is a newer JavaScript type (ES2024) that may be passed by users. |
| 208 | + // Both use 16-bit storage, so they are compatible at the binary level. |
| 209 | + bool isValidTypedArray = (DATA_TYPE_TYPEDARRAY_MAP[elemType] == typedArrayType); |
| 210 | + if (!isValidTypedArray && elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { |
| 211 | + // Accept Float16Array (napi_float16_array = 11) for float16 tensors |
| 212 | + isValidTypedArray = (typedArrayType == napi_float16_array); |
| 213 | + } |
| 214 | + |
| 215 | + ORT_NAPI_THROW_TYPEERROR_IF(!isValidTypedArray, env, |
| 216 | + "Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], |
| 217 | + " or Float16Array) for ", tensorTypeString, " tensors, but got typed array (", typedArrayType, ")."); |
202 | 218 |
|
203 | 219 | char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data()); |
204 | 220 | size_t bufferByteOffset = tensorDataTypedArray.ByteOffset(); |
|
0 commit comments