Skip to content

Commit 4418d1b

Browse files
qjia7alex-spacemit
authored andcommitted
[webgpu] Fix errors found by matmul2bits tests (microsoft#26880)
Fix bugs found in microsoft#26862.
1 parent 1f8fd82 commit 4418d1b

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,21 @@ $MAIN {
8484
#if n_bits == 4
8585
var sum = output_element_t(0);
8686
var a_offset = idx * (8 / component_a) * component_b;
87+
#if component_b == 1
88+
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
89+
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
90+
let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
91+
let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
92+
#if component_a == 1
93+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
94+
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);
95+
#elif component_a == 2
96+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
97+
dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);
98+
#elif component_a == 4
99+
sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);
100+
#endif
101+
#else
87102
for (var i = 0u; i < component_b; i++) {
88103
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
89104
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
@@ -102,25 +117,63 @@ $MAIN {
102117
a_offset += 2;
103118
#endif
104119
}
120+
#endif
105121
#elif n_bits == 8
106122
var sum = output_element_t(0);
107123
var a_offset = idx * (4 / component_a) * component_b;
124+
#if component_b == 1
125+
let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value)) - vec4<output_element_t>(zero)) * scale_b;
126+
#if component_a == 1
127+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked);
128+
#elif component_a == 2
129+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked);
130+
#elif component_a == 4
131+
sum += dot(tile_A[a_offset], b_value_unpacked);
132+
#endif
133+
#else
108134
for (var i = 0u; i < component_b; i++) {
109-
let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
135+
let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
110136
#if component_a == 1
111-
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);
137+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked);
112138
a_offset += 4;
113139
#elif component_a == 2
114-
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value);
140+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked);
115141
a_offset += 2;
116142
#elif component_a == 4
117-
sum += dot(tile_A[a_offset], b_value);
143+
sum += dot(tile_A[a_offset], b_value_unpacked);
118144
a_offset += 1;
119145
#endif
120146
}
147+
#endif
121148
#elif n_bits == 2
122149
var sum = output_element_t(0);
123150
var a_offset = idx * (16 / component_a) * component_b;
151+
#if component_b == 1
152+
let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value & 0x03030303u)) - vec4<output_element_t>(zero);
153+
let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
154+
let b_data_2 = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x03030303u)) - vec4<output_element_t>(zero);
155+
let b_data_3 = vec4<output_element_t>(unpack4xU8((b_value >> 6) & 0x03030303u)) - vec4<output_element_t>(zero);
156+
157+
let b0 = vec4<output_element_t>(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b;
158+
let b1 = vec4<output_element_t>(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b;
159+
let b2 = vec4<output_element_t>(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b;
160+
let b3 = vec4<output_element_t>(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b;
161+
162+
#if component_a == 1
163+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
164+
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) +
165+
dot(vec4<output_element_t>(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) +
166+
dot(vec4<output_element_t>(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3);
167+
#elif component_a == 2
168+
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
169+
dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) +
170+
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) +
171+
dot(vec4<output_element_t>(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3);
172+
#elif component_a == 4
173+
sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) +
174+
dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3);
175+
#endif
176+
#else
124177
for (var i = 0u; i < component_b; i++) {
125178
let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x03030303u)) - vec4<output_element_t>(zero);
126179
let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
@@ -150,6 +203,7 @@ $MAIN {
150203
a_offset += 4;
151204
#endif
152205
}
206+
#endif
153207
#endif
154208

155209
inter_results[local_row_offset + idy][idx] += sum;

onnxruntime/test/contrib_ops/matmul_2bits_test.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,12 @@ void RunTest2Bits(const TestOptions2Bits& opts) {
197197

198198
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
199199
if constexpr (std::is_same<T1, float>::value) {
200-
execution_providers.emplace_back(DefaultCpuExecutionProvider());
201200
#ifdef USE_WEBGPU
202-
execution_providers.push_back(DefaultWebGpuExecutionProvider());
201+
if (!opts.has_zero_point) {
202+
execution_providers.push_back(DefaultWebGpuExecutionProvider());
203+
}
203204
#endif
205+
execution_providers.emplace_back(DefaultCpuExecutionProvider());
204206
test.ConfigEps(std::move(execution_providers));
205207
test.RunWithConfig();
206208
}

0 commit comments

Comments
 (0)