@@ -123,21 +123,42 @@ export const createMatMulNBitsProgramInfo = (
123123 }
124124 } ) ( ) ;
125125
126+ // Number of quantized values per u32 word and passes needed (each pass extracts 8 values).
127+ const valuesPerWord = Math . floor ( 32 / attributes . bits ) ; // Q4=8, Q2=16
128+ const passesPerWord = Math . floor ( valuesPerWord / 8 ) ; // Q4=1, Q2=2
129+
126130 const processOneWord = ( ) : string => {
127- let calcStr = `
128- // reuse a data
129- var input_offset = ${ a . indicesToOffset ( `${ a . type . indices } (batch, row, word_offset)` ) } ;
130- var a_data: ${ qDqDataType } ;
131- for (var j: u32 = 0; j < ${ 8 / aComponents } ; j++) {
132- a_data[j] = ${ a . getByOffset ( 'input_offset' ) } ;
133- input_offset++;
131+ let calcStr = '' ;
132+ for ( let pass = 0 ; pass < passesPerWord ; pass ++ ) {
133+ // Each pass processes 8 values from the current u32 word.
134+ // For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6).
135+ const lowerShift = pass * attributes . bits * 4 ; // bit offset for lower group within each byte
136+ const upperShift = lowerShift + attributes . bits ;
137+ calcStr += `
138+ // reuse a data (pass ${ pass } )
139+ var input_offset${ pass > 0 ? pass : '' } = ${ pass === 0 ? a . indicesToOffset ( `${ a . type . indices } (batch, row, word_offset)` ) : `input_offset` } ;
140+ var a_data${ pass > 0 ? pass : '' } : ${ qDqDataType } ;
141+ for (var j${ pass > 0 ? pass : '' } : u32 = 0; j${ pass > 0 ? pass : '' } < ${ 8 / aComponents } ; j${ pass > 0 ? pass : '' } ++) {
142+ a_data${ pass > 0 ? pass : '' } [j${ pass > 0 ? pass : '' } ] = ${ a . getByOffset ( `input_offset${ pass > 0 ? pass : '' } ` ) } ;
143+ input_offset${ pass > 0 ? pass : '' } ++;
134144 }
135145 ` ;
136- for ( let c = 0 ; c < components * outputNumber ; c ++ ) {
137- calcStr += `
146+ for ( let c = 0 ; c < components * outputNumber ; c ++ ) {
147+ calcStr += `
138148 b_value = ${ bComponents === 1 ? `b${ c } _data` : `b${ c } _data[i]` } ;
139- b_value_lower = unpack4xU8(b_value & b_mask);
140- b_value_upper = unpack4xU8((b_value >> 4) & b_mask);
149+ ${
150+ attributes . bits === 2
151+ ? `{
152+ let half_word = b_value >> ${ pass * 16 } u;
153+ let byte_lo = half_word & 0xFFu;
154+ let byte_hi = (half_word >> 8u) & 0xFFu;
155+ let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
156+ b_value_lower = unpack4xU8(spread_word & b_mask);
157+ b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask);
158+ }`
159+ : `b_value_lower = unpack4xU8((b_value >> ${ lowerShift } u) & b_mask);
160+ b_value_upper = unpack4xU8((b_value >> ${ upperShift } u) & b_mask);`
161+ }
141162 b_quantized_values = ${ qDqDataType } (${ Array . from (
142163 { length : 4 } ,
143164 ( _ , i ) => `${ dataType } (b_value_lower[${ i } ]), ${ dataType } (b_value_upper[${ i } ])` ,
@@ -159,11 +180,12 @@ export const createMatMulNBitsProgramInfo = (
159180 ( _ , i ) =>
160181 `${
161182 aComponents === 1
162- ? `a_data[${ i } ] * b_dequantized_values[${ i } ]`
163- : `dot(a_data[${ i } ], b_dequantized_values[${ i } ])`
183+ ? `a_data${ pass > 0 ? pass : '' } [${ i } ] * b_dequantized_values[${ i } ]`
184+ : `dot(a_data${ pass > 0 ? pass : '' } [${ i } ], b_dequantized_values[${ i } ])`
164185 } `,
165186 ) . join ( ' + ' ) } ;
166187 ` ;
188+ }
167189 }
168190 return calcStr ;
169191 } ;
@@ -173,16 +195,17 @@ export const createMatMulNBitsProgramInfo = (
173195 ${
174196 zeroPoints
175197 ? `
176- let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2;
198+ let zero_point_values_per_byte: u32 = ${ Math . floor ( 8 / attributes . bits ) } u;
199+ let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
177200 var zero_point_byte_count: u32;
178201 var zero_point_word_index: u32;
179202 var zero_point_byte_offset: u32;
180- let zero_point_nibble_offset : u32 = block & 0x1u ;
203+ let zero_point_sub_offset : u32 = block % zero_point_values_per_byte ;
181204 var zero_point_bits_offset: u32;
182205 var zero_point_word: u32;`
183206 : `
184- // The default zero point is 8 for unsigned 4 -bit quantization.
185- let zero_point = ${ dataType } (${ 8.0 } );`
207+ // The default zero point is ${ Math . pow ( 2 , attributes . bits - 1 ) } for unsigned ${ attributes . bits } -bit quantization.
208+ let zero_point = ${ dataType } (${ Math . pow ( 2 , attributes . bits - 1 ) . toFixed ( 1 ) } );`
186209 }
187210 ` ;
188211 for ( let c = 0 ; c < components * outputNumber ; c ++ ) {
@@ -191,12 +214,12 @@ export const createMatMulNBitsProgramInfo = (
191214 ${
192215 zeroPoints
193216 ? `
194- zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u );
217+ zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte );
195218 zero_point_word_index = zero_point_byte_count >> 0x2u;
196219 zero_point_byte_offset = zero_point_byte_count & 0x3u;
197- zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2 );
220+ zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${ attributes . bits } u );
198221 zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_word_index' ) } >> zero_point_bits_offset;
199- let zero_point${ c } = ${ dataType } ((zero_point_word) & 0xFu);`
222+ let zero_point${ c } = ${ dataType } ((zero_point_word) & ${ attributes . bits === 2 ? '0x3u' : ' 0xFu' } );`
200223 : ''
201224 }
202225 col_index += 1;` ;
@@ -212,7 +235,7 @@ export const createMatMulNBitsProgramInfo = (
212235 }
213236 calcStr += `
214237 var b_value: u32;
215- let b_mask: u32 = 0x0F0F0F0Fu;
238+ let b_mask: u32 = ${ attributes . bits === 2 ? '0x03030303u' : ' 0x0F0F0F0Fu' } ;
216239 var b_value_lower: vec4<u32>;
217240 var b_value_upper: vec4<u32>;
218241 var b_quantized_values: ${ qDqDataType } ;
@@ -237,7 +260,7 @@ export const createMatMulNBitsProgramInfo = (
237260 ${ prepareBData ( ) }
238261 for (var i: u32 = 0; i < ${ bComponents } ; i++) {
239262 ${ processOneWord ( ) }
240- word_offset += ${ 8 / aComponents } ;
263+ word_offset += ${ valuesPerWord / aComponents } ;
241264 }
242265 }
243266 }
@@ -291,7 +314,8 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
291314 const workgroupSize = 128 ;
292315 const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1 ;
293316 const workgroupX = workgroupSize / workgroupY ;
294- const tileSize = workgroupX * bComponents * 8 ; // each uint32 has 8 data.
317+ const valuesPerWordBs32 = Math . floor ( 32 / attributes . bits ) ; // Q4=8, Q2=16
318+ const tileSize = workgroupX * bComponents * valuesPerWordBs32 ; // each uint32 has valuesPerWord data.
295319 const aLengthPerTile = tileSize / aComponents ;
296320 const blocksPerTile = tileSize / attributes . blockSize ;
297321 const dispatchSize = ShapeUtil . size ( outputShape ) / workgroupY ;
@@ -376,36 +400,59 @@ export const createMatMulNBitsBlockSize32ProgramInfo = (
376400 ${
377401 zeroPoints
378402 ? `
379- let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;
380- let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);
403+ let zero_point_values_per_byte: u32 = ${ Math . floor ( 8 / attributes . bits ) } u;
404+ let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
405+ let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
381406 let zero_point_word_index = zero_point_byte_count >> 0x2u;
382407 let zero_point_byte_offset = zero_point_byte_count & 0x3u;
383- let zero_point_nibble_offset : u32 = block & 0x1u ;
384- let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2 );
408+ let zero_point_sub_offset : u32 = block % zero_point_values_per_byte ;
409+ let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${ attributes . bits } u );
385410 let zero_point_word = ${ zeroPoints . getByOffset ( 'zero_point_word_index' ) } >> zero_point_bits_offset;
386- let zero_point = ${ dataType } ((zero_point_word) & 0xFu);`
411+ let zero_point = ${ dataType } ((zero_point_word) & ${ attributes . bits === 2 ? '0x3u' : ' 0xFu' } );`
387412 : `
388- // The default zero point is 8 for unsigned 4 -bit quantization.
389- let zero_point = ${ dataType } (${ 8.0 } );`
413+ // The default zero point is ${ Math . pow ( 2 , attributes . bits - 1 ) } for unsigned ${ attributes . bits } -bit quantization.
414+ let zero_point = ${ dataType } (${ Math . pow ( 2 , attributes . bits - 1 ) . toFixed ( 1 ) } );`
390415 }
391416 let scale = ${ scales . getByOffset ( `b_row * n_blocks_per_col + block` ) } ;
392417 let b_data = ${ b . getByIndices ( `${ b . type . indices } (b_row, block, 0)` ) } ;
393418 var word_offset = local_id.x * ${ attributes . blockSize / aComponents } ;
394419 for (var i: u32 = 0; i < ${ bComponents } ; i++) {
395- ${ readA ( ) }
396420 let b_value = ${ bComponents === 1 ? `b_data` : `b_data[i]` } ;
397- let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
398- let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
399- let b_quantized_values = mat2x4<${ dataType } >(${ Array . from (
400- { length : 4 } ,
401- ( _ , i ) => `${ dataType } (b_value_lower[${ i } ]), ${ dataType } (b_value_upper[${ i } ])` ,
402- ) . join ( ', ' ) } );
403- let b_dequantized_values = (b_quantized_values - mat2x4<${ dataType } >(${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } )) * scale;
404- inter_results[local_id.y][local_id.x] += ${ Array . from (
405- { length : 2 } ,
406- ( _ , i ) => `${ `dot(a_data${ i } , b_dequantized_values[${ i } ])` } ` ,
407- ) . join ( ' + ' ) } ;
408- word_offset += ${ 8 / aComponents } ;
421+ ${ ( ( ) => {
422+ const passesPerWordBs32 = Math . floor ( valuesPerWordBs32 / 8 ) ;
423+ let code = '' ;
424+ for ( let pass = 0 ; pass < passesPerWordBs32 ; pass ++ ) {
425+ const lowerShift = pass * attributes . bits * 4 ;
426+ const upperShift = lowerShift + attributes . bits ;
427+ code += `
428+ ${ readA ( ) }
429+ {${
430+ attributes . bits === 2
431+ ? `
432+ let half_word = b_value >> ${ pass * 16 } u;
433+ let byte_lo = half_word & 0xFFu;
434+ let byte_hi = (half_word >> 8u) & 0xFFu;
435+ let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
436+ let b_value_lower = unpack4xU8(spread_word & 0x03030303u);
437+ let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`
438+ : `
439+ let b_value_lower = unpack4xU8((b_value >> ${ lowerShift } u) & 0x0F0F0F0Fu);
440+ let b_value_upper = unpack4xU8((b_value >> ${ upperShift } u) & 0x0F0F0F0Fu);`
441+ }
442+ let b_quantized_values = mat2x4<${ dataType } >(${ Array . from (
443+ { length : 4 } ,
444+ ( _ , i ) => `${ dataType } (b_value_lower[${ i } ]), ${ dataType } (b_value_upper[${ i } ])` ,
445+ ) . join ( ', ' ) } );
446+ let b_dequantized_values = (b_quantized_values - mat2x4<${ dataType } >(${ Array ( 8 ) . fill ( 'zero_point' ) . join ( ',' ) } )) * scale;
447+ inter_results[local_id.y][local_id.x] += ${ Array . from (
448+ { length : 2 } ,
449+ ( _ , i ) => `${ `dot(a_data${ i } , b_dequantized_values[${ i } ])` } ` ,
450+ ) . join ( ' + ' ) } ;
451+ }
452+ word_offset += ${ 8 / aComponents } ;` ;
453+ }
454+ return code ;
455+ } ) ( ) }
409456 }
410457 workgroupBarrier();
411458 }
0 commit comments