Skip to content

Commit 85dddea

Browse files
authored
[webgpu] fix broadcast for SkipLayerNorm (#27107)
### Description Fix the bug discovered by #27014: ``` SkipLayerNormTest.SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size SkipLayerNormTest.SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1 ```
1 parent ca0cd21 commit 85dddea

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
4040
<< " let workgroup_half_idx = uniforms.hidden_size / (workgroup_size_x * 4);\n"
4141
<< " if (workgroup_idx >= workgroup_half_idx) {\n"
4242
<< " offset = (workgroup_idx - workgroup_half_idx) * workgroup_size_x + local_idx;\n"
43-
<< " let skip_value = skip[offset];\n"
43+
<< " let skip_offset = offset % (uniforms.skip_size / 4);\n"
44+
<< " let skip_value = skip[skip_offset];\n"
4445
<< " let input_value = x[offset];\n"
4546
<< " let value = input_value + skip_value" << (hasBias_ ? " + bias[offset]" : "") << ";\n"
4647
<< " input_skip_bias_sum[offset] = value;\n"
@@ -56,7 +57,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
5657
<< " var cur_input_skip_bias_sum = x_value_t(0);\n"
5758
<< " for (var i: u32 = 0; i < uniforms.hidden_size / (workgroup_size_x * 4); i++) {\n"
5859
<< " let input_offset = i * workgroup_size_x + local_idx;\n"
59-
<< " let skip_value = skip[input_offset];\n"
60+
<< " let skip_input_offset = input_offset % (uniforms.skip_size / 4);\n"
61+
<< " let skip_value = skip[skip_input_offset];\n"
6062
<< " let input_value = x[input_offset];\n"
6163
<< " let value = input_value + skip_value" << (hasBias_ ? " + bias[input_offset]" : "") << ";\n"
6264
<< " if (i == workgroup_idx) {\n"
@@ -106,7 +108,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
106108
<< " stride = hidden_size_vectorized - stride * ix;\n"
107109
<< "}\n"
108110
<< "for (var i: u32 = 0; i < stride; i++) {\n"
109-
<< " let skip_value = skip[offset + i];\n"
111+
<< " let skip_offset = (offset + i) % (uniforms.skip_size / uniforms.components);\n"
112+
<< " let skip_value = skip[skip_offset];\n"
110113
<< " let input_value = x[offset + i];\n"
111114
<< " let value = input_value + skip_value" << bias << ";\n"
112115
<< " output[offset + i] = value;\n"
@@ -162,6 +165,9 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
162165
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(x_shape.NumDimensions() - 1));
163166
const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1;
164167

168+
const auto skip_shape = skip->Shape();
169+
const uint32_t skip_size = onnxruntime::narrow<uint32_t>(skip_shape.Size());
170+
165171
SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
166172
program
167173
.CacheHint(simplified, has_input_skip_bias_sum, split_hidden_dim)
@@ -178,6 +184,9 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
178184
})
179185
.AddUniformVariables({
180186
{static_cast<float>(epsilon_)},
187+
})
188+
.AddUniformVariables({
189+
{static_cast<uint32_t>(skip_size)},
181190
});
182191

183192
if (split_hidden_dim) {

onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class SkipLayerNormProgram final : public Program<SkipLayerNormProgram> {
3131
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
3232
{"components", ProgramUniformVariableDataType::Uint32},
3333
{"hidden_size", ProgramUniformVariableDataType::Uint32},
34-
{"epsilon", ProgramUniformVariableDataType::Float32});
34+
{"epsilon", ProgramUniformVariableDataType::Float32},
35+
{"skip_size", ProgramUniformVariableDataType::Uint32});
3536

3637
private:
3738
bool hasBeta_;

0 commit comments

Comments
 (0)