Skip to content

Commit ba1df05

Browse files
shaofeiqilhez
andauthored
opencl: add q5_0/q5_1 gemm and gemv kernels for Adreno (#24319)
* opencl: add q5_0 adreno support * opencl: add q5_1 adreno support * opencl: cosmetic fix --------- Co-authored-by: Li He <lih@qti.qualcomm.com>
1 parent 1593d56 commit ba1df05

7 files changed

Lines changed: 1642 additions & 55 deletions

File tree

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ set(GGML_OPENCL_KERNELS
142142
gemm_noshuffle_q4_0_f32
143143
gemv_noshuffle_q4_1_f32
144144
gemm_noshuffle_q4_1_f32
145+
gemv_noshuffle_q5_0_f32
146+
gemm_noshuffle_q5_0_f32
147+
gemv_noshuffle_q5_1_f32
148+
gemm_noshuffle_q5_1_f32
145149
gemv_noshuffle_iq4_nl_f32
146150
gemm_noshuffle_iq4_nl_f32
147151
gemv_noshuffle_q8_0_f32

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 674 additions & 55 deletions
Large diffs are not rendered by default.

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,60 @@ kernel void kernel_restore_block_q5_0(
584584
}
585585
}
586586

587+
kernel void kernel_convert_block_q5_0_noshuffle(
588+
global struct block_q5_0 * src0,
589+
global uchar * dst_q,
590+
global uint * dst_qh,
591+
global half * dst_d
592+
) {
593+
global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0);
594+
global uchar * q = (global uchar *) dst_q + QK5_0/2*get_global_id(0);
595+
global uint * qh = (global uint *) dst_qh + get_global_id(0);
596+
global half * d = (global half *) dst_d + get_global_id(0);
597+
598+
*d = b->d;
599+
*qh = *((global uint *)(b->qh));
600+
601+
for (int i = 0; i < QK5_0/4; ++i) {
602+
uchar x0 = b->qs[2*i + 0];
603+
uchar x1 = b->qs[2*i + 1];
604+
605+
q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
606+
q[i + QK5_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
607+
608+
#ifdef ADRENO_GPU
609+
if (get_global_id(0) == 65536*4096) {
610+
printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
611+
}
612+
#endif
613+
}
614+
}
615+
616+
kernel void kernel_restore_block_q5_0_noshuffle(
617+
global uchar * src_q,
618+
global uint * src_qh,
619+
global half * src_d,
620+
global struct block_q5_0 * dst,
621+
uchar mask_0F,
622+
uchar mask_F0
623+
) {
624+
global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0);
625+
global uchar * q = (global uchar *) src_q + QK5_0/2*get_global_id(0);
626+
global uint * qh = (global uint *) src_qh + get_global_id(0);
627+
global half * d = (global half *) src_d + get_global_id(0);
628+
629+
b->d = *d;
630+
*((global uint *)(b->qh)) = *qh;
631+
632+
for (int i = 0; i < QK5_0/4; ++i) {
633+
uchar x0 = q[i + 0 ];
634+
uchar x1 = q[i + QK5_0/4];
635+
636+
b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
637+
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
638+
}
639+
}
640+
587641
kernel void kernel_convert_block_q5_0_trans4_ns(
588642
__global struct block_q5_0 * src0,
589643
__global uint * dst_qs,
@@ -736,6 +790,66 @@ kernel void kernel_restore_block_q5_1(
736790
}
737791
}
738792

793+
kernel void kernel_convert_block_q5_1_noshuffle(
794+
global struct block_q5_1 * src0,
795+
global uchar * dst_q,
796+
global uint * dst_qh,
797+
global half * dst_d,
798+
global half * dst_m
799+
) {
800+
global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0);
801+
global uchar * q = (global uchar *) dst_q + QK5_1/2*get_global_id(0);
802+
global uint * qh = (global uint *) dst_qh + get_global_id(0);
803+
global half * d = (global half *) dst_d + get_global_id(0);
804+
global half * m = (global half *) dst_m + get_global_id(0);
805+
806+
*d = b->d;
807+
*m = b->m;
808+
*qh = *((global uint *)(b->qh));
809+
810+
for (int i = 0; i < QK5_1/4; ++i) {
811+
uchar x0 = b->qs[2*i + 0];
812+
uchar x1 = b->qs[2*i + 1];
813+
814+
q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
815+
q[i + QK5_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
816+
817+
#ifdef ADRENO_GPU
818+
if (get_global_id(0) == 65536*4096) {
819+
printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
820+
}
821+
#endif
822+
}
823+
}
824+
825+
kernel void kernel_restore_block_q5_1_noshuffle(
826+
global uchar * src_q,
827+
global uint * src_qh,
828+
global half * src_d,
829+
global half * src_m,
830+
global struct block_q5_1 * dst,
831+
uchar mask_0F,
832+
uchar mask_F0
833+
) {
834+
global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0);
835+
global uchar * q = (global uchar *) src_q + QK5_1/2*get_global_id(0);
836+
global uint * qh = (global uint *) src_qh + get_global_id(0);
837+
global half * d = (global half *) src_d + get_global_id(0);
838+
global half * m = (global half *) src_m + get_global_id(0);
839+
840+
b->d = *d;
841+
b->m = *m;
842+
*((global uint *)(b->qh)) = *qh;
843+
844+
for (int i = 0; i < QK5_1/4; ++i) {
845+
uchar x0 = q[i + 0 ];
846+
uchar x1 = q[i + QK5_1/4];
847+
848+
b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
849+
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
850+
}
851+
}
852+
739853
kernel void kernel_convert_block_q5_1_trans4_ns(
740854
__global struct block_q5_1 * src0,
741855
__global uint * dst_qs,
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
3+
4+
#ifdef cl_qcom_reqd_sub_group_size
5+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
6+
#define ADRENO_GPU 1
7+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
8+
#endif
9+
10+
#ifdef ADRENO_GPU
11+
REQD_SUBGROUP_SIZE_128
12+
#endif
13+
14+
kernel void kernel_gemm_noshuffle_q5_0_f32(
15+
global const ushort * src0_qs, // quantized A
16+
global const uchar * src0_qh, // 5th bits
17+
global const half * src0_d, // A scales
18+
__read_only image1d_buffer_t src1, // B (1d image)
19+
global float * dst, // C
20+
int m, // M
21+
int n, // N with padding
22+
int k, // K
23+
int n_no_padding // N without padding
24+
) {
25+
26+
int n_4 = n >> 2;
27+
28+
int gy = get_global_id(0);
29+
int gx = get_global_id(1);
30+
int gx_2 = gx << 2;
31+
32+
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
33+
half8 B;
34+
half4 dequantized_weights;
35+
36+
global const ushort * weight_ptr = src0_qs + gx_2;
37+
global const uchar * qh_ptr = src0_qh + gx_2;
38+
global const half * scale_ptr = src0_d + gx_2;
39+
40+
for (int i = 0; i < k; i += 4) {
41+
42+
B.s0123 = read_imageh(src1, gy*2 + i*n_4);
43+
B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1);
44+
45+
ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m);
46+
uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m);
47+
uchar4 qh = bits1 >> (uchar4)(i & 4);
48+
49+
half4 scale = vload4(0, scale_ptr + (i >> 5)*m);
50+
51+
// j=0
52+
dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0;
53+
dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1;
54+
dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2;
55+
dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3;
56+
c0 += B * dequantized_weights.s0;
57+
c1 += B * dequantized_weights.s1;
58+
c2 += B * dequantized_weights.s2;
59+
c3 += B * dequantized_weights.s3;
60+
61+
// j=1
62+
B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4);
63+
B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1);
64+
dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0;
65+
dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1;
66+
dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2;
67+
dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3;
68+
c0 += B * dequantized_weights.s0;
69+
c1 += B * dequantized_weights.s1;
70+
c2 += B * dequantized_weights.s2;
71+
c3 += B * dequantized_weights.s3;
72+
73+
// j=2
74+
B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4);
75+
B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1);
76+
dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0;
77+
dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1;
78+
dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2;
79+
dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3;
80+
c0 += B * dequantized_weights.s0;
81+
c1 += B * dequantized_weights.s1;
82+
c2 += B * dequantized_weights.s2;
83+
c3 += B * dequantized_weights.s3;
84+
85+
// j=3
86+
B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4);
87+
B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1);
88+
dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0;
89+
dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1;
90+
dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2;
91+
dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3;
92+
c0 += B * dequantized_weights.s0;
93+
c1 += B * dequantized_weights.s1;
94+
c2 += B * dequantized_weights.s2;
95+
c3 += B * dequantized_weights.s3;
96+
}
97+
98+
int idx = (gy<<3)*m + (gx<<2);
99+
100+
if(idx+3 < m*n_no_padding){
101+
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
102+
idx += m;
103+
}
104+
if(idx+3 < m*n_no_padding){
105+
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
106+
idx += m;
107+
}
108+
if(idx+3 < m*n_no_padding){
109+
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
110+
idx += m;
111+
}
112+
if(idx+3 < m*n_no_padding){
113+
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
114+
idx += m;
115+
}
116+
if(idx+3 < m*n_no_padding){
117+
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
118+
idx += m;
119+
}
120+
if(idx+3 < m*n_no_padding){
121+
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
122+
idx += m;
123+
}
124+
if(idx+3 < m*n_no_padding){
125+
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
126+
idx += m;
127+
}
128+
if(idx+3 < m*n_no_padding){
129+
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
130+
}
131+
}

0 commit comments

Comments
 (0)