@@ -114,9 +114,29 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
114
114
ORT_RETURN_IF_NOT (GetType (*input_defs[0 ], input_type, logger), " Cannot get input type" );
115
115
emscripten::val common_options = emscripten::val::object ();
116
116
117
+ if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
118
+ // Decomposed *SimplifiedLayerNormalization may lose precision if its data type is float16,
119
+ // cast all inputs to float32 to ensure precision.
120
+ common_options.set (" label" , node.Name () + " _cast_input_to_fp32" );
121
+ input = model_builder.GetBuilder ().call <emscripten::val>(" cast" , input, emscripten::val (" float32" ));
122
+
123
+ common_options.set (" label" , node.Name () + " _cast_scale_to_fp32" );
124
+ scale = model_builder.GetBuilder ().call <emscripten::val>(" cast" , scale, emscripten::val (" float32" ));
125
+
126
+ if (!bias.isUndefined ()) {
127
+ common_options.set (" label" , node.Name () + " _cast_bias_to_fp32" );
128
+ bias = model_builder.GetBuilder ().call <emscripten::val>(" cast" , bias, emscripten::val (" float32" ));
129
+ }
130
+ }
131
+
117
132
// If it is SkipSimplifiedLayerNormalization, add the skip and bias (if it exists) to the input.
118
133
if (op_type == " SkipSimplifiedLayerNormalization" ) {
119
134
emscripten::val skip = model_builder.GetOperand (input_defs[1 ]->Name ());
135
+ if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
136
+ // Cast skip to float32
137
+ common_options.set (" label" , node.Name () + " _cast_skip_to_fp32" );
138
+ skip = model_builder.GetBuilder ().call <emscripten::val>(" cast" , skip, emscripten::val (" float32" ));
139
+ }
120
140
common_options.set (" label" , node.Name () + " _add_skip" );
121
141
input = model_builder.GetBuilder ().call <emscripten::val>(" add" , input, skip, common_options);
122
142
if (!bias.isUndefined ()) {
@@ -127,12 +147,20 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
127
147
// Add SkipSimplifiedLayerNormalization's output input_skip_bias_sum if it exists.
128
148
// Now input equals to input_skip_bias_sum.
129
149
if (TensorExists (output_defs, 3 )) {
130
- model_builder.AddOperand (output_defs[3 ]->Name (), input);
150
+ emscripten::val input_skip_bias_sum = input;
151
+ if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
152
+ // Cast input_skip_bias_sum back to float16.
153
+ common_options.set (" label" , node.Name () + " _cast_input_skip_bias_sum_to_fp16" );
154
+ input_skip_bias_sum = model_builder.GetBuilder ().call <emscripten::val>(" cast" , input_skip_bias_sum,
155
+ emscripten::val (" float16" ));
156
+ }
157
+ model_builder.AddOperand (output_defs[3 ]->Name (), input_skip_bias_sum);
131
158
}
132
159
}
133
160
134
161
// Pow
135
- emscripten::val pow_constant = model_builder.CreateOrGetConstant <float >(input_type, 2 );
162
+ emscripten::val pow_constant =
163
+ model_builder.CreateOrGetConstant <float >(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 2 );
136
164
common_options.set (" label" , node.Name () + " _pow" );
137
165
emscripten::val pow =
138
166
model_builder.GetBuilder ().call <emscripten::val>(" pow" , input, pow_constant, common_options);
@@ -145,7 +173,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
145
173
emscripten::val reduce_mean = model_builder.GetBuilder ().call <emscripten::val>(" reduceMean" , pow , reduce_options);
146
174
147
175
// Add
148
- emscripten::val add_constant = model_builder.CreateOrGetConstant <float >(input_type, epsilon);
176
+ emscripten::val add_constant =
177
+ model_builder.CreateOrGetConstant <float >(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, epsilon);
149
178
common_options.set (" label" , node.Name () + " _add" );
150
179
emscripten::val add =
151
180
model_builder.GetBuilder ().call <emscripten::val>(" add" , reduce_mean, add_constant, common_options);
@@ -167,6 +196,12 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
167
196
common_options.set (" label" , node.Name () + " _add_bias" );
168
197
output = model_builder.GetBuilder ().call <emscripten::val>(" add" , output, bias, common_options);
169
198
}
199
+
200
+ if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
201
+ // Cast output back to float16.
202
+ common_options.set (" label" , node.Name () + " _cast_output_to_fp16" );
203
+ output = model_builder.GetBuilder ().call <emscripten::val>(" cast" , output, emscripten::val (" float16" ));
204
+ }
170
205
}
171
206
} else if (op_type == " InstanceNormalization" ) {
172
207
// WebNN spec only supports 4D input for instanceNormalization.
0 commit comments