Skip to content

Commit d71aa4d

Browse files
jchen10fs-eire
andauthored
[webgpu] Fix test_layer_normalization_2d_axis0 (#24223)
The optional 'Mean' and 'InvStdDev' outputs of the LayerNormalization were not implemented. --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
1 parent 8de342a commit d71aa4d

File tree

2 files changed

+61
-33
lines changed

2 files changed

+61
-33
lines changed

onnxruntime/core/providers/webgpu/nn/layer_norm.cc

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
2424
if (has_bias_) {
2525
shader.AddInput("bias", ShaderUsage::UseUniform);
2626
}
27-
shader.AddOutput("output", ShaderUsage::UseUniform);
27+
shader.AddOutput("y", ShaderUsage::UseUniform);
28+
if (has_mean_output_) {
29+
shader.AddOutput("mean_output", ShaderUsage::None);
30+
}
31+
if (has_inv_std_dev_output_) {
32+
shader.AddOutput("inv_std_dev_output", ShaderUsage::None);
33+
}
2834

2935
int components = x.NumComponents();
3036
std::string bias = (has_bias_) ? " + bias[j]" : "";
@@ -48,8 +54,14 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
4854
<< "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n"
4955
<< " let f32input = f32_val_t(x[j + offset]);\n"
5056
<< " let f32scale = f32_val_t(scale[j]);\n"
51-
<< " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n"
57+
<< " y[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n"
5258
<< "}\n";
59+
if (has_mean_output_) {
60+
shader.MainFunctionBody() << "mean_output[global_idx] = mean;\n";
61+
}
62+
if (has_inv_std_dev_output_) {
63+
shader.MainFunctionBody() << "inv_std_dev_output[global_idx] = inv_std_dev;\n";
64+
}
5365

5466
return Status::OK();
5567
}
@@ -62,8 +74,6 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
6274

6375
const auto x_shape = x->Shape();
6476

65-
auto* output = context.Output(0, x_shape);
66-
6777
if (x_shape.Size() == 0) {
6878
return Status::OK();
6979
}
@@ -85,13 +95,27 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
8595
scale_size, " and bias size of ", bias_size);
8696
}
8797

88-
LayerNormProgram program{bias != nullptr, is_fp16, simplified};
98+
TensorShapeVector mean_dim;
99+
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
100+
if (i < axis) {
101+
mean_dim.push_back(x_shape[i]);
102+
} else {
103+
mean_dim.push_back(1);
104+
}
105+
}
106+
TensorShape mean_shape(mean_dim);
107+
108+
auto* y = context.Output(0, x_shape);
109+
auto* mean = context.Output(1, mean_shape);
110+
auto* inv_std_dev = context.Output(2, mean_shape);
111+
112+
LayerNormProgram program{bias != nullptr, is_fp16, simplified, mean != nullptr, inv_std_dev != nullptr};
89113

90114
program
91115
.CacheHint(simplified)
92116
.AddInputs({{x, ProgramTensorMetadataDependency::Type, components}})
93117
.AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}})
94-
.AddOutputs({{output, ProgramTensorMetadataDependency::None, components}})
118+
.AddOutputs({{y, ProgramTensorMetadataDependency::None, components}})
95119
.SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
96120
.AddUniformVariables({
97121
{static_cast<uint32_t>(norm_count)},
@@ -109,25 +133,26 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
109133
if (bias != nullptr) {
110134
program.AddInput({bias, ProgramTensorMetadataDependency::Type, components});
111135
}
136+
137+
if (mean != nullptr) {
138+
program.AddOutputs({{mean, ProgramTensorMetadataDependency::None}});
139+
}
140+
if (inv_std_dev != nullptr) {
141+
program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}});
142+
}
143+
112144
return context.RunProgram(program);
113145
}
114146

115-
ONNX_OPERATOR_KERNEL_EX(
116-
LayerNormalization,
117-
kOnnxDomain,
118-
17,
119-
kWebGpuExecutionProvider,
120-
(*KernelDefBuilder::Create())
121-
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
122-
LayerNorm<false>);
123-
124-
ONNX_OPERATOR_KERNEL_EX(
125-
SimplifiedLayerNormalization,
126-
kOnnxDomain,
127-
1,
128-
kWebGpuExecutionProvider,
129-
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
130-
LayerNorm<true>);
147+
ONNX_OPERATOR_KERNEL_EX(LayerNormalization, kOnnxDomain, 17, kWebGpuExecutionProvider,
148+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
149+
LayerNorm<false>);
150+
151+
ONNX_OPERATOR_KERNEL_EX(SimplifiedLayerNormalization, kOnnxDomain, 1, kWebGpuExecutionProvider,
152+
(*KernelDefBuilder::Create())
153+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
154+
.TypeConstraint("U", WebGpuSupportedFloatTypes()),
155+
LayerNorm<true>);
131156

132157
} // namespace webgpu
133158
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/nn/layer_norm.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,28 @@ namespace webgpu {
1111

1212
class LayerNormProgram final : public Program<LayerNormProgram> {
1313
public:
14-
LayerNormProgram(bool has_bias,
15-
bool is_fp16,
16-
bool simplified) : Program{"LayerNorm"},
17-
has_bias_{has_bias},
18-
is_fp16_{is_fp16},
19-
simplified_{simplified} {}
14+
LayerNormProgram(bool has_bias, bool is_fp16, bool simplified, bool has_mean_output,
15+
bool has_inv_std_dev_output)
16+
: Program{"LayerNorm"},
17+
has_bias_{has_bias},
18+
is_fp16_{is_fp16},
19+
simplified_{simplified},
20+
has_mean_output_{has_mean_output},
21+
has_inv_std_dev_output_{has_inv_std_dev_output} {}
2022

2123
Status GenerateShaderCode(ShaderHelper& sh) const override;
2224

23-
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
24-
{"norm_count", ProgramUniformVariableDataType::Uint32},
25-
{"norm_size", ProgramUniformVariableDataType::Uint32},
26-
{"norm_size_vectorized", ProgramUniformVariableDataType::Uint32},
27-
{"epsilon", ProgramUniformVariableDataType::Float32});
25+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"norm_count", ProgramUniformVariableDataType::Uint32},
26+
{"norm_size", ProgramUniformVariableDataType::Uint32},
27+
{"norm_size_vectorized", ProgramUniformVariableDataType::Uint32},
28+
{"epsilon", ProgramUniformVariableDataType::Float32});
2829

2930
private:
3031
bool has_bias_;
3132
bool is_fp16_;
3233
bool simplified_;
34+
bool has_mean_output_;
35+
bool has_inv_std_dev_output_;
3336
};
3437

3538
template <bool simplified>

0 commit comments

Comments
 (0)