Skip to content

Commit 4d3daf8

Browse files
authored
opencl: add general Q6_K mm and Q4_K mv (ggml-org#19347)
* opencl: add general q6_k mm * opencl: refine condition for q6_K mm * opencl: add general q4_K mv * opencl: fix whitespace
1 parent 914dde7 commit 4d3daf8

File tree

4 files changed

+461
-2
lines changed

4 files changed

+461
-2
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ set(GGML_OPENCL_KERNELS
8585
mul_mv_q4_0_f32_8x_flat
8686
mul_mv_q4_0_f32_1d_8x_flat
8787
mul_mv_q4_0_f32_1d_16x_flat
88+
mul_mv_q4_k_f32
8889
mul_mv_q6_k_f32
8990
mul_mv_q6_k_f32_flat
9091
mul_mv_q8_0_f32
@@ -101,6 +102,7 @@ set(GGML_OPENCL_KERNELS
101102
mul_mm_f32_f32_l4_lm
102103
mul_mm_f16_f32_l4_lm
103104
mul_mm_q8_0_f32_l4_lm
105+
mul_mm_q6_k_f32_l4_lm
104106
mul_mm_q8_0_f32_8x4
105107
gemv_noshuffle_general_q8_0_f32
106108
mul

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ struct ggml_backend_opencl_context {
532532
cl_kernel kernel_restore_block_q4_0_noshuffle;
533533
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
534534
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
535+
cl_kernel kernel_mul_mv_q4_K_f32;
535536
cl_kernel kernel_mul_mv_q6_K_f32;
536537
cl_kernel kernel_mul_mv_q6_K_f32_flat;
537538
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@@ -564,6 +565,7 @@ struct ggml_backend_opencl_context {
564565
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
565566
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
566567
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
568+
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
567569

568570
std::vector<ProfilingInfo> profiling_info;
569571

@@ -1117,6 +1119,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11171119
GGML_LOG_CONT(".");
11181120
}
11191121

1122+
// mul_mv_q4_k_f32
1123+
{
1124+
#ifdef GGML_OPENCL_EMBED_KERNELS
1125+
const std::string kernel_src {
1126+
#include "mul_mv_q4_k_f32.cl.h"
1127+
};
1128+
#else
1129+
const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
1130+
#endif
1131+
cl_program prog =
1132+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1133+
1134+
CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
1135+
CL_CHECK(clReleaseProgram(prog));
1136+
GGML_LOG_CONT(".");
1137+
}
1138+
11201139
// mul_mv_q6_k_f32
11211140
{
11221141
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1358,6 +1377,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13581377
GGML_LOG_CONT(".");
13591378
}
13601379

1380+
// mul_mm_q6_k_f32_l4_lm
1381+
{
1382+
#ifdef GGML_OPENCL_EMBED_KERNELS
1383+
const std::string kernel_src {
1384+
#include "mul_mm_q6_k_f32_l4_lm.cl.h"
1385+
};
1386+
#else
1387+
const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
1388+
#endif
1389+
cl_program prog =
1390+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1391+
1392+
CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
1393+
CL_CHECK(clReleaseProgram(prog));
1394+
GGML_LOG_CONT(".");
1395+
}
1396+
13611397
// mul_mm_f16_f32_kq_kqv
13621398
{
13631399
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3364,6 +3400,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
33643400
} else if (op->src[0]->type == GGML_TYPE_F32) {
33653401
return op->src[1]->type == GGML_TYPE_F32;
33663402
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
3403+
op->src[0]->type == GGML_TYPE_Q4_K ||
33673404
op->src[0]->type == GGML_TYPE_Q6_K) {
33683405
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
33693406
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
@@ -8927,6 +8964,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
89278964
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
89288965
return;
89298966
}
8967+
case GGML_TYPE_Q6_K: {
8968+
if (ne11 < 32) {
8969+
break;
8970+
}
8971+
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
8972+
break;
8973+
}
8974+
8975+
kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
8976+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
8977+
8978+
int batch_stride_a = ne00*ne01;
8979+
int batch_stride_b = ne10*ne11;
8980+
int batch_stride_d = ne0*ne1;
8981+
8982+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
8983+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
8984+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
8985+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
8986+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
8987+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
8988+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
8989+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
8990+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
8991+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
8992+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
8993+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
8994+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
8995+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
8996+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
8997+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
8998+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
8999+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
9000+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
9001+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
9002+
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
9003+
9004+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
9005+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
9006+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
9007+
9008+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
9009+
return;
9010+
}
89309011
default:
89319012
break;
89329013
}
@@ -9262,7 +9343,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
92629343
}
92639344
case GGML_TYPE_Q2_K:
92649345
case GGML_TYPE_Q3_K:
9265-
case GGML_TYPE_Q4_K:
9346+
case GGML_TYPE_Q4_K: {
9347+
kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
9348+
9349+
if (backend_ctx->gpu_family == INTEL) {
9350+
nth0 = 16;
9351+
nth1 = 1;
9352+
ndst = 4;
9353+
} else if (backend_ctx->gpu_family == ADRENO) {
9354+
nth0 = 64;
9355+
nth1 = 1;
9356+
ndst = 4;
9357+
} else {
9358+
GGML_ASSERT(false && "TODO: Unknown GPU");
9359+
}
9360+
9361+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
9362+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
9363+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
9364+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
9365+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
9366+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
9367+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
9368+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
9369+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
9370+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
9371+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
9372+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
9373+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
9374+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
9375+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
9376+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
9377+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
9378+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
9379+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
9380+
break;
9381+
}
92669382
case GGML_TYPE_Q5_K:
92679383
case GGML_TYPE_Q6_K:
92689384
#ifdef GGML_OPENCL_SOA_Q
@@ -9424,7 +9540,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
94249540

94259541
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
94269542
} else if (src0t == GGML_TYPE_Q4_K) {
9427-
GGML_ASSERT(false && "not implemented");
9543+
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
9544+
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
9545+
9546+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
94289547
} else if (src0t == GGML_TYPE_Q3_K) {
94299548
GGML_ASSERT(false && "not implemented");
94309549
} else if (src0t == GGML_TYPE_Q5_K) {
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 2
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
13+
global uchar * src0_ql,
14+
global uchar * src0_qh,
15+
global char * src0_s,
16+
global half * src0_d,
17+
global float4 * src1,
18+
ulong offset1,
19+
global float * dst,
20+
ulong offsetd,
21+
22+
int ne00,
23+
int ne01,
24+
int ne02,
25+
int ne11,
26+
int ne12,
27+
28+
int stride_a,
29+
int stride_b,
30+
int stride_d,
31+
32+
int batch_stride_a,
33+
int batch_stride_b,
34+
int batch_stride_d,
35+
36+
int r2,
37+
int r3
38+
) {
39+
src1 = (global float4*)((global char*)src1 + offset1);
40+
dst = (global float *)((global char*)dst + offsetd);
41+
42+
local float buf_a[BM * BK];
43+
local float buf_b[BN * BK];
44+
45+
const int batch_idx = get_global_id(2);
46+
47+
const int i13 = batch_idx / ne12;
48+
const int i12 = batch_idx % ne12;
49+
50+
const int i03 = i13 / r3;
51+
const int i02 = i12 / r2;
52+
53+
const int batch_idx_a = i03 * ne02 + i02;
54+
55+
const int ir = get_group_id(0);
56+
const int ic = get_group_id(1);
57+
58+
const int tid = get_local_id(0);
59+
const int th_r = tid % (BM / TM);
60+
const int th_c = tid / (BM / TM);
61+
62+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
63+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
64+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
65+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
66+
67+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
68+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
69+
70+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
71+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
72+
73+
float sums[TM * TN];
74+
float cache_a[TM];
75+
float cache_b[TN];
76+
77+
for (int i = 0; i < TM * TN; i++) {
78+
sums[i] = 0.0f;
79+
}
80+
81+
for (int block = 0; block < ne00; block += BK) {
82+
for (int l = 0; l < BM; l += loadstride_a) {
83+
if (ir*BM + loadc_a + l < ne01) {
84+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
85+
86+
int ib = idx / 128; // 2 values per idx
87+
int iqs = idx % 128; // 0..127
88+
89+
int n = iqs / 64; // 0,1
90+
int b = (iqs % 64) / 32; // 0,1
91+
int is_b = (iqs % 16) / 8; // 0,1
92+
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
93+
int is = 8 * n + qhshift + is_b; // 0..15
94+
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
95+
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
96+
97+
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
98+
99+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
100+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
101+
} else {
102+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
103+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
104+
}
105+
}
106+
107+
for (int l = 0; l < BN; l += loadstride_b) {
108+
if (ic*BN + loadc_b + l < ne11) {
109+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
110+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
111+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
112+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
113+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
114+
} else {
115+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
116+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
117+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
118+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
119+
}
120+
}
121+
122+
barrier(CLK_LOCAL_MEM_FENCE);
123+
124+
pos_a += BK / LOAD_VEC_A;
125+
pos_b += BK / LOAD_VEC_B;
126+
127+
for (int i = 0; i < BK; i++) {
128+
for (int j = 0; j < TM; j++) {
129+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
130+
}
131+
132+
for (int j = 0; j < TN; j++) {
133+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
134+
}
135+
136+
for (int cc = 0; cc < TN; cc++) {
137+
for (int cr = 0; cr < TM; cr++) {
138+
const int sums_idx = cc*TM + cr;
139+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
140+
}
141+
}
142+
}
143+
barrier(CLK_LOCAL_MEM_FENCE);
144+
}
145+
146+
const int dr = ir * BM + th_r * TM;
147+
const int dc = ic * BN + th_c * TN;
148+
149+
const int offsets = batch_idx * batch_stride_d;
150+
151+
for (int cc = 0; cc < TN; cc++) {
152+
for (int cr = 0; cr < TM; cr++) {
153+
if (dr + cr < ne01 && dc + cc < ne11) {
154+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
155+
}
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)