@@ -40,7 +40,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
4040 << " let workgroup_half_idx = uniforms.hidden_size / (workgroup_size_x * 4);\n "
4141 << " if (workgroup_idx >= workgroup_half_idx) {\n "
4242 << " offset = (workgroup_idx - workgroup_half_idx) * workgroup_size_x + local_idx;\n "
43- << " let skip_value = skip[offset];\n "
43+ << " let skip_offset = offset % (uniforms.skip_size / 4);\n "
44+ << " let skip_value = skip[skip_offset];\n "
4445 << " let input_value = x[offset];\n "
4546 << " let value = input_value + skip_value" << (hasBias_ ? " + bias[offset]" : " " ) << " ;\n "
4647 << " input_skip_bias_sum[offset] = value;\n "
@@ -56,7 +57,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
5657 << " var cur_input_skip_bias_sum = x_value_t(0);\n "
5758 << " for (var i: u32 = 0; i < uniforms.hidden_size / (workgroup_size_x * 4); i++) {\n "
5859 << " let input_offset = i * workgroup_size_x + local_idx;\n "
59- << " let skip_value = skip[input_offset];\n "
60+ << " let skip_input_offset = input_offset % (uniforms.skip_size / 4);\n "
61+ << " let skip_value = skip[skip_input_offset];\n "
6062 << " let input_value = x[input_offset];\n "
6163 << " let value = input_value + skip_value" << (hasBias_ ? " + bias[input_offset]" : " " ) << " ;\n "
6264 << " if (i == workgroup_idx) {\n "
@@ -106,7 +108,8 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
106108 << " stride = hidden_size_vectorized - stride * ix;\n "
107109 << " }\n "
108110 << " for (var i: u32 = 0; i < stride; i++) {\n "
109- << " let skip_value = skip[offset + i];\n "
111+ << " let skip_offset = (offset + i) % (uniforms.skip_size / uniforms.components);\n "
112+ << " let skip_value = skip[skip_offset];\n "
110113 << " let input_value = x[offset + i];\n "
111114 << " let value = input_value + skip_value" << bias << " ;\n "
112115 << " output[offset + i] = value;\n "
@@ -162,6 +165,9 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
162165 const uint32_t norm_count = onnxruntime::narrow<uint32_t >(x_shape.SizeToDimension (x_shape.NumDimensions () - 1 ));
163166 const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1 ;
164167
168+ const auto skip_shape = skip->Shape ();
169+ const uint32_t skip_size = onnxruntime::narrow<uint32_t >(skip_shape.Size ());
170+
165171 SkipLayerNormProgram program{beta != nullptr , bias != nullptr , epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
166172 program
167173 .CacheHint (simplified, has_input_skip_bias_sum, split_hidden_dim)
@@ -178,6 +184,9 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
178184 })
179185 .AddUniformVariables ({
180186 {static_cast <float >(epsilon_)},
187+ })
188+ .AddUniformVariables ({
189+ {static_cast <uint32_t >(skip_size)},
181190 });
182191
183192 if (split_hidden_dim) {
0 commit comments