@@ -24,7 +24,13 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
2424 if (has_bias_) {
2525 shader.AddInput (" bias" , ShaderUsage::UseUniform);
2626 }
27- shader.AddOutput (" output" , ShaderUsage::UseUniform);
27+ shader.AddOutput (" y" , ShaderUsage::UseUniform);
28+ if (has_mean_output_) {
29+ shader.AddOutput (" mean_output" , ShaderUsage::None);
30+ }
31+ if (has_inv_std_dev_output_) {
32+ shader.AddOutput (" inv_std_dev_output" , ShaderUsage::None);
33+ }
2834
2935 int components = x.NumComponents ();
3036 std::string bias = (has_bias_) ? " + bias[j]" : " " ;
@@ -48,8 +54,14 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
4854 << " for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n "
4955 << " let f32input = f32_val_t(x[j + offset]);\n "
5056 << " let f32scale = f32_val_t(scale[j]);\n "
51- << " output [j + offset] = x_value_t((f32input" << simpl2 << " ) * inv_std_dev * f32scale)" << bias << " ;\n "
57+ << " y [j + offset] = x_value_t((f32input" << simpl2 << " ) * inv_std_dev * f32scale)" << bias << " ;\n "
5258 << " }\n " ;
59+ if (has_mean_output_) {
60+ shader.MainFunctionBody () << " mean_output[global_idx] = mean;\n " ;
61+ }
62+ if (has_inv_std_dev_output_) {
63+ shader.MainFunctionBody () << " inv_std_dev_output[global_idx] = inv_std_dev;\n " ;
64+ }
5365
5466 return Status::OK ();
5567}
@@ -62,8 +74,6 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
6274
6375 const auto x_shape = x->Shape ();
6476
65- auto * output = context.Output (0 , x_shape);
66-
6777 if (x_shape.Size () == 0 ) {
6878 return Status::OK ();
6979 }
@@ -85,13 +95,27 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
8595 scale_size, " and bias size of " , bias_size);
8696 }
8797
88- LayerNormProgram program{bias != nullptr , is_fp16, simplified};
98+ TensorShapeVector mean_dim;
99+ for (size_t i = 0 ; i < x_shape.NumDimensions (); ++i) {
100+ if (i < axis) {
101+ mean_dim.push_back (x_shape[i]);
102+ } else {
103+ mean_dim.push_back (1 );
104+ }
105+ }
106+ TensorShape mean_shape (mean_dim);
107+
108+ auto * y = context.Output (0 , x_shape);
109+ auto * mean = context.Output (1 , mean_shape);
110+ auto * inv_std_dev = context.Output (2 , mean_shape);
111+
112+ LayerNormProgram program{bias != nullptr , is_fp16, simplified, mean != nullptr , inv_std_dev != nullptr };
89113
90114 program
91115 .CacheHint (simplified)
92116 .AddInputs ({{x, ProgramTensorMetadataDependency::Type, components}})
93117 .AddInputs ({{scale, ProgramTensorMetadataDependency::Type, components}})
94- .AddOutputs ({{output , ProgramTensorMetadataDependency::None, components}})
118+ .AddOutputs ({{y , ProgramTensorMetadataDependency::None, components}})
95119 .SetDispatchGroupSize ((norm_count + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
96120 .AddUniformVariables ({
97121 {static_cast <uint32_t >(norm_count)},
@@ -109,25 +133,26 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
109133 if (bias != nullptr ) {
110134 program.AddInput ({bias, ProgramTensorMetadataDependency::Type, components});
111135 }
136+
137+ if (mean != nullptr ) {
138+ program.AddOutputs ({{mean, ProgramTensorMetadataDependency::None}});
139+ }
140+ if (inv_std_dev != nullptr ) {
141+ program.AddOutputs ({{inv_std_dev, ProgramTensorMetadataDependency::None}});
142+ }
143+
112144 return context.RunProgram (program);
113145}
114146
115- ONNX_OPERATOR_KERNEL_EX (
116- LayerNormalization,
117- kOnnxDomain ,
118- 17 ,
119- kWebGpuExecutionProvider ,
120- (*KernelDefBuilder::Create ())
121- .TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
122- LayerNorm<false>);
123-
124- ONNX_OPERATOR_KERNEL_EX (
125- SimplifiedLayerNormalization,
126- kOnnxDomain ,
127- 1 ,
128- kWebGpuExecutionProvider ,
129- (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
130- LayerNorm<true>);
147+ ONNX_OPERATOR_KERNEL_EX (LayerNormalization, kOnnxDomain , 17 , kWebGpuExecutionProvider ,
148+ (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedFloatTypes()),
149+ LayerNorm<false>);
150+
151+ ONNX_OPERATOR_KERNEL_EX (SimplifiedLayerNormalization, kOnnxDomain , 1 , kWebGpuExecutionProvider ,
152+ (*KernelDefBuilder::Create ())
153+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
154+ .TypeConstraint(" U" , WebGpuSupportedFloatTypes()),
155+ LayerNorm<true>);
131156
132157} // namespace webgpu
133158} // namespace onnxruntime
0 commit comments