Skip to content

Commit 44e211c

Browse files
authored
sycl : Optimize Q3_K mul_mat by reorder (#23725)
1 parent af6528e commit 44e211c

7 files changed

Lines changed: 340 additions & 4 deletions

File tree

ggml/src/ggml-sycl/convert.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
107107
#endif
108108
}
109109

110+
template <typename dst_t>
111+
static void dequantize_row_q3_K_sycl_reorder(const void *vx, dst_t *y, const int64_t k,
112+
dpct::queue_ptr stream) {
113+
const int64_t nb = k / QK_K;
114+
115+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
116+
stream->parallel_for(
117+
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
118+
[=](sycl::nd_item<3> item_ct1) {
119+
dequantize_block_q3_K_reorder(vx, y, item_ct1, nb);
120+
});
121+
}
122+
110123
template <typename dst_t>
111124
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
112125
dpct::queue_ptr stream) {
@@ -652,7 +665,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
652665
case GGML_TYPE_Q2_K:
653666
return dequantize_row_q2_K_sycl;
654667
case GGML_TYPE_Q3_K:
655-
return dequantize_row_q3_K_sycl;
668+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
669+
return dequantize_row_q3_K_sycl_reorder;
670+
} else {
671+
return dequantize_row_q3_K_sycl;
672+
}
656673
case GGML_TYPE_Q4_K:
657674
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
658675
return dequantize_row_q4_K_sycl_reorder;
@@ -730,7 +747,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
730747
case GGML_TYPE_Q2_K:
731748
return dequantize_row_q2_K_sycl;
732749
case GGML_TYPE_Q3_K:
733-
return dequantize_row_q3_K_sycl;
750+
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
751+
return dequantize_row_q3_K_sycl_reorder;
752+
} else {
753+
return dequantize_row_q3_K_sycl;
754+
}
734755
case GGML_TYPE_Q4_K:
735756
if (dst->src[0]->extra &&
736757
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {

ggml/src/ggml-sycl/dequantize.hpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,63 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
390390

391391
}
392392

393+
template<typename dst_t>
394+
static void dequantize_block_q3_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
395+
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
396+
#if QK_K == 256
397+
const int64_t i = item_ct1.get_group(2);
398+
if (i >= n_blocks) {
399+
return;
400+
}
401+
402+
const uint8_t * base = static_cast<const uint8_t *>(vx);
403+
const size_t qs_offset = i * (QK_K / 4);
404+
const size_t hmask_offset = n_blocks * (QK_K / 4) + i * (QK_K / 8);
405+
const size_t scales_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + i * 12;
406+
const size_t d_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + n_blocks * 12 +
407+
i * sizeof(ggml_half);
408+
409+
const uint8_t * qs = base + qs_offset;
410+
const uint8_t * hmask = base + hmask_offset;
411+
const uint8_t * scales = base + scales_offset;
412+
const float d_all = static_cast<float>(*reinterpret_cast<const ggml_half *>(base + d_offset));
413+
414+
const int64_t r = item_ct1.get_local_id(2) / 4;
415+
const int64_t tid = r / 2;
416+
const int64_t is0 = r % 2;
417+
const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
418+
const int64_t n = tid / 4;
419+
const int64_t j = tid - 4 * n;
420+
const int64_t is = 8 * n + 2 * j + is0;
421+
const int shift = 2 * j;
422+
uint8_t m = 1 << (4 * n + j);
423+
424+
uint8_t us = is < 4
425+
? (scales[is - 0] & 0xF) | (((scales[is + 8] >> 0) & 3) << 4)
426+
: is < 8
427+
? (scales[is - 0] & 0xF) | (((scales[is + 4] >> 2) & 3) << 4)
428+
: is < 12
429+
? (scales[is - 8] >> 4) | (((scales[is + 0] >> 4) & 3) << 4)
430+
: (scales[is - 8] >> 4) | (((scales[is - 4] >> 6) & 3) << 4);
431+
432+
const float dl = d_all * (us - 32);
433+
434+
dst_t * y = yy + i * QK_K + 128 * n + 32 * j;
435+
const uint8_t * q = qs + 32 * n;
436+
const uint8_t * hm = hmask;
437+
438+
for (int l = l0; l < l0 + 4; ++l) {
439+
y[l] = dl * ((int8_t) ((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
440+
}
441+
#else
442+
GGML_UNUSED(vx);
443+
GGML_UNUSED(yy);
444+
GGML_UNUSED(item_ct1);
445+
GGML_UNUSED(n_blocks);
446+
GGML_ABORT("Q3_K reorder dequantize not supported for QK_K != 256");
447+
#endif
448+
}
449+
393450
#if QK_K == 256
394451
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
395452
if (j < 4) {

ggml/src/ggml-sycl/dmmv.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
501501
}
502502
}
503503

504+
static void dequantize_mul_mat_vec_q3_k_reorder(const void *__restrict__ vx,
505+
const float *__restrict__ yy,
506+
float *__restrict__ dst,
507+
const int ncols, int nrows,
508+
const sycl::nd_item<3> &item_ct1) {
509+
510+
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
511+
item_ct1.get_local_id(1);
512+
if (row > nrows) return;
513+
514+
const int num_blocks_per_row = ncols / QK_K;
515+
const int ib0 = row*num_blocks_per_row;
516+
517+
// SOA base pointers for the reordered layout:
518+
// [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)]
519+
const int nb = nrows * num_blocks_per_row;
520+
const uint8_t * qs_base = (const uint8_t *)vx;
521+
const uint8_t * hmask_base = qs_base + (size_t)nb * (QK_K / 4);
522+
const uint8_t * scales_base = hmask_base + (size_t)nb * (QK_K / 8);
523+
const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t)nb * 12);
524+
525+
float tmp = 0; // partial sum for thread in warp
526+
527+
#if QK_K == 256
528+
529+
const uint16_t kmask1 = 0x0303;
530+
const uint16_t kmask2 = 0x0f0f;
531+
532+
const int tid =
533+
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
534+
const int ix =
535+
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
536+
537+
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
538+
const int step = 16/K_QUANTS_PER_ITERATION;
539+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
540+
const int in = tid - step*im; // 0....15 or 0...7
541+
542+
const uint8_t m = 1 << (4*im);
543+
544+
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
545+
const int q_offset = 32*im + l0;
546+
const int y_offset = 128*im + l0;
547+
548+
uint16_t utmp[4];
549+
const int8_t * s = (const int8_t *)utmp;
550+
551+
const uint16_t s_shift = 4*im;
552+
553+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
554+
const int bi = ib0 + i;
555+
556+
const float * y = yy + i * QK_K + y_offset;
557+
const uint8_t * q = qs_base + bi * (QK_K / 4) + q_offset;
558+
const uint8_t * h = hmask_base + bi * (QK_K / 8) + l0;
559+
560+
const uint16_t * a = (const uint16_t *)(scales_base + bi * 12);
561+
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
562+
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
563+
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
564+
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
565+
566+
const float d = d_base[bi];
567+
568+
float sum = 0;
569+
for (int l = 0; l < n; ++l) {
570+
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
571+
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
572+
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
573+
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
574+
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
575+
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
576+
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
577+
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
578+
}
579+
tmp += d * sum;
580+
}
581+
#else
582+
GGML_UNUSED(vx);
583+
GGML_UNUSED(yy);
584+
GGML_UNUSED(ncols);
585+
GGML_UNUSED(item_ct1);
586+
GGML_ABORT("Q3_K reorder DMMV not supported for QK_K != 256");
587+
#endif
588+
589+
// sum up partial sums and write back result
590+
#pragma unroll
591+
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
592+
tmp +=
593+
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
594+
}
595+
596+
if (item_ct1.get_local_id(2) == 0) {
597+
dst[row] = tmp;
598+
}
599+
}
600+
504601
/*
505602
DPCT1110:6: The total declared local variable size in device function
506603
dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
@@ -1440,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
14401537
});
14411538
}
14421539

1540+
static void dequantize_mul_mat_vec_q3_K_sycl_reorder(const void *vx, const float *y,
1541+
float *dst, const int ncols,
1542+
const int nrows,
1543+
dpct::queue_ptr stream) {
1544+
GGML_ASSERT(ncols % QK_K == 0);
1545+
const int ny = 2 / K_QUANTS_PER_ITERATION;
1546+
const int block_num_y = (nrows + ny - 1) / ny;
1547+
const sycl::range<3> block_nums(1, 1, block_num_y);
1548+
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1549+
stream->parallel_for(
1550+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
1551+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1552+
dequantize_mul_mat_vec_q3_k_reorder(vx, y, dst, ncols, nrows, item_ct1);
1553+
});
1554+
}
1555+
14431556
static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
14441557
float *dst, const int ncols,
14451558
const int nrows,
@@ -1581,7 +1694,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
15811694
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
15821695
break;
15831696
case GGML_TYPE_Q3_K:
1584-
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1697+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1698+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1699+
dequantize_mul_mat_vec_q3_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1700+
} else {
1701+
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1702+
}
15851703
break;
15861704
case GGML_TYPE_Q4_K:
15871705
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3549,6 +3549,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
35493549
case GGML_TYPE_Q4_0:
35503550
case GGML_TYPE_Q8_0:
35513551
return true;
3552+
case GGML_TYPE_Q3_K:
35523553
case GGML_TYPE_Q4_K:
35533554
case GGML_TYPE_Q5_K:
35543555
case GGML_TYPE_Q6_K:
@@ -3572,6 +3573,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
35723573
switch (type) {
35733574
case GGML_TYPE_Q4_0:
35743575
case GGML_TYPE_Q8_0:
3576+
case GGML_TYPE_Q3_K:
35753577
case GGML_TYPE_Q4_K:
35763578
case GGML_TYPE_Q5_K:
35773579
case GGML_TYPE_Q6_K:
@@ -3791,6 +3793,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
37913793
return true;
37923794
}
37933795

3796+
static bool reorder_qw_q3_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
3797+
GGML_ASSERT(size % sizeof(block_q3_K) == 0);
3798+
GGML_ASSERT(offset % sizeof(block_q3_K) == 0);
3799+
3800+
const int nblocks = size / sizeof(block_q3_K);
3801+
3802+
sycl_reorder_temp_buffer tmp(stream, size);
3803+
if (!tmp) {
3804+
GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size);
3805+
return false;
3806+
}
3807+
uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr);
3808+
3809+
sycl::event copy_event;
3810+
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
3811+
if (!g_ggml_sycl_use_async_mem_op) {
3812+
copy_event.wait();
3813+
}
3814+
3815+
auto * qs_ptr = data_device;
3816+
auto * hmask_ptr = qs_ptr + (QK_K / 4) * nblocks;
3817+
auto * scales_ptr = hmask_ptr + (QK_K / 8) * nblocks;
3818+
sycl::half * d_ptr = (sycl::half *) (scales_ptr + 12 * nblocks);
3819+
3820+
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
3821+
const block_q3_K * x = (const block_q3_K *) tmp_buf;
3822+
const int ib = i;
3823+
3824+
for (int j = 0; j < QK_K / 4; ++j) {
3825+
qs_ptr[ib * (QK_K / 4) + j] = x[ib].qs[j];
3826+
}
3827+
3828+
for (int j = 0; j < QK_K / 8; ++j) {
3829+
hmask_ptr[ib * (QK_K / 8) + j] = x[ib].hmask[j];
3830+
}
3831+
3832+
for (int j = 0; j < 12; ++j) {
3833+
scales_ptr[ib * 12 + j] = x[ib].scales[j];
3834+
}
3835+
3836+
d_ptr[ib] = x[ib].d;
3837+
});
3838+
if (!g_ggml_sycl_use_async_mem_op) {
3839+
reorder_event.wait_and_throw();
3840+
}
3841+
return true;
3842+
}
3843+
37943844
static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
37953845
GGML_ASSERT(size % sizeof(block_q5_K) == 0);
37963846
GGML_ASSERT(offset % sizeof(block_q5_K) == 0);
@@ -3903,6 +3953,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
39033953
return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
39043954
case GGML_TYPE_Q8_0:
39053955
return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream);
3956+
case GGML_TYPE_Q3_K:
3957+
return reorder_qw_q3_k(data_device, size, 0, stream);
39063958
case GGML_TYPE_Q4_K:
39073959
return reorder_qw_q4_k(data_device, size, 0, stream);
39083960
case GGML_TYPE_Q5_K:

ggml/src/ggml-sycl/mmvq.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,26 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
770770
}
771771
}
772772

773+
static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
774+
const int nrows, dpct::queue_ptr stream) {
775+
GGML_ASSERT(ncols % QK_K == 0);
776+
777+
// Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel.
778+
constexpr size_t num_subgroups = WARP_SIZE;
779+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups;
780+
781+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
782+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
783+
784+
stream->submit([&](sycl::handler & cgh) {
785+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
786+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
787+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows,
788+
nd_item);
789+
});
790+
});
791+
}
792+
773793
static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
774794
float *dst, const int ncols,
775795
const int nrows,
@@ -1153,7 +1173,15 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
11531173
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
11541174
break;
11551175
case GGML_TYPE_Q3_K:
1156-
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1176+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
1177+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
1178+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n");
1179+
reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff,
1180+
stream);
1181+
} else {
1182+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n");
1183+
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1184+
}
11571185
break;
11581186
case GGML_TYPE_Q4_K:
11591187
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&

0 commit comments

Comments
 (0)