Skip to content

Commit 3a8563b

Browse files
committed
Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNBits kernel. Previously, the dp4a path for 2-bit quantization used a hardcoded 256-entry LUT assuming zero_point=2, and was blocked from running when custom zero points were provided.
1 parent a70ac2f commit 3a8563b

File tree

4 files changed

+312
-4
lines changed

4 files changed

+312
-4
lines changed

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,13 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32)
140140
const b_weight_offset : u32 = 0;
141141
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
142142
#endif
143-
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
144143
let block_idx = kidx_v/(block_size/16);
144+
#if has_zero_points
145+
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
146+
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value, zero);
147+
#else
148+
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
149+
#endif
145150
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
146151
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
147152
}
@@ -150,6 +155,11 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32)
150155
$MAIN {
151156
#if n_bits == 2
152157
LoadDequantizationTable(local_idx);
158+
#if has_zero_points
159+
LoadDequantizationTable(local_idx + 256);
160+
LoadDequantizationTable(local_idx + 512);
161+
LoadDequantizationTable(local_idx + 768);
162+
#endif
153163
workgroupBarrier();
154164
#endif
155165
// During the load phase we use all 256 threads to load 64 rows of A/B.

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,287 @@
4040

4141
#if n_bits == 2
4242
alias mul_precision = output_element_t;
43+
#if has_zero_points
44+
const lut_size = 1024;
45+
var<workgroup> shm_dequantization_table : array<u32, lut_size>;
46+
// 1024-entry LUT: 4 sections of 256 entries, one per zero_point value (0-3).
47+
// Index as: zero * 256 + byte_value
48+
const q2_dequantization_table = array<u32, lut_size>(
49+
// zero_point = 0: entries 0-255
50+
0x00000000, 0x00000001, 0x00000002, 0x00000003,
51+
0x00000100, 0x00000101, 0x00000102, 0x00000103,
52+
0x00000200, 0x00000201, 0x00000202, 0x00000203,
53+
0x00000300, 0x00000301, 0x00000302, 0x00000303,
54+
0x00010000, 0x00010001, 0x00010002, 0x00010003,
55+
0x00010100, 0x00010101, 0x00010102, 0x00010103,
56+
0x00010200, 0x00010201, 0x00010202, 0x00010203,
57+
0x00010300, 0x00010301, 0x00010302, 0x00010303,
58+
0x00020000, 0x00020001, 0x00020002, 0x00020003,
59+
0x00020100, 0x00020101, 0x00020102, 0x00020103,
60+
0x00020200, 0x00020201, 0x00020202, 0x00020203,
61+
0x00020300, 0x00020301, 0x00020302, 0x00020303,
62+
0x00030000, 0x00030001, 0x00030002, 0x00030003,
63+
0x00030100, 0x00030101, 0x00030102, 0x00030103,
64+
0x00030200, 0x00030201, 0x00030202, 0x00030203,
65+
0x00030300, 0x00030301, 0x00030302, 0x00030303,
66+
0x01000000, 0x01000001, 0x01000002, 0x01000003,
67+
0x01000100, 0x01000101, 0x01000102, 0x01000103,
68+
0x01000200, 0x01000201, 0x01000202, 0x01000203,
69+
0x01000300, 0x01000301, 0x01000302, 0x01000303,
70+
0x01010000, 0x01010001, 0x01010002, 0x01010003,
71+
0x01010100, 0x01010101, 0x01010102, 0x01010103,
72+
0x01010200, 0x01010201, 0x01010202, 0x01010203,
73+
0x01010300, 0x01010301, 0x01010302, 0x01010303,
74+
0x01020000, 0x01020001, 0x01020002, 0x01020003,
75+
0x01020100, 0x01020101, 0x01020102, 0x01020103,
76+
0x01020200, 0x01020201, 0x01020202, 0x01020203,
77+
0x01020300, 0x01020301, 0x01020302, 0x01020303,
78+
0x01030000, 0x01030001, 0x01030002, 0x01030003,
79+
0x01030100, 0x01030101, 0x01030102, 0x01030103,
80+
0x01030200, 0x01030201, 0x01030202, 0x01030203,
81+
0x01030300, 0x01030301, 0x01030302, 0x01030303,
82+
0x02000000, 0x02000001, 0x02000002, 0x02000003,
83+
0x02000100, 0x02000101, 0x02000102, 0x02000103,
84+
0x02000200, 0x02000201, 0x02000202, 0x02000203,
85+
0x02000300, 0x02000301, 0x02000302, 0x02000303,
86+
0x02010000, 0x02010001, 0x02010002, 0x02010003,
87+
0x02010100, 0x02010101, 0x02010102, 0x02010103,
88+
0x02010200, 0x02010201, 0x02010202, 0x02010203,
89+
0x02010300, 0x02010301, 0x02010302, 0x02010303,
90+
0x02020000, 0x02020001, 0x02020002, 0x02020003,
91+
0x02020100, 0x02020101, 0x02020102, 0x02020103,
92+
0x02020200, 0x02020201, 0x02020202, 0x02020203,
93+
0x02020300, 0x02020301, 0x02020302, 0x02020303,
94+
0x02030000, 0x02030001, 0x02030002, 0x02030003,
95+
0x02030100, 0x02030101, 0x02030102, 0x02030103,
96+
0x02030200, 0x02030201, 0x02030202, 0x02030203,
97+
0x02030300, 0x02030301, 0x02030302, 0x02030303,
98+
0x03000000, 0x03000001, 0x03000002, 0x03000003,
99+
0x03000100, 0x03000101, 0x03000102, 0x03000103,
100+
0x03000200, 0x03000201, 0x03000202, 0x03000203,
101+
0x03000300, 0x03000301, 0x03000302, 0x03000303,
102+
0x03010000, 0x03010001, 0x03010002, 0x03010003,
103+
0x03010100, 0x03010101, 0x03010102, 0x03010103,
104+
0x03010200, 0x03010201, 0x03010202, 0x03010203,
105+
0x03010300, 0x03010301, 0x03010302, 0x03010303,
106+
0x03020000, 0x03020001, 0x03020002, 0x03020003,
107+
0x03020100, 0x03020101, 0x03020102, 0x03020103,
108+
0x03020200, 0x03020201, 0x03020202, 0x03020203,
109+
0x03020300, 0x03020301, 0x03020302, 0x03020303,
110+
0x03030000, 0x03030001, 0x03030002, 0x03030003,
111+
0x03030100, 0x03030101, 0x03030102, 0x03030103,
112+
0x03030200, 0x03030201, 0x03030202, 0x03030203,
113+
0x03030300, 0x03030301, 0x03030302, 0x03030303,
114+
// zero_point = 1: entries 256-511
115+
0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01, 0xFFFFFF02,
116+
0xFFFF00FF, 0xFFFF0000, 0xFFFF0001, 0xFFFF0002,
117+
0xFFFF01FF, 0xFFFF0100, 0xFFFF0101, 0xFFFF0102,
118+
0xFFFF02FF, 0xFFFF0200, 0xFFFF0201, 0xFFFF0202,
119+
0xFF00FFFF, 0xFF00FF00, 0xFF00FF01, 0xFF00FF02,
120+
0xFF0000FF, 0xFF000000, 0xFF000001, 0xFF000002,
121+
0xFF0001FF, 0xFF000100, 0xFF000101, 0xFF000102,
122+
0xFF0002FF, 0xFF000200, 0xFF000201, 0xFF000202,
123+
0xFF01FFFF, 0xFF01FF00, 0xFF01FF01, 0xFF01FF02,
124+
0xFF0100FF, 0xFF010000, 0xFF010001, 0xFF010002,
125+
0xFF0101FF, 0xFF010100, 0xFF010101, 0xFF010102,
126+
0xFF0102FF, 0xFF010200, 0xFF010201, 0xFF010202,
127+
0xFF02FFFF, 0xFF02FF00, 0xFF02FF01, 0xFF02FF02,
128+
0xFF0200FF, 0xFF020000, 0xFF020001, 0xFF020002,
129+
0xFF0201FF, 0xFF020100, 0xFF020101, 0xFF020102,
130+
0xFF0202FF, 0xFF020200, 0xFF020201, 0xFF020202,
131+
0x00FFFFFF, 0x00FFFF00, 0x00FFFF01, 0x00FFFF02,
132+
0x00FF00FF, 0x00FF0000, 0x00FF0001, 0x00FF0002,
133+
0x00FF01FF, 0x00FF0100, 0x00FF0101, 0x00FF0102,
134+
0x00FF02FF, 0x00FF0200, 0x00FF0201, 0x00FF0202,
135+
0x0000FFFF, 0x0000FF00, 0x0000FF01, 0x0000FF02,
136+
0x000000FF, 0x00000000, 0x00000001, 0x00000002,
137+
0x000001FF, 0x00000100, 0x00000101, 0x00000102,
138+
0x000002FF, 0x00000200, 0x00000201, 0x00000202,
139+
0x0001FFFF, 0x0001FF00, 0x0001FF01, 0x0001FF02,
140+
0x000100FF, 0x00010000, 0x00010001, 0x00010002,
141+
0x000101FF, 0x00010100, 0x00010101, 0x00010102,
142+
0x000102FF, 0x00010200, 0x00010201, 0x00010202,
143+
0x0002FFFF, 0x0002FF00, 0x0002FF01, 0x0002FF02,
144+
0x000200FF, 0x00020000, 0x00020001, 0x00020002,
145+
0x000201FF, 0x00020100, 0x00020101, 0x00020102,
146+
0x000202FF, 0x00020200, 0x00020201, 0x00020202,
147+
0x01FFFFFF, 0x01FFFF00, 0x01FFFF01, 0x01FFFF02,
148+
0x01FF00FF, 0x01FF0000, 0x01FF0001, 0x01FF0002,
149+
0x01FF01FF, 0x01FF0100, 0x01FF0101, 0x01FF0102,
150+
0x01FF02FF, 0x01FF0200, 0x01FF0201, 0x01FF0202,
151+
0x0100FFFF, 0x0100FF00, 0x0100FF01, 0x0100FF02,
152+
0x010000FF, 0x01000000, 0x01000001, 0x01000002,
153+
0x010001FF, 0x01000100, 0x01000101, 0x01000102,
154+
0x010002FF, 0x01000200, 0x01000201, 0x01000202,
155+
0x0101FFFF, 0x0101FF00, 0x0101FF01, 0x0101FF02,
156+
0x010100FF, 0x01010000, 0x01010001, 0x01010002,
157+
0x010101FF, 0x01010100, 0x01010101, 0x01010102,
158+
0x010102FF, 0x01010200, 0x01010201, 0x01010202,
159+
0x0102FFFF, 0x0102FF00, 0x0102FF01, 0x0102FF02,
160+
0x010200FF, 0x01020000, 0x01020001, 0x01020002,
161+
0x010201FF, 0x01020100, 0x01020101, 0x01020102,
162+
0x010202FF, 0x01020200, 0x01020201, 0x01020202,
163+
0x02FFFFFF, 0x02FFFF00, 0x02FFFF01, 0x02FFFF02,
164+
0x02FF00FF, 0x02FF0000, 0x02FF0001, 0x02FF0002,
165+
0x02FF01FF, 0x02FF0100, 0x02FF0101, 0x02FF0102,
166+
0x02FF02FF, 0x02FF0200, 0x02FF0201, 0x02FF0202,
167+
0x0200FFFF, 0x0200FF00, 0x0200FF01, 0x0200FF02,
168+
0x020000FF, 0x02000000, 0x02000001, 0x02000002,
169+
0x020001FF, 0x02000100, 0x02000101, 0x02000102,
170+
0x020002FF, 0x02000200, 0x02000201, 0x02000202,
171+
0x0201FFFF, 0x0201FF00, 0x0201FF01, 0x0201FF02,
172+
0x020100FF, 0x02010000, 0x02010001, 0x02010002,
173+
0x020101FF, 0x02010100, 0x02010101, 0x02010102,
174+
0x020102FF, 0x02010200, 0x02010201, 0x02010202,
175+
0x0202FFFF, 0x0202FF00, 0x0202FF01, 0x0202FF02,
176+
0x020200FF, 0x02020000, 0x02020001, 0x02020002,
177+
0x020201FF, 0x02020100, 0x02020101, 0x02020102,
178+
0x020202FF, 0x02020200, 0x02020201, 0x02020202,
179+
// zero_point = 2: entries 512-767
180+
0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00, 0xFEFEFE01,
181+
0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00, 0xFEFEFF01,
182+
0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000, 0xFEFE0001,
183+
0xFEFE01FE, 0xFEFE01FF, 0xFEFE0100, 0xFEFE0101,
184+
0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00, 0xFEFFFE01,
185+
0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00, 0xFEFFFF01,
186+
0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000, 0xFEFF0001,
187+
0xFEFF01FE, 0xFEFF01FF, 0xFEFF0100, 0xFEFF0101,
188+
0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00, 0xFE00FE01,
189+
0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00, 0xFE00FF01,
190+
0xFE0000FE, 0xFE0000FF, 0xFE000000, 0xFE000001,
191+
0xFE0001FE, 0xFE0001FF, 0xFE000100, 0xFE000101,
192+
0xFE01FEFE, 0xFE01FEFF, 0xFE01FE00, 0xFE01FE01,
193+
0xFE01FFFE, 0xFE01FFFF, 0xFE01FF00, 0xFE01FF01,
194+
0xFE0100FE, 0xFE0100FF, 0xFE010000, 0xFE010001,
195+
0xFE0101FE, 0xFE0101FF, 0xFE010100, 0xFE010101,
196+
0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00, 0xFFFEFE01,
197+
0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00, 0xFFFEFF01,
198+
0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000, 0xFFFE0001,
199+
0xFFFE01FE, 0xFFFE01FF, 0xFFFE0100, 0xFFFE0101,
200+
0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00, 0xFFFFFE01,
201+
0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01,
202+
0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000, 0xFFFF0001,
203+
0xFFFF01FE, 0xFFFF01FF, 0xFFFF0100, 0xFFFF0101,
204+
0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00, 0xFF00FE01,
205+
0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00, 0xFF00FF01,
206+
0xFF0000FE, 0xFF0000FF, 0xFF000000, 0xFF000001,
207+
0xFF0001FE, 0xFF0001FF, 0xFF000100, 0xFF000101,
208+
0xFF01FEFE, 0xFF01FEFF, 0xFF01FE00, 0xFF01FE01,
209+
0xFF01FFFE, 0xFF01FFFF, 0xFF01FF00, 0xFF01FF01,
210+
0xFF0100FE, 0xFF0100FF, 0xFF010000, 0xFF010001,
211+
0xFF0101FE, 0xFF0101FF, 0xFF010100, 0xFF010101,
212+
0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00, 0x00FEFE01,
213+
0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00, 0x00FEFF01,
214+
0x00FE00FE, 0x00FE00FF, 0x00FE0000, 0x00FE0001,
215+
0x00FE01FE, 0x00FE01FF, 0x00FE0100, 0x00FE0101,
216+
0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00, 0x00FFFE01,
217+
0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00, 0x00FFFF01,
218+
0x00FF00FE, 0x00FF00FF, 0x00FF0000, 0x00FF0001,
219+
0x00FF01FE, 0x00FF01FF, 0x00FF0100, 0x00FF0101,
220+
0x0000FEFE, 0x0000FEFF, 0x0000FE00, 0x0000FE01,
221+
0x0000FFFE, 0x0000FFFF, 0x0000FF00, 0x0000FF01,
222+
0x000000FE, 0x000000FF, 0x00000000, 0x00000001,
223+
0x000001FE, 0x000001FF, 0x00000100, 0x00000101,
224+
0x0001FEFE, 0x0001FEFF, 0x0001FE00, 0x0001FE01,
225+
0x0001FFFE, 0x0001FFFF, 0x0001FF00, 0x0001FF01,
226+
0x000100FE, 0x000100FF, 0x00010000, 0x00010001,
227+
0x000101FE, 0x000101FF, 0x00010100, 0x00010101,
228+
0x01FEFEFE, 0x01FEFEFF, 0x01FEFE00, 0x01FEFE01,
229+
0x01FEFFFE, 0x01FEFFFF, 0x01FEFF00, 0x01FEFF01,
230+
0x01FE00FE, 0x01FE00FF, 0x01FE0000, 0x01FE0001,
231+
0x01FE01FE, 0x01FE01FF, 0x01FE0100, 0x01FE0101,
232+
0x01FFFEFE, 0x01FFFEFF, 0x01FFFE00, 0x01FFFE01,
233+
0x01FFFFFE, 0x01FFFFFF, 0x01FFFF00, 0x01FFFF01,
234+
0x01FF00FE, 0x01FF00FF, 0x01FF0000, 0x01FF0001,
235+
0x01FF01FE, 0x01FF01FF, 0x01FF0100, 0x01FF0101,
236+
0x0100FEFE, 0x0100FEFF, 0x0100FE00, 0x0100FE01,
237+
0x0100FFFE, 0x0100FFFF, 0x0100FF00, 0x0100FF01,
238+
0x010000FE, 0x010000FF, 0x01000000, 0x01000001,
239+
0x010001FE, 0x010001FF, 0x01000100, 0x01000101,
240+
0x0101FEFE, 0x0101FEFF, 0x0101FE00, 0x0101FE01,
241+
0x0101FFFE, 0x0101FFFF, 0x0101FF00, 0x0101FF01,
242+
0x010100FE, 0x010100FF, 0x01010000, 0x01010001,
243+
0x010101FE, 0x010101FF, 0x01010100, 0x01010101,
244+
// zero_point = 3: entries 768-1023
245+
0xFDFDFDFD, 0xFDFDFDFE, 0xFDFDFDFF, 0xFDFDFD00,
246+
0xFDFDFEFD, 0xFDFDFEFE, 0xFDFDFEFF, 0xFDFDFE00,
247+
0xFDFDFFFD, 0xFDFDFFFE, 0xFDFDFFFF, 0xFDFDFF00,
248+
0xFDFD00FD, 0xFDFD00FE, 0xFDFD00FF, 0xFDFD0000,
249+
0xFDFEFDFD, 0xFDFEFDFE, 0xFDFEFDFF, 0xFDFEFD00,
250+
0xFDFEFEFD, 0xFDFEFEFE, 0xFDFEFEFF, 0xFDFEFE00,
251+
0xFDFEFFFD, 0xFDFEFFFE, 0xFDFEFFFF, 0xFDFEFF00,
252+
0xFDFE00FD, 0xFDFE00FE, 0xFDFE00FF, 0xFDFE0000,
253+
0xFDFFFDFD, 0xFDFFFDFE, 0xFDFFFDFF, 0xFDFFFD00,
254+
0xFDFFFEFD, 0xFDFFFEFE, 0xFDFFFEFF, 0xFDFFFE00,
255+
0xFDFFFFFD, 0xFDFFFFFE, 0xFDFFFFFF, 0xFDFFFF00,
256+
0xFDFF00FD, 0xFDFF00FE, 0xFDFF00FF, 0xFDFF0000,
257+
0xFD00FDFD, 0xFD00FDFE, 0xFD00FDFF, 0xFD00FD00,
258+
0xFD00FEFD, 0xFD00FEFE, 0xFD00FEFF, 0xFD00FE00,
259+
0xFD00FFFD, 0xFD00FFFE, 0xFD00FFFF, 0xFD00FF00,
260+
0xFD0000FD, 0xFD0000FE, 0xFD0000FF, 0xFD000000,
261+
0xFEFDFDFD, 0xFEFDFDFE, 0xFEFDFDFF, 0xFEFDFD00,
262+
0xFEFDFEFD, 0xFEFDFEFE, 0xFEFDFEFF, 0xFEFDFE00,
263+
0xFEFDFFFD, 0xFEFDFFFE, 0xFEFDFFFF, 0xFEFDFF00,
264+
0xFEFD00FD, 0xFEFD00FE, 0xFEFD00FF, 0xFEFD0000,
265+
0xFEFEFDFD, 0xFEFEFDFE, 0xFEFEFDFF, 0xFEFEFD00,
266+
0xFEFEFEFD, 0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00,
267+
0xFEFEFFFD, 0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00,
268+
0xFEFE00FD, 0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000,
269+
0xFEFFFDFD, 0xFEFFFDFE, 0xFEFFFDFF, 0xFEFFFD00,
270+
0xFEFFFEFD, 0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00,
271+
0xFEFFFFFD, 0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00,
272+
0xFEFF00FD, 0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000,
273+
0xFE00FDFD, 0xFE00FDFE, 0xFE00FDFF, 0xFE00FD00,
274+
0xFE00FEFD, 0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00,
275+
0xFE00FFFD, 0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00,
276+
0xFE0000FD, 0xFE0000FE, 0xFE0000FF, 0xFE000000,
277+
0xFFFDFDFD, 0xFFFDFDFE, 0xFFFDFDFF, 0xFFFDFD00,
278+
0xFFFDFEFD, 0xFFFDFEFE, 0xFFFDFEFF, 0xFFFDFE00,
279+
0xFFFDFFFD, 0xFFFDFFFE, 0xFFFDFFFF, 0xFFFDFF00,
280+
0xFFFD00FD, 0xFFFD00FE, 0xFFFD00FF, 0xFFFD0000,
281+
0xFFFEFDFD, 0xFFFEFDFE, 0xFFFEFDFF, 0xFFFEFD00,
282+
0xFFFEFEFD, 0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00,
283+
0xFFFEFFFD, 0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00,
284+
0xFFFE00FD, 0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000,
285+
0xFFFFFDFD, 0xFFFFFDFE, 0xFFFFFDFF, 0xFFFFFD00,
286+
0xFFFFFEFD, 0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00,
287+
0xFFFFFFFD, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00,
288+
0xFFFF00FD, 0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000,
289+
0xFF00FDFD, 0xFF00FDFE, 0xFF00FDFF, 0xFF00FD00,
290+
0xFF00FEFD, 0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00,
291+
0xFF00FFFD, 0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00,
292+
0xFF0000FD, 0xFF0000FE, 0xFF0000FF, 0xFF000000,
293+
0x00FDFDFD, 0x00FDFDFE, 0x00FDFDFF, 0x00FDFD00,
294+
0x00FDFEFD, 0x00FDFEFE, 0x00FDFEFF, 0x00FDFE00,
295+
0x00FDFFFD, 0x00FDFFFE, 0x00FDFFFF, 0x00FDFF00,
296+
0x00FD00FD, 0x00FD00FE, 0x00FD00FF, 0x00FD0000,
297+
0x00FEFDFD, 0x00FEFDFE, 0x00FEFDFF, 0x00FEFD00,
298+
0x00FEFEFD, 0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00,
299+
0x00FEFFFD, 0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00,
300+
0x00FE00FD, 0x00FE00FE, 0x00FE00FF, 0x00FE0000,
301+
0x00FFFDFD, 0x00FFFDFE, 0x00FFFDFF, 0x00FFFD00,
302+
0x00FFFEFD, 0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00,
303+
0x00FFFFFD, 0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00,
304+
0x00FF00FD, 0x00FF00FE, 0x00FF00FF, 0x00FF0000,
305+
0x0000FDFD, 0x0000FDFE, 0x0000FDFF, 0x0000FD00,
306+
0x0000FEFD, 0x0000FEFE, 0x0000FEFF, 0x0000FE00,
307+
0x0000FFFD, 0x0000FFFE, 0x0000FFFF, 0x0000FF00,
308+
0x000000FD, 0x000000FE, 0x000000FF, 0x00000000);
309+
fn LoadDequantizationTable(local_idx:u32)
310+
{
311+
// Move dequantization table into on chip memory.
312+
shm_dequantization_table[local_idx] = q2_dequantization_table[local_idx];
313+
}
314+
fn DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) -> vec4<u32>
315+
{
316+
let base = u32(zero) * 256;
317+
let unpacked = unpack4xU8(in);
318+
return vec4<u32>(shm_dequantization_table[base + unpacked[0]],
319+
shm_dequantization_table[base + unpacked[1]],
320+
shm_dequantization_table[base + unpacked[2]],
321+
shm_dequantization_table[base + unpacked[3]]);
322+
}
323+
#else
43324
const lut_size = 256;
44325
var<workgroup> shm_dequantization_table : array<u32, lut_size>;
45326
const q2_dequantization_table = array<u32, lut_size>(
@@ -313,6 +594,7 @@
313594
shm_dequantization_table[unpacked[3]]);
314595
}
315596
#endif
597+
#endif
316598

317599
#if has_zero_points && n_bits == 8
318600
// If has_zero_points is true, vec4<i32>(unpack4xU8(b_data)) - vec4<i32>(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255].

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,21 @@ $MAIN {
8080
#endif
8181

8282
#if n_bits == 2
83+
#if has_zero_points
84+
// The workgroup size is 128, LoadDequantizationTable needs to load 1024 entries.
85+
LoadDequantizationTable(local_idx);
86+
LoadDequantizationTable(local_idx + 128);
87+
LoadDequantizationTable(local_idx + 256);
88+
LoadDequantizationTable(local_idx + 384);
89+
LoadDequantizationTable(local_idx + 512);
90+
LoadDequantizationTable(local_idx + 640);
91+
LoadDequantizationTable(local_idx + 768);
92+
LoadDequantizationTable(local_idx + 896);
93+
#else
8394
// The workgroup size is 128, LoadDequantizationTable needs to be called twice.
8495
LoadDequantizationTable(local_idx);
85-
LoadDequantizationTable(local_idx+127);
96+
LoadDequantizationTable(local_idx+128);
97+
#endif
8698
workgroupBarrier();
8799
#endif
88100
#if single_scale_weights
@@ -141,8 +153,13 @@ $MAIN {
141153

142154
#elif n_bits == 2
143155
let b_value = b.getByOffset(b_offset);
156+
#if has_zero_points
157+
let own_b = DequantizedFrom2BitsTo8Bits(b_value.x, zero);
158+
let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y, zero);
159+
#else
144160
let own_b = DequantizedFrom2BitsTo8Bits(b_value.x);
145161
let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y);
162+
#endif
146163
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
147164
#endif
148165
}

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
221221
#endif
222222

223223
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
224-
// DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points.
224+
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
225225
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
226-
!(has_zero_points && nbits == 2) &&
227226
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
228227
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
229228
}

0 commit comments

Comments
 (0)