Skip to content

Commit bf95ce1

Browse files
committed
Merge branch 'main' into copilot/fix-opset-versions-sync
2 parents fc0866d + 0a478c0 commit bf95ce1

File tree

9 files changed

+774
-66
lines changed

9 files changed

+774
-66
lines changed

.github/workflows/linux_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
dockerfile_path: tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile
9595
docker_image_repo: onnxruntimecpubuildciaarch64
9696
# ASan disabled due to excessive runtime (>4hr). Includes wheel build for basic checks.
97-
extra_build_flags: '--use_binskim_compliant_compile_flags --build_shared_lib'
97+
extra_build_flags: '--use_binskim_compliant_compile_flags --build_shared_lib --enable_arm_neon_nchwc'
9898
job_identifier: build-linux-arm64-debug
9999
secrets:
100100
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts

Lines changed: 90 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,42 @@ export const createMatMulNBitsProgramInfo = (
123123
}
124124
})();
125125

126+
// Number of quantized values per u32 word and passes needed (each pass extracts 8 values).
127+
const valuesPerWord = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
128+
const passesPerWord = Math.floor(valuesPerWord / 8); // Q4=1, Q2=2
129+
126130
const processOneWord = (): string => {
127-
let calcStr = `
128-
// reuse a data
129-
var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)};
130-
var a_data: ${qDqDataType};
131-
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
132-
a_data[j] = ${a.getByOffset('input_offset')};
133-
input_offset++;
131+
let calcStr = '';
132+
for (let pass = 0; pass < passesPerWord; pass++) {
133+
// Each pass processes 8 values from the current u32 word.
134+
// For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6).
135+
const lowerShift = pass * attributes.bits * 4; // bit offset for lower group within each byte
136+
const upperShift = lowerShift + attributes.bits;
137+
calcStr += `
138+
// reuse a data (pass ${pass})
139+
var input_offset${pass > 0 ? pass : ''} = ${pass === 0 ? a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`) : `input_offset`};
140+
var a_data${pass > 0 ? pass : ''}: ${qDqDataType};
141+
for (var j${pass > 0 ? pass : ''}: u32 = 0; j${pass > 0 ? pass : ''} < ${8 / aComponents}; j${pass > 0 ? pass : ''}++) {
142+
a_data${pass > 0 ? pass : ''}[j${pass > 0 ? pass : ''}] = ${a.getByOffset(`input_offset${pass > 0 ? pass : ''}`)};
143+
input_offset${pass > 0 ? pass : ''}++;
134144
}
135145
`;
136-
for (let c = 0; c < components * outputNumber; c++) {
137-
calcStr += `
146+
for (let c = 0; c < components * outputNumber; c++) {
147+
calcStr += `
138148
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
139-
b_value_lower = unpack4xU8(b_value & b_mask);
140-
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
149+
${
150+
attributes.bits === 2
151+
? `{
152+
let half_word = b_value >> ${pass * 16}u;
153+
let byte_lo = half_word & 0xFFu;
154+
let byte_hi = (half_word >> 8u) & 0xFFu;
155+
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
156+
b_value_lower = unpack4xU8(spread_word & b_mask);
157+
b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask);
158+
}`
159+
: `b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & b_mask);
160+
b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & b_mask);`
161+
}
141162
b_quantized_values = ${qDqDataType}(${Array.from(
142163
{ length: 4 },
143164
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
@@ -159,11 +180,12 @@ export const createMatMulNBitsProgramInfo = (
159180
(_, i) =>
160181
`${
161182
aComponents === 1
162-
? `a_data[${i}] * b_dequantized_values[${i}]`
163-
: `dot(a_data[${i}], b_dequantized_values[${i}])`
183+
? `a_data${pass > 0 ? pass : ''}[${i}] * b_dequantized_values[${i}]`
184+
: `dot(a_data${pass > 0 ? pass : ''}[${i}], b_dequantized_values[${i}])`
164185
}`,
165186
).join(' + ')};
166187
`;
188+
}
167189
}
168190
return calcStr;
169191
};
@@ -173,16 +195,17 @@ export const createMatMulNBitsProgramInfo = (
173195
${
174196
zeroPoints
175197
? `
176-
let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
198+
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
199+
let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
177200
var zero_point_byte_count: u32;
178201
var zero_point_word_index: u32;
179202
var zero_point_byte_offset: u32;
180-
let zero_point_nibble_offset: u32 = block & 0x1u;
203+
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
181204
var zero_point_bits_offset: u32;
182205
var zero_point_word: u32;`
183206
: `
184-
// The default zero point is 8 for unsigned 4-bit quantization.
185-
let zero_point = ${dataType}(${8.0});`
207+
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
208+
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
186209
}
187210
`;
188211
for (let c = 0; c < components * outputNumber; c++) {
@@ -191,12 +214,12 @@ export const createMatMulNBitsProgramInfo = (
191214
${
192215
zeroPoints
193216
? `
194-
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);
217+
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
195218
zero_point_word_index = zero_point_byte_count >> 0x2u;
196219
zero_point_byte_offset = zero_point_byte_count & 0x3u;
197-
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
220+
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
198221
zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
199-
let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);`
222+
let zero_point${c} = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
200223
: ''
201224
}
202225
col_index += 1;`;
@@ -212,7 +235,7 @@ export const createMatMulNBitsProgramInfo = (
212235
}
213236
calcStr += `
214237
var b_value: u32;
215-
let b_mask: u32 = 0x0F0F0F0Fu;
238+
let b_mask: u32 = ${attributes.bits === 2 ? '0x03030303u' : '0x0F0F0F0Fu'};
216239
var b_value_lower: vec4<u32>;
217240
var b_value_upper: vec4<u32>;
218241
var b_quantized_values: ${qDqDataType};
@@ -237,7 +260,7 @@ export const createMatMulNBitsProgramInfo = (
237260
${prepareBData()}
238261
for (var i: u32 = 0; i < ${bComponents}; i++) {
239262
${processOneWord()}
240-
word_offset += ${8 / aComponents};
263+
word_offset += ${valuesPerWord / aComponents};
241264
}
242265
}
243266
}
@@ -291,7 +314,8 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
291314
const workgroupSize = 128;
292315
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
293316
const workgroupX = workgroupSize / workgroupY;
294-
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
317+
const valuesPerWordBs32 = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
318+
const tileSize = workgroupX * bComponents * valuesPerWordBs32; // each uint32 has valuesPerWord data.
295319
const aLengthPerTile = tileSize / aComponents;
296320
const blocksPerTile = tileSize / attributes.blockSize;
297321
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
@@ -376,36 +400,59 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
376400
${
377401
zeroPoints
378402
? `
379-
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
380-
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
403+
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
404+
let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
405+
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
381406
let zero_point_word_index = zero_point_byte_count >> 0x2u;
382407
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
383-
let zero_point_nibble_offset: u32 = block & 0x1u;
384-
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
408+
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
409+
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
385410
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
386-
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
411+
let zero_point = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
387412
: `
388-
// The default zero point is 8 for unsigned 4-bit quantization.
389-
let zero_point = ${dataType}(${8.0});`
413+
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
414+
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
390415
}
391416
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
392417
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
393418
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
394419
for (var i: u32 = 0; i < ${bComponents}; i++) {
395-
${readA()}
396420
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
397-
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
398-
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
399-
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
400-
{ length: 4 },
401-
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
402-
).join(', ')});
403-
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
404-
inter_results[local_id.y][local_id.x] += ${Array.from(
405-
{ length: 2 },
406-
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
407-
).join(' + ')};
408-
word_offset += ${8 / aComponents};
421+
${(() => {
422+
const passesPerWordBs32 = Math.floor(valuesPerWordBs32 / 8);
423+
let code = '';
424+
for (let pass = 0; pass < passesPerWordBs32; pass++) {
425+
const lowerShift = pass * attributes.bits * 4;
426+
const upperShift = lowerShift + attributes.bits;
427+
code += `
428+
${readA()}
429+
{${
430+
attributes.bits === 2
431+
? `
432+
let half_word = b_value >> ${pass * 16}u;
433+
let byte_lo = half_word & 0xFFu;
434+
let byte_hi = (half_word >> 8u) & 0xFFu;
435+
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
436+
let b_value_lower = unpack4xU8(spread_word & 0x03030303u);
437+
let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`
438+
: `
439+
let b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & 0x0F0F0F0Fu);
440+
let b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & 0x0F0F0F0Fu);`
441+
}
442+
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
443+
{ length: 4 },
444+
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
445+
).join(', ')});
446+
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
447+
inter_results[local_id.y][local_id.x] += ${Array.from(
448+
{ length: 2 },
449+
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
450+
).join(' + ')};
451+
}
452+
word_offset += ${8 / aComponents};`;
453+
}
454+
return code;
455+
})()}
409456
}
410457
workgroupBarrier();
411458
}

onnxruntime/contrib_ops/js/quantization/matmul_nbits.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class MatMulNBits final : public JsKernel {
1717
accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
1818
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
1919
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
20-
ORT_ENFORCE(nbits_ == 4,
21-
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
20+
ORT_ENFORCE(nbits_ == 4 || nbits_ == 2,
21+
"Only 2b and 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
2222
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
2323
"Block size must be a power of 2 and greater than or equal to 16.");
2424
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
117117
ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet.");
118118

119119
const bool has_zero_points = zero_points != nullptr;
120-
const uint32_t nbits = onnxruntime::narrow<uint32_t>(bits_);
121120
if (has_zero_points) {
122121
ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType<uint8_t>(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType());
123-
ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization.");
124122
}
125123

126124
MatMulComputeHelper helper;
@@ -205,9 +203,13 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
205203
const uint32_t components_a = GetMaxComponents(K);
206204
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
207205
uint32_t components = GetMaxComponents(N);
208-
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits.
209-
// For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2.
210-
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1;
206+
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)].
207+
// The shader uses a flat linear index to address individual n-bit zero point values.
208+
// Since each column's zero points are byte-aligned in the packed buffer, we must round
209+
// n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point
210+
// values per byte — so that the linear stride correctly skips byte-boundary padding.
211+
const uint32_t zp_elements_per_byte = 8 / static_cast<uint32_t>(nbits);
212+
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;
211213

212214
#if !defined(__wasm__)
213215
int32_t subgroup_matrix_config_index = -1;
@@ -219,7 +221,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
219221
#endif
220222

221223
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
224+
// DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points.
222225
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
226+
!(has_zero_points && nbits == 2) &&
223227
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
224228
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
225229
}

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
const bit_mask = 0xFFu;
2121
#elif n_bits == 2
2222
const default_zero_point = 2;
23+
const bit_mask = 0x3u;
2324
#endif
2425

2526
#if has_zero_points

0 commit comments

Comments
 (0)