We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent dea3b47 commit 54dced4Copy full SHA for 54dced4
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
@@ -76,14 +76,10 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
76
let scale = max(max_temp[0], max_temp[1]);
77
let norm_a = a_values[local_row][local_col]/scale;
78
output[global_idx] = pack4x8snorm(vec4<f32>(norm_a));
79
- if (local_idx == 0u)
+ if (local_col == 0u)
80
{
81
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
82
- scales[workgroup_idx * 2] = scale/127;
83
- } else if (local_idx == 32u)
84
- {
85
- // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
86
- scales[workgroup_idx * 2 + 1] = scale/127;
+ scales[workgroup_idx * 2 + local_row] = scale/127;
87
}
88
)MAIN_FN";
89
return Status::OK();
0 commit comments