@@ -501,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
501501 }
502502}
503503
504+ static void dequantize_mul_mat_vec_q3_k_reorder (const void *__restrict__ vx,
505+ const float *__restrict__ yy,
506+ float *__restrict__ dst,
507+ const int ncols, int nrows,
508+ const sycl::nd_item<3 > &item_ct1) {
509+
510+ const int row = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) +
511+ item_ct1.get_local_id (1 );
512+ if (row > nrows) return ;
513+
514+ const int num_blocks_per_row = ncols / QK_K ;
515+ const int ib0 = row*num_blocks_per_row;
516+
517+ // SOA base pointers for the reordered layout:
518+ // [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)]
519+ const int nb = nrows * num_blocks_per_row;
520+ const uint8_t * qs_base = (const uint8_t *)vx;
521+ const uint8_t * hmask_base = qs_base + (size_t )nb * (QK_K / 4 );
522+ const uint8_t * scales_base = hmask_base + (size_t )nb * (QK_K / 8 );
523+ const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t )nb * 12 );
524+
525+ float tmp = 0 ; // partial sum for thread in warp
526+
527+ #if QK_K == 256
528+
529+ const uint16_t kmask1 = 0x0303 ;
530+ const uint16_t kmask2 = 0x0f0f ;
531+
532+ const int tid =
533+ item_ct1.get_local_id (2 ) / K_QUANTS_PER_ITERATION ; // 0...31 or 0...16
534+ const int ix =
535+ item_ct1.get_local_id (2 ) % K_QUANTS_PER_ITERATION ; // 0 or 0,1
536+
537+ const int n = K_QUANTS_PER_ITERATION ; // iterations in the inner loop
538+ const int step = 16 /K_QUANTS_PER_ITERATION ;
539+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
540+ const int in = tid - step*im; // 0....15 or 0...7
541+
542+ const uint8_t m = 1 << (4 *im);
543+
544+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
545+ const int q_offset = 32 *im + l0;
546+ const int y_offset = 128 *im + l0;
547+
548+ uint16_t utmp[4 ];
549+ const int8_t * s = (const int8_t *)utmp;
550+
551+ const uint16_t s_shift = 4 *im;
552+
553+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION ) {
554+ const int bi = ib0 + i;
555+
556+ const float * y = yy + i * QK_K + y_offset;
557+ const uint8_t * q = qs_base + bi * (QK_K / 4 ) + q_offset;
558+ const uint8_t * h = hmask_base + bi * (QK_K / 8 ) + l0;
559+
560+ const uint16_t * a = (const uint16_t *)(scales_base + bi * 12 );
561+ utmp[0 ] = ((a[0 ] >> s_shift) & kmask2) | (((a[4 ] >> (s_shift + 0 )) & kmask1) << 4 );
562+ utmp[1 ] = ((a[1 ] >> s_shift) & kmask2) | (((a[5 ] >> (s_shift + 0 )) & kmask1) << 4 );
563+ utmp[2 ] = ((a[2 ] >> s_shift) & kmask2) | (((a[4 ] >> (s_shift + 2 )) & kmask1) << 4 );
564+ utmp[3 ] = ((a[3 ] >> s_shift) & kmask2) | (((a[5 ] >> (s_shift + 2 )) & kmask1) << 4 );
565+
566+ const float d = d_base[bi];
567+
568+ float sum = 0 ;
569+ for (int l = 0 ; l < n; ++l) {
570+ sum += y[l+ 0 ] * (s[0 ] - 32 ) * (((q[l] >> 0 ) & 3 ) - (h[l] & (m << 0 ) ? 0 : 4 ))
571+ + y[l+32 ] * (s[2 ] - 32 ) * (((q[l] >> 2 ) & 3 ) - (h[l] & (m << 1 ) ? 0 : 4 ))
572+ + y[l+64 ] * (s[4 ] - 32 ) * (((q[l] >> 4 ) & 3 ) - (h[l] & (m << 2 ) ? 0 : 4 ))
573+ + y[l+96 ] * (s[6 ] - 32 ) * (((q[l] >> 6 ) & 3 ) - (h[l] & (m << 3 ) ? 0 : 4 ));
574+ sum += y[l+16 ] * (s[1 ] - 32 ) * (((q[l+16 ] >> 0 ) & 3 ) - (h[l+16 ] & (m << 0 ) ? 0 : 4 ))
575+ + y[l+48 ] * (s[3 ] - 32 ) * (((q[l+16 ] >> 2 ) & 3 ) - (h[l+16 ] & (m << 1 ) ? 0 : 4 ))
576+ + y[l+80 ] * (s[5 ] - 32 ) * (((q[l+16 ] >> 4 ) & 3 ) - (h[l+16 ] & (m << 2 ) ? 0 : 4 ))
577+ + y[l+112 ] * (s[7 ] - 32 ) * (((q[l+16 ] >> 6 ) & 3 ) - (h[l+16 ] & (m << 3 ) ? 0 : 4 ));
578+ }
579+ tmp += d * sum;
580+ }
581+ #else
582+ GGML_UNUSED (vx);
583+ GGML_UNUSED (yy);
584+ GGML_UNUSED (ncols);
585+ GGML_UNUSED (item_ct1);
586+ GGML_ABORT (" Q3_K reorder DMMV not supported for QK_K != 256" );
587+ #endif
588+
589+ // sum up partial sums and write back result
590+ #pragma unroll
591+ for (int mask = QK_WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
592+ tmp +=
593+ dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), tmp, mask);
594+ }
595+
596+ if (item_ct1.get_local_id (2 ) == 0 ) {
597+ dst[row] = tmp;
598+ }
599+ }
600+
504601/*
505602DPCT1110:6: The total declared local variable size in device function
506603dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
@@ -1440,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
14401537 });
14411538}
14421539
1540+ static void dequantize_mul_mat_vec_q3_K_sycl_reorder (const void *vx, const float *y,
1541+ float *dst, const int ncols,
1542+ const int nrows,
1543+ dpct::queue_ptr stream) {
1544+ GGML_ASSERT (ncols % QK_K == 0 );
1545+ const int ny = 2 / K_QUANTS_PER_ITERATION ;
1546+ const int block_num_y = (nrows + ny - 1 ) / ny;
1547+ const sycl::range<3 > block_nums (1 , 1 , block_num_y);
1548+ const sycl::range<3 > block_dims (1 , ny, QK_WARP_SIZE );
1549+ stream->parallel_for (
1550+ sycl::nd_range<3 >(block_nums * block_dims, block_dims),
1551+ [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1552+ dequantize_mul_mat_vec_q3_k_reorder (vx, y, dst, ncols, nrows, item_ct1);
1553+ });
1554+ }
1555+
14431556static void dequantize_mul_mat_vec_q4_K_sycl (const void *vx, const float *y,
14441557 float *dst, const int ncols,
14451558 const int nrows,
@@ -1581,7 +1694,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
15811694 dequantize_mul_mat_vec_q2_K_sycl (src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
15821695 break ;
15831696 case GGML_TYPE_Q3_K :
1584- dequantize_mul_mat_vec_q3_K_sycl (src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1697+ if ((ggml_tensor_extra_gpu *) dst->src [0 ]->extra &&
1698+ ((ggml_tensor_extra_gpu *) dst->src [0 ]->extra )->optimized_feature .reorder ) {
1699+ dequantize_mul_mat_vec_q3_K_sycl_reorder (src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1700+ } else {
1701+ dequantize_mul_mat_vec_q3_K_sycl (src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
1702+ }
15851703 break ;
15861704 case GGML_TYPE_Q4_K :
15871705 if ((ggml_tensor_extra_gpu *) dst->src [0 ]->extra &&
0 commit comments