Skip to content

Commit b524229

Browse files
authored
[WebNN] Accept Float16Array for float16 data type if it is available (#23894)
Float16Array is now shipping and WebNN Chromium implementation has accepted it. We should allow it in WebNN EP as well.
1 parent 95225dd commit b524229

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ export class WebNNBackend {
314314
bufferView = new Float32Array(buffer);
315315
break;
316316
case 'float16':
317-
bufferView = new Uint16Array(buffer);
317+
bufferView =
318+
typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer);
318319
break;
319320
case 'int32':
320321
bufferView = new Int32Array(buffer);

onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,17 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build
219219
sign_buffer.set(0, -1.0f);
220220
sign_buffer.set(1, 1.0f);
221221
} else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
222-
sign_buffer = emscripten::val::global("Uint16Array").new_(2);
223-
sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f));
224-
sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f));
222+
if (model_builder.IsFloat16ArrayAvailable()) {
223+
// Float16Array is avaliable - use Float16Array.
224+
sign_buffer = emscripten::val::global("Float16Array").new_(2);
225+
sign_buffer.set(0, -1.0f);
226+
sign_buffer.set(1, 1.0f);
227+
} else {
228+
// Float16Array is not available - use Uint16Array instead.
229+
sign_buffer = emscripten::val::global("Uint16Array").new_(2);
230+
sign_buffer.set(0, PackFloat32ToUint16AsFloat16(-1.0f));
231+
sign_buffer.set(1, PackFloat32ToUint16AsFloat16(1.0f));
232+
}
225233
} else {
226234
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported input data type: ", input_data_type);
227235
}

onnxruntime/core/providers/webnn/builders/model_builder.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ Status ModelBuilder::RegisterInitializers() {
197197

198198
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
199199
// buffers in JS side. Simply create a copy to fix it.
200-
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
200+
view = view.call<emscripten::val>("slice");
201+
operand = wnn_builder_.call<emscripten::val>("constant", desc, view["buffer"]);
201202
}
202203
} else {
203204
// TODO: support other type.
@@ -350,7 +351,8 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
350351
emscripten::val operand = emscripten::val::object();
351352
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
352353
// buffers in JS side. Simply create a copy to fix it.
353-
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
354+
view = view.call<emscripten::val>("slice");
355+
operand = wnn_builder_.call<emscripten::val>("constant", desc, view["buffer"]);
354356

355357
AddOperand(name, operand);
356358
mem_persist_buffers_.push_back(std::move(persist_buffer));

onnxruntime/core/providers/webnn/builders/model_builder.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ModelBuilder {
3030
Status Compile(std::unique_ptr<Model>& model) ORT_MUST_USE_RESULT;
3131

3232
// Accessors for members.
33+
bool IsFloat16ArrayAvailable() const { return is_float16array_available_; }
3334
const GraphViewer& GetGraphViewer() const { return graph_viewer_; }
3435
InitializedTensorSet GetInitializerTensors();
3536

@@ -68,6 +69,8 @@ class ModelBuilder {
6869
private:
6970
const GraphViewer& graph_viewer_;
7071
const logging::Logger& logger_;
72+
const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() &&
73+
emscripten::val::global("Float16Array").hasOwnProperty("from");
7174

7275
emscripten::val wnn_context_ = emscripten::val::undefined();
7376
emscripten::val wnn_builder_ = emscripten::val::undefined();
@@ -172,9 +175,12 @@ const emscripten::val& ModelBuilder::CreateOrGetConstant(const int32_t& data_typ
172175
}
173176
break;
174177
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
175-
buffer = emscripten::val::global("Uint16Array").new_(num_elements);
178+
buffer = is_float16array_available_
179+
? emscripten::val::global("Float16Array").new_(num_elements)
180+
: emscripten::val::global("Uint16Array").new_(num_elements);
176181
if (value) {
177-
buffer.call<void>("fill", emscripten::val(PackFloat32ToUint16AsFloat16(value)));
182+
buffer.call<void>("fill",
183+
emscripten::val(is_float16array_available_ ? value : PackFloat32ToUint16AsFloat16(value)));
178184
}
179185
break;
180186
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:

0 commit comments

Comments
 (0)