|
| 1 | +#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
| 2 | +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable |
| 3 | + |
| 4 | +#ifdef cl_qcom_reqd_sub_group_size |
| 5 | +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable |
| 6 | +#define ADRENO_GPU 1 |
| 7 | +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) |
| 8 | +#endif |
| 9 | + |
| 10 | +#ifdef ADRENO_GPU |
| 11 | +REQD_SUBGROUP_SIZE_128 |
| 12 | +#endif |
| 13 | + |
| 14 | +kernel void kernel_gemm_noshuffle_q5_0_f32( |
| 15 | + global const ushort * src0_qs, // quantized A |
| 16 | + global const uchar * src0_qh, // 5th bits |
| 17 | + global const half * src0_d, // A scales |
| 18 | + __read_only image1d_buffer_t src1, // B (1d image) |
| 19 | + global float * dst, // C |
| 20 | + int m, // M |
| 21 | + int n, // N with padding |
| 22 | + int k, // K |
| 23 | + int n_no_padding // N without padding |
| 24 | +) { |
| 25 | + |
| 26 | + int n_4 = n >> 2; |
| 27 | + |
| 28 | + int gy = get_global_id(0); |
| 29 | + int gx = get_global_id(1); |
| 30 | + int gx_2 = gx << 2; |
| 31 | + |
| 32 | + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; |
| 33 | + half8 B; |
| 34 | + half4 dequantized_weights; |
| 35 | + |
| 36 | + global const ushort * weight_ptr = src0_qs + gx_2; |
| 37 | + global const uchar * qh_ptr = src0_qh + gx_2; |
| 38 | + global const half * scale_ptr = src0_d + gx_2; |
| 39 | + |
| 40 | + for (int i = 0; i < k; i += 4) { |
| 41 | + |
| 42 | + B.s0123 = read_imageh(src1, gy*2 + i*n_4); |
| 43 | + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); |
| 44 | + |
| 45 | + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); |
| 46 | + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); |
| 47 | + uchar4 qh = bits1 >> (uchar4)(i & 4); |
| 48 | + |
| 49 | + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); |
| 50 | + |
| 51 | + // j=0 |
| 52 | + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0; |
| 53 | + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1; |
| 54 | + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2; |
| 55 | + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3; |
| 56 | + c0 += B * dequantized_weights.s0; |
| 57 | + c1 += B * dequantized_weights.s1; |
| 58 | + c2 += B * dequantized_weights.s2; |
| 59 | + c3 += B * dequantized_weights.s3; |
| 60 | + |
| 61 | + // j=1 |
| 62 | + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); |
| 63 | + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); |
| 64 | + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0; |
| 65 | + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1; |
| 66 | + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2; |
| 67 | + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3; |
| 68 | + c0 += B * dequantized_weights.s0; |
| 69 | + c1 += B * dequantized_weights.s1; |
| 70 | + c2 += B * dequantized_weights.s2; |
| 71 | + c3 += B * dequantized_weights.s3; |
| 72 | + |
| 73 | + // j=2 |
| 74 | + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); |
| 75 | + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); |
| 76 | + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0; |
| 77 | + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1; |
| 78 | + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2; |
| 79 | + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3; |
| 80 | + c0 += B * dequantized_weights.s0; |
| 81 | + c1 += B * dequantized_weights.s1; |
| 82 | + c2 += B * dequantized_weights.s2; |
| 83 | + c3 += B * dequantized_weights.s3; |
| 84 | + |
| 85 | + // j=3 |
| 86 | + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); |
| 87 | + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); |
| 88 | + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0; |
| 89 | + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1; |
| 90 | + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2; |
| 91 | + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3; |
| 92 | + c0 += B * dequantized_weights.s0; |
| 93 | + c1 += B * dequantized_weights.s1; |
| 94 | + c2 += B * dequantized_weights.s2; |
| 95 | + c3 += B * dequantized_weights.s3; |
| 96 | + } |
| 97 | + |
| 98 | + int idx = (gy<<3)*m + (gx<<2); |
| 99 | + |
| 100 | + if(idx+3 < m*n_no_padding){ |
| 101 | + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); |
| 102 | + idx += m; |
| 103 | + } |
| 104 | + if(idx+3 < m*n_no_padding){ |
| 105 | + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); |
| 106 | + idx += m; |
| 107 | + } |
| 108 | + if(idx+3 < m*n_no_padding){ |
| 109 | + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); |
| 110 | + idx += m; |
| 111 | + } |
| 112 | + if(idx+3 < m*n_no_padding){ |
| 113 | + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); |
| 114 | + idx += m; |
| 115 | + } |
| 116 | + if(idx+3 < m*n_no_padding){ |
| 117 | + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); |
| 118 | + idx += m; |
| 119 | + } |
| 120 | + if(idx+3 < m*n_no_padding){ |
| 121 | + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); |
| 122 | + idx += m; |
| 123 | + } |
| 124 | + if(idx+3 < m*n_no_padding){ |
| 125 | + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); |
| 126 | + idx += m; |
| 127 | + } |
| 128 | + if(idx+3 < m*n_no_padding){ |
| 129 | + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); |
| 130 | + } |
| 131 | +} |
0 commit comments