Skip to content

Commit 40d36b1

Browse files
authored
Improve WebGPU MatMulNBits to support zero pointer for 2bits (#27285)
### Description The existing WebGPU MatMulNBits op does not support zero pointer for 2bits. So it blocks some models. This PR enables the zero pointer support for 2bits support. UT tests are included for coverage.
1 parent bcd5605 commit 40d36b1

File tree

5 files changed

+198
-53
lines changed

5 files changed

+198
-53
lines changed

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

Lines changed: 90 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,42 @@ export const createMatMulNBitsProgramInfo = (
123123
}
124124
})();
125125

126+
// Number of quantized values per u32 word and passes needed (each pass extracts 8 values).
127+
const valuesPerWord = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
128+
const passesPerWord = Math.floor(valuesPerWord / 8); // Q4=1, Q2=2
129+
126130
const processOneWord = (): string => {
127-
let calcStr = `
128-
// reuse a data
129-
var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)};
130-
var a_data: ${qDqDataType};
131-
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
132-
a_data[j] = ${a.getByOffset('input_offset')};
133-
input_offset++;
131+
let calcStr = '';
132+
for (let pass = 0; pass < passesPerWord; pass++) {
133+
// Each pass processes 8 values from the current u32 word.
134+
// For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6).
135+
const lowerShift = pass * attributes.bits * 4; // bit offset for lower group within each byte
136+
const upperShift = lowerShift + attributes.bits;
137+
calcStr += `
138+
// reuse a data (pass ${pass})
139+
var input_offset${pass > 0 ? pass : ''} = ${pass === 0 ? a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`) : `input_offset`};
140+
var a_data${pass > 0 ? pass : ''}: ${qDqDataType};
141+
for (var j${pass > 0 ? pass : ''}: u32 = 0; j${pass > 0 ? pass : ''} < ${8 / aComponents}; j${pass > 0 ? pass : ''}++) {
142+
a_data${pass > 0 ? pass : ''}[j${pass > 0 ? pass : ''}] = ${a.getByOffset(`input_offset${pass > 0 ? pass : ''}`)};
143+
input_offset${pass > 0 ? pass : ''}++;
134144
}
135145
`;
136-
for (let c = 0; c < components * outputNumber; c++) {
137-
calcStr += `
146+
for (let c = 0; c < components * outputNumber; c++) {
147+
calcStr += `
138148
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
139-
b_value_lower = unpack4xU8(b_value & b_mask);
140-
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
149+
${
150+
attributes.bits === 2
151+
? `{
152+
let half_word = b_value >> ${pass * 16}u;
153+
let byte_lo = half_word & 0xFFu;
154+
let byte_hi = (half_word >> 8u) & 0xFFu;
155+
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
156+
b_value_lower = unpack4xU8(spread_word & b_mask);
157+
b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask);
158+
}`
159+
: `b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & b_mask);
160+
b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & b_mask);`
161+
}
141162
b_quantized_values = ${qDqDataType}(${Array.from(
142163
{ length: 4 },
143164
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
@@ -159,11 +180,12 @@ export const createMatMulNBitsProgramInfo = (
159180
(_, i) =>
160181
`${
161182
aComponents === 1
162-
? `a_data[${i}] * b_dequantized_values[${i}]`
163-
: `dot(a_data[${i}], b_dequantized_values[${i}])`
183+
? `a_data${pass > 0 ? pass : ''}[${i}] * b_dequantized_values[${i}]`
184+
: `dot(a_data${pass > 0 ? pass : ''}[${i}], b_dequantized_values[${i}])`
164185
}`,
165186
).join(' + ')};
166187
`;
188+
}
167189
}
168190
return calcStr;
169191
};
@@ -173,16 +195,17 @@ export const createMatMulNBitsProgramInfo = (
173195
${
174196
zeroPoints
175197
? `
176-
let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
198+
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
199+
let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
177200
var zero_point_byte_count: u32;
178201
var zero_point_word_index: u32;
179202
var zero_point_byte_offset: u32;
180-
let zero_point_nibble_offset: u32 = block & 0x1u;
203+
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
181204
var zero_point_bits_offset: u32;
182205
var zero_point_word: u32;`
183206
: `
184-
// The default zero point is 8 for unsigned 4-bit quantization.
185-
let zero_point = ${dataType}(${8.0});`
207+
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
208+
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
186209
}
187210
`;
188211
for (let c = 0; c < components * outputNumber; c++) {
@@ -191,12 +214,12 @@ export const createMatMulNBitsProgramInfo = (
191214
${
192215
zeroPoints
193216
? `
194-
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);
217+
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
195218
zero_point_word_index = zero_point_byte_count >> 0x2u;
196219
zero_point_byte_offset = zero_point_byte_count & 0x3u;
197-
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
220+
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
198221
zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
199-
let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);`
222+
let zero_point${c} = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
200223
: ''
201224
}
202225
col_index += 1;`;
@@ -212,7 +235,7 @@ export const createMatMulNBitsProgramInfo = (
212235
}
213236
calcStr += `
214237
var b_value: u32;
215-
let b_mask: u32 = 0x0F0F0F0Fu;
238+
let b_mask: u32 = ${attributes.bits === 2 ? '0x03030303u' : '0x0F0F0F0Fu'};
216239
var b_value_lower: vec4<u32>;
217240
var b_value_upper: vec4<u32>;
218241
var b_quantized_values: ${qDqDataType};
@@ -237,7 +260,7 @@ export const createMatMulNBitsProgramInfo = (
237260
${prepareBData()}
238261
for (var i: u32 = 0; i < ${bComponents}; i++) {
239262
${processOneWord()}
240-
word_offset += ${8 / aComponents};
263+
word_offset += ${valuesPerWord / aComponents};
241264
}
242265
}
243266
}
@@ -291,7 +314,8 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
291314
const workgroupSize = 128;
292315
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
293316
const workgroupX = workgroupSize / workgroupY;
294-
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
317+
const valuesPerWordBs32 = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
318+
const tileSize = workgroupX * bComponents * valuesPerWordBs32; // each uint32 has valuesPerWord data.
295319
const aLengthPerTile = tileSize / aComponents;
296320
const blocksPerTile = tileSize / attributes.blockSize;
297321
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
@@ -376,36 +400,59 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
376400
${
377401
zeroPoints
378402
? `
379-
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
380-
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
403+
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
404+
let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
405+
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
381406
let zero_point_word_index = zero_point_byte_count >> 0x2u;
382407
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
383-
let zero_point_nibble_offset: u32 = block & 0x1u;
384-
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
408+
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
409+
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
385410
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
386-
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
411+
let zero_point = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
387412
: `
388-
// The default zero point is 8 for unsigned 4-bit quantization.
389-
let zero_point = ${dataType}(${8.0});`
413+
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
414+
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
390415
}
391416
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
392417
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
393418
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
394419
for (var i: u32 = 0; i < ${bComponents}; i++) {
395-
${readA()}
396420
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
397-
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
398-
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
399-
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
400-
{ length: 4 },
401-
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
402-
).join(', ')});
403-
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
404-
inter_results[local_id.y][local_id.x] += ${Array.from(
405-
{ length: 2 },
406-
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
407-
).join(' + ')};
408-
word_offset += ${8 / aComponents};
421+
${(() => {
422+
const passesPerWordBs32 = Math.floor(valuesPerWordBs32 / 8);
423+
let code = '';
424+
for (let pass = 0; pass < passesPerWordBs32; pass++) {
425+
const lowerShift = pass * attributes.bits * 4;
426+
const upperShift = lowerShift + attributes.bits;
427+
code += `
428+
${readA()}
429+
{${
430+
attributes.bits === 2
431+
? `
432+
let half_word = b_value >> ${pass * 16}u;
433+
let byte_lo = half_word & 0xFFu;
434+
let byte_hi = (half_word >> 8u) & 0xFFu;
435+
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
436+
let b_value_lower = unpack4xU8(spread_word & 0x03030303u);
437+
let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`
438+
: `
439+
let b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & 0x0F0F0F0Fu);
440+
let b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & 0x0F0F0F0Fu);`
441+
}
442+
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
443+
{ length: 4 },
444+
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
445+
).join(', ')});
446+
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
447+
inter_results[local_id.y][local_id.x] += ${Array.from(
448+
{ length: 2 },
449+
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
450+
).join(' + ')};
451+
}
452+
word_offset += ${8 / aComponents};`;
453+
}
454+
return code;
455+
})()}
409456
}
410457
workgroupBarrier();
411458
}

onnxruntime/contrib_ops/js/quantization/matmul_nbits.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class MatMulNBits final : public JsKernel {
1717
accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
1818
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
1919
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
20-
ORT_ENFORCE(nbits_ == 4,
21-
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
20+
ORT_ENFORCE(nbits_ == 4 || nbits_ == 2,
21+
"Only 2b and 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
2222
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
2323
"Block size must be a power of 2 and greater than or equal to 16.");
2424
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
117117
ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet.");
118118

119119
const bool has_zero_points = zero_points != nullptr;
120-
const uint32_t nbits = onnxruntime::narrow<uint32_t>(bits_);
121120
if (has_zero_points) {
122121
ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType<uint8_t>(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType());
123-
ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization.");
124122
}
125123

126124
MatMulComputeHelper helper;
@@ -205,9 +203,13 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
205203
const uint32_t components_a = GetMaxComponents(K);
206204
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
207205
uint32_t components = GetMaxComponents(N);
208-
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits.
209-
// For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2.
210-
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1;
206+
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)].
207+
// The shader uses a flat linear index to address individual n-bit zero point values.
208+
// Since each column's zero points are byte-aligned in the packed buffer, we must round
209+
// n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point
210+
// values per byte — so that the linear stride correctly skips byte-boundary padding.
211+
const uint32_t zp_elements_per_byte = 8 / static_cast<uint32_t>(nbits);
212+
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;
211213

212214
#if !defined(__wasm__)
213215
int32_t subgroup_matrix_config_index = -1;
@@ -219,7 +221,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
219221
#endif
220222

221223
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
224+
// DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points.
222225
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
226+
!(has_zero_points && nbits == 2) &&
223227
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
224228
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
225229
}

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
const bit_mask = 0xFFu;
2121
#elif n_bits == 2
2222
const default_zero_point = 2;
23+
const bit_mask = 0x3u;
2324
#endif
2425

2526
#if has_zero_points

0 commit comments

Comments
 (0)