Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 90 additions & 43 deletions js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,42 @@ export const createMatMulNBitsProgramInfo = (
}
})();

// Number of quantized values per u32 word and passes needed (each pass extracts 8 values).
const valuesPerWord = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
const passesPerWord = Math.floor(valuesPerWord / 8); // Q4=1, Q2=2

const processOneWord = (): string => {
let calcStr = `
// reuse a data
var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)};
var a_data: ${qDqDataType};
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
a_data[j] = ${a.getByOffset('input_offset')};
input_offset++;
let calcStr = '';
for (let pass = 0; pass < passesPerWord; pass++) {
// Each pass processes 8 values from the current u32 word.
// For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6).
const lowerShift = pass * attributes.bits * 4; // bit offset for lower group within each byte
const upperShift = lowerShift + attributes.bits;
calcStr += `
// reuse a data (pass ${pass})
var input_offset${pass > 0 ? pass : ''} = ${pass === 0 ? a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`) : `input_offset`};
var a_data${pass > 0 ? pass : ''}: ${qDqDataType};
for (var j${pass > 0 ? pass : ''}: u32 = 0; j${pass > 0 ? pass : ''} < ${8 / aComponents}; j${pass > 0 ? pass : ''}++) {
a_data${pass > 0 ? pass : ''}[j${pass > 0 ? pass : ''}] = ${a.getByOffset(`input_offset${pass > 0 ? pass : ''}`)};
input_offset${pass > 0 ? pass : ''}++;
}
`;
for (let c = 0; c < components * outputNumber; c++) {
calcStr += `
for (let c = 0; c < components * outputNumber; c++) {
calcStr += `
b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
b_value_lower = unpack4xU8(b_value & b_mask);
b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
${
attributes.bits === 2
? `{
let half_word = b_value >> ${pass * 16}u;
let byte_lo = half_word & 0xFFu;
let byte_hi = (half_word >> 8u) & 0xFFu;
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
b_value_lower = unpack4xU8(spread_word & b_mask);
b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask);
}`
: `b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & b_mask);
b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & b_mask);`
}
b_quantized_values = ${qDqDataType}(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
Expand All @@ -159,11 +180,12 @@ export const createMatMulNBitsProgramInfo = (
(_, i) =>
`${
aComponents === 1
? `a_data[${i}] * b_dequantized_values[${i}]`
: `dot(a_data[${i}], b_dequantized_values[${i}])`
? `a_data${pass > 0 ? pass : ''}[${i}] * b_dequantized_values[${i}]`
: `dot(a_data${pass > 0 ? pass : ''}[${i}], b_dequantized_values[${i}])`
}`,
).join(' + ')};
`;
}
}
return calcStr;
};
Expand All @@ -173,16 +195,17 @@ export const createMatMulNBitsProgramInfo = (
${
zeroPoints
? `
let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
var zero_point_byte_count: u32;
var zero_point_word_index: u32;
var zero_point_byte_offset: u32;
let zero_point_nibble_offset: u32 = block & 0x1u;
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
var zero_point_bits_offset: u32;
var zero_point_word: u32;`
: `
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});`
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
}
`;
for (let c = 0; c < components * outputNumber; c++) {
Expand All @@ -191,12 +214,12 @@ export const createMatMulNBitsProgramInfo = (
${
zeroPoints
? `
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);
zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
zero_point_word_index = zero_point_byte_count >> 0x2u;
zero_point_byte_offset = zero_point_byte_count & 0x3u;
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);`
let zero_point${c} = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
: ''
}
col_index += 1;`;
Expand All @@ -212,7 +235,7 @@ export const createMatMulNBitsProgramInfo = (
}
calcStr += `
var b_value: u32;
let b_mask: u32 = 0x0F0F0F0Fu;
let b_mask: u32 = ${attributes.bits === 2 ? '0x03030303u' : '0x0F0F0F0Fu'};
var b_value_lower: vec4<u32>;
var b_value_upper: vec4<u32>;
var b_quantized_values: ${qDqDataType};
Expand All @@ -237,7 +260,7 @@ export const createMatMulNBitsProgramInfo = (
${prepareBData()}
for (var i: u32 = 0; i < ${bComponents}; i++) {
${processOneWord()}
word_offset += ${8 / aComponents};
word_offset += ${valuesPerWord / aComponents};
}
}
}
Expand Down Expand Up @@ -291,7 +314,8 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
const workgroupSize = 128;
const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
const workgroupX = workgroupSize / workgroupY;
const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data.
const valuesPerWordBs32 = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
const tileSize = workgroupX * bComponents * valuesPerWordBs32; // each uint32 has valuesPerWord data.
const aLengthPerTile = tileSize / aComponents;
const blocksPerTile = tileSize / attributes.blockSize;
const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;
Expand Down Expand Up @@ -376,36 +400,59 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
${
zeroPoints
? `
let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
let zero_point_word_index = zero_point_byte_count >> 0x2u;
let zero_point_byte_offset = zero_point_byte_count & 0x3u;
let zero_point_nibble_offset: u32 = block & 0x1u;
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
let zero_point = ${dataType}((zero_point_word) & 0xFu);`
let zero_point = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
: `
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${8.0});`
// The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
}
let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
var word_offset = local_id.x * ${attributes.blockSize / aComponents};
for (var i: u32 = 0; i < ${bComponents}; i++) {
${readA()}
let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
inter_results[local_id.y][local_id.x] += ${Array.from(
{ length: 2 },
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
).join(' + ')};
word_offset += ${8 / aComponents};
${(() => {
const passesPerWordBs32 = Math.floor(valuesPerWordBs32 / 8);
let code = '';
for (let pass = 0; pass < passesPerWordBs32; pass++) {
const lowerShift = pass * attributes.bits * 4;
const upperShift = lowerShift + attributes.bits;
code += `
${readA()}
{${
attributes.bits === 2
? `
let half_word = b_value >> ${pass * 16}u;
let byte_lo = half_word & 0xFFu;
let byte_hi = (half_word >> 8u) & 0xFFu;
let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
let b_value_lower = unpack4xU8(spread_word & 0x03030303u);
let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`
: `
let b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & 0x0F0F0F0Fu);`
}
let b_quantized_values = mat2x4<${dataType}>(${Array.from(
{ length: 4 },
(_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
).join(', ')});
let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
inter_results[local_id.y][local_id.x] += ${Array.from(
{ length: 2 },
(_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
).join(' + ')};
}
word_offset += ${8 / aComponents};`;
}
return code;
})()}
}
workgroupBarrier();
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class MatMulNBits final : public JsKernel {
accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
ORT_ENFORCE(nbits_ == 4 || nbits_ == 2,
"Only 2b and 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
"Block size must be a power of 2 and greater than or equal to 16.");
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet.");

const bool has_zero_points = zero_points != nullptr;
const uint32_t nbits = onnxruntime::narrow<uint32_t>(bits_);
if (has_zero_points) {
ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType<uint8_t>(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType());
ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization.");
}

MatMulComputeHelper helper;
Expand Down Expand Up @@ -205,9 +203,13 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
const uint32_t components_a = GetMaxComponents(K);
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
uint32_t components = GetMaxComponents(N);
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits.
// For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2.
uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1;
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)].
// The shader uses a flat linear index to address individual n-bit zero point values.
// Since each column's zero points are byte-aligned in the packed buffer, we must round
// n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point
// values per byte — so that the linear stride correctly skips byte-boundary padding.
const uint32_t zp_elements_per_byte = 8 / static_cast<uint32_t>(nbits);
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;

#if !defined(__wasm__)
int32_t subgroup_matrix_config_index = -1;
Expand All @@ -219,7 +221,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
#endif

// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
// DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points.
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
!(has_zero_points && nbits == 2) &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
const bit_mask = 0xFFu;
#elif n_bits == 2
const default_zero_point = 2;
const bit_mask = 0x3u;
#endif

#if has_zero_points
Expand Down
Loading
Loading