Skip to content

Commit 1962000

Browse files
authored
vulkan: Block-load Q3_K/Q6_K block data and subtract on 32b ints (#23056)
Q2_K/Q3_K/Q6_K do much better when using MMVQ on Intel BMG even though they're only 2-byte aligned, and Q3_K still wins on NVIDIA as well. mesa isn't all that great at coalescing back-to-back loads from alternating arrays, so we force it instead. Further, we can do subtraction directly on a full int32_t rather than an i8vec4 with bit twiddling because the high bit is always free to start. On Intel BMG on mesa, the switch to MMVQ provides an immediate ~57% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and ~78% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. The futher switch to block loads leads to a ~24% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and a ~48% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. Finally, Xe2 wins on MMVQ even for small k, so we take the NVIDIA override for K quants on Xe2 as well.
1 parent f8c0a19 commit 1962000

2 files changed

Lines changed: 80 additions & 47 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8336,8 +8336,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
83368336
return false;
83378337
}
83388338

8339-
// General performance issue with q3_k and q6_k due to 2-byte alignment
8340-
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
8339+
// q6_k only has 2-byte alignment which makes it somewhat problematic,
8340+
// using MMVQ is only a win on Intel.
8341+
bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL;
8342+
if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) {
83418343
return false;
83428344
}
83438345

@@ -8349,7 +8351,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
83498351
// Quantization overhead is not worth it for small k
83508352
switch (device->vendor_id) {
83518353
case VK_VENDOR_ID_NVIDIA:
8352-
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
8354+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
83538355
return true;
83548356
}
83558357

@@ -8376,9 +8378,16 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
83768378
return true;
83778379
}
83788380
case VK_VENDOR_ID_INTEL:
8381+
if (device->architecture == vk_device_architecture::INTEL_XE2) {
8382+
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
8383+
return true;
8384+
}
8385+
}
8386+
83798387
if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
8380-
// Intel Windows proprietary driver MMVQ performance is worse than fp16, see
8381-
// https://github.com/ggml-org/llama.cpp/issues/17628
8388+
// Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16,
8389+
// see https://github.com/ggml-org/llama.cpp/issues/17628 and
8390+
// https://github.com/ggml-org/llama.cpp/pull/23056
83828391
return false;
83838392
}
83848393

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,40 @@ i32vec4 repack4(uint ib, uint iqs) {
212212
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
213213
const uint hm_shift = iqs_k / 8;
214214

215+
const uvec4 qs = uvec4( uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 ]) |
216+
(uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 1]) << 16),
217+
uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 2]) |
218+
(uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 3]) << 16),
219+
uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 4]) |
220+
(uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 5]) << 16),
221+
uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 6]) |
222+
(uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 7]) << 16));
223+
224+
const uvec4 hmask = uvec4( uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 ]) |
225+
(uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 1]) << 16),
226+
uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 2]) |
227+
(uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 3]) << 16),
228+
uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 4]) |
229+
(uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 5]) << 16),
230+
uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 6]) |
231+
(uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 7]) << 16));
232+
215233
// bitwise OR to add 4 if hmask is set, subtract later
216-
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
217-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
218-
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
219-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
220-
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
221-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
222-
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
223-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
224-
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
225-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
226-
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
227-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
228-
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
229-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
230-
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
231-
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
232-
233-
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
234-
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
235-
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
236-
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
234+
const uint vals0 = (( qs.x >> qs_shift) & 0x03030303) |
235+
(((hmask.x >> hm_shift) & 0x01010101) << 2);
236+
const uint vals1 = (( qs.y >> qs_shift) & 0x03030303) |
237+
(((hmask.y >> hm_shift) & 0x01010101) << 2);
238+
const uint vals2 = (( qs.z >> qs_shift) & 0x03030303) |
239+
(((hmask.z >> hm_shift) & 0x01010101) << 2);
240+
const uint vals3 = (( qs.w >> qs_shift) & 0x03030303) |
241+
(((hmask.w >> hm_shift) & 0x01010101) << 2);
242+
243+
// Subtract 4 by twiddling bits rather than using re-packing as mesa
244+
// compiles repacking poorly.
245+
return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x04040404) ^ 0x80808080),
246+
int32_t(((vals1 ^ 0x80808080) - 0x04040404) ^ 0x80808080),
247+
int32_t(((vals2 ^ 0x80808080) - 0x04040404) ^ 0x80808080),
248+
int32_t(((vals3 ^ 0x80808080) - 0x04040404) ^ 0x80808080));
237249
}
238250

239251
float get_d_scale(uint ib, uint iqs) {
@@ -343,27 +355,39 @@ i32vec4 repack4(uint ib, uint iqs) {
343355
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
344356
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
345357

346-
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
347-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
348-
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
349-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
350-
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
351-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
352-
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
353-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
354-
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
355-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
356-
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
357-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
358-
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
359-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
360-
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
361-
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
362-
363-
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
364-
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
365-
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
366-
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
358+
const uvec4 ql = uvec4( uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 ]) |
359+
(uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 1]) << 16),
360+
uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 2]) |
361+
(uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 3]) << 16),
362+
uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 4]) |
363+
(uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 5]) << 16),
364+
uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 6]) |
365+
(uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 7]) << 16));
366+
367+
const uvec4 qh = uvec4( uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 ]) |
368+
(uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 1]) << 16),
369+
uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 2]) |
370+
(uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 3]) << 16),
371+
uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 4]) |
372+
(uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 5]) << 16),
373+
uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 6]) |
374+
(uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 7]) << 16));
375+
376+
const uint vals0 = (( ql.x >> ql_shift) & 0x0F0F0F0F) |
377+
(((qh.x >> qh_shift) & 0x03030303) << 4);
378+
const uint vals1 = (( ql.y >> ql_shift) & 0x0F0F0F0F) |
379+
(((qh.y >> qh_shift) & 0x03030303) << 4);
380+
const uint vals2 = (( ql.z >> ql_shift) & 0x0F0F0F0F) |
381+
(((qh.z >> qh_shift) & 0x03030303) << 4);
382+
const uint vals3 = (( ql.w >> ql_shift) & 0x0F0F0F0F) |
383+
(((qh.w >> qh_shift) & 0x03030303) << 4);
384+
385+
// Subtract 32 by twiddling bits rather than using re-packing as mesa
386+
// compiles repacking poorly.
387+
return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x20202020) ^ 0x80808080),
388+
int32_t(((vals1 ^ 0x80808080) - 0x20202020) ^ 0x80808080),
389+
int32_t(((vals2 ^ 0x80808080) - 0x20202020) ^ 0x80808080),
390+
int32_t(((vals3 ^ 0x80808080) - 0x20202020) ^ 0x80808080));
367391
}
368392

369393
float get_d_scale(uint ib, uint iqs) {

0 commit comments

Comments
 (0)