Skip to content

Commit 3cd14fd

Browse files
committed
[WebNN] Always execute decomposed *SimplifiedLayerNormalization in FP32
Decomposed [Skip]SimplifiedLayerNormalization will lose precision in FP16, we'd like to add cast (to: fp32) ops around it in WebNN EP to ensure its precision rather than manully add cast nodes in each model file.
1 parent c18e06d commit 3cd14fd

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

Diff for: onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc

+38-3
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,29 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
114114
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input type");
115115
emscripten::val common_options = emscripten::val::object();
116116

117+
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
118+
// Decomposed *SimplifiedLayerNormalization may lose precision if its data type is float16,
119+
// cast all inputs to float32 to ensure precision.
120+
common_options.set("label", node.Name() + "_cast_input_to_fp32");
121+
input = model_builder.GetBuilder().call<emscripten::val>("cast", input, emscripten::val("float32"));
122+
123+
common_options.set("label", node.Name() + "_cast_scale_to_fp32");
124+
scale = model_builder.GetBuilder().call<emscripten::val>("cast", scale, emscripten::val("float32"));
125+
126+
if (!bias.isUndefined()) {
127+
common_options.set("label", node.Name() + "_cast_bias_to_fp32");
128+
bias = model_builder.GetBuilder().call<emscripten::val>("cast", bias, emscripten::val("float32"));
129+
}
130+
}
131+
117132
// If it is SkipSimplifiedLayerNormalization, add the skip and bias (if it exists) to the input.
118133
if (op_type == "SkipSimplifiedLayerNormalization") {
119134
emscripten::val skip = model_builder.GetOperand(input_defs[1]->Name());
135+
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
136+
// Cast skip to float32
137+
common_options.set("label", node.Name() + "_cast_skip_to_fp32");
138+
skip = model_builder.GetBuilder().call<emscripten::val>("cast", skip, emscripten::val("float32"));
139+
}
120140
common_options.set("label", node.Name() + "_add_skip");
121141
input = model_builder.GetBuilder().call<emscripten::val>("add", input, skip, common_options);
122142
if (!bias.isUndefined()) {
@@ -127,12 +147,20 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
127147
// Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists.
128148
// Now input equals to input_skip_bias_sum.
129149
if (TensorExists(output_defs, 3)) {
130-
model_builder.AddOperand(output_defs[3]->Name(), input);
150+
emscripten::val input_skip_bias_sum = input;
151+
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
152+
// Cast input_skip_bias_sum back to float16.
153+
common_options.set("label", node.Name() + "_cast_input_skip_bias_sum_to_fp16");
154+
input_skip_bias_sum = model_builder.GetBuilder().call<emscripten::val>("cast", input_skip_bias_sum,
155+
emscripten::val("float16"));
156+
}
157+
model_builder.AddOperand(output_defs[3]->Name(), input_skip_bias_sum);
131158
}
132159
}
133160

134161
// Pow
135-
emscripten::val pow_constant = model_builder.CreateOrGetConstant<float>(input_type, 2);
162+
emscripten::val pow_constant =
163+
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 2);
136164
common_options.set("label", node.Name() + "_pow");
137165
emscripten::val pow =
138166
model_builder.GetBuilder().call<emscripten::val>("pow", input, pow_constant, common_options);
@@ -145,7 +173,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
145173
emscripten::val reduce_mean = model_builder.GetBuilder().call<emscripten::val>("reduceMean", pow, reduce_options);
146174

147175
// Add
148-
emscripten::val add_constant = model_builder.CreateOrGetConstant<float>(input_type, epsilon);
176+
emscripten::val add_constant =
177+
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, epsilon);
149178
common_options.set("label", node.Name() + "_add");
150179
emscripten::val add =
151180
model_builder.GetBuilder().call<emscripten::val>("add", reduce_mean, add_constant, common_options);
@@ -167,6 +196,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
167196
common_options.set("label", node.Name() + "_add_bias");
168197
output = model_builder.GetBuilder().call<emscripten::val>("add", output, bias, common_options);
169198
}
199+
200+
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
201+
// Cast output back to float16.
202+
common_options.set("label", node.Name() + "_cast_output_to_fp16");
203+
output = model_builder.GetBuilder().call<emscripten::val>("cast", output, emscripten::val("float16"));
204+
}
170205
}
171206
} else if (op_type == "InstanceNormalization") {
172207
// WebNN spec only supports 4D input for instanceNormalization.

0 commit comments

Comments
 (0)