Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions js/node/src/tensor_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
#include "tensor_helper.h"
#include "inference_session_wrap.h"

// napi_float16_array was added in Node.js 23 (N-API version 10).
// Define it for older Node.js versions to support Float16Array input tensors.
#ifndef napi_float16_array
#define napi_float16_array static_cast<napi_typedarray_type>(11)
#endif

// make sure consistent with origin definition
static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == 0, "definition not consistent with OnnxRuntime");
static_assert(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT == 1, "definition not consistent with OnnxRuntime");
Expand Down Expand Up @@ -196,9 +202,20 @@ Ort::Value NapiValueToOrtValue(Napi::Env env, Napi::Value value, OrtMemoryInfo*

auto tensorDataTypedArray = tensorDataValue.As<Napi::TypedArray>();
std::underlying_type_t<napi_typedarray_type> typedArrayType = tensorDataValue.As<Napi::TypedArray>().TypedArrayType();
ORT_NAPI_THROW_TYPEERROR_IF(DATA_TYPE_TYPEDARRAY_MAP[elemType] != typedArrayType, env,
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType], ") for ",
tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");

// For float16 tensors, accept both Uint16Array and Float16Array.
// Float16Array is a newer JavaScript type (ES2024) that may be passed by users.
// Both use 16-bit storage, so they are compatible at the binary level.
bool isValidTypedArray = (DATA_TYPE_TYPEDARRAY_MAP[elemType] == typedArrayType);
if (!isValidTypedArray && elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
// Accept Float16Array (napi_float16_array = 11) for float16 tensors
isValidTypedArray = (typedArrayType == napi_float16_array);
}

ORT_NAPI_THROW_TYPEERROR_IF(!isValidTypedArray, env,
"Tensor.data must be a typed array (", DATA_TYPE_TYPEDARRAY_MAP[elemType],
elemType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 ? " or Float16Array" : "",
") for ", tensorTypeString, " tensors, but got typed array (", typedArrayType, ").");

char* buffer = reinterpret_cast<char*>(tensorDataTypedArray.ArrayBuffer().Data());
size_t bufferByteOffset = tensorDataTypedArray.ByteOffset();
Expand Down