@@ -434,46 +434,58 @@ void weight_quant_gpu(const GPUContext& dev_ctx,
434434 }
435435}
436436
437- // pack int8 weight to 2int4 int one int8
438437__global__ void weight_permute_transpose_interleave_kernel_w4a8 (
439438 const int8_t * input_data_ptr,
440439 int8_t * output_data_ptr,
441440 int numel,
442441 int total_k,
443442 int total_n) {
444- const int interleave = 4 ;
445- const int interleave_group = 64 ;
446- const int permute_group = 32 ;
447-
448- for (int block_n = blockIdx.x ; block_n < total_n / interleave;
449- block_n += gridDim.x ) {
450- const int8_t * src_ptr = input_data_ptr + block_n * interleave;
451- int8_t * dst_ptr = output_data_ptr + block_n * interleave * total_k / 2 ;
452-
453- for (int block_k = threadIdx.y ; block_k < total_k / interleave_group;
454- block_k += blockDim.y ) {
455- const int8_t * src_ptr_1 = src_ptr + block_k * interleave_group * total_n;
456- int8_t * dst_ptr_1 = dst_ptr + block_k * interleave_group * interleave / 2 ;
457-
458- int tid_div_16 = threadIdx.x / 16 ;
459- int tid_mod_16 = threadIdx.x % 16 ;
460-
461- int src_offset = (tid_div_16 * permute_group + tid_mod_16) * total_n;
443+ // every 4 k-direction 4bit is packed to 2 int8, so here / 4.
444+ numel = numel / 4 ;
445+ for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x ;
446+ linear_idx < numel;
447+ linear_idx += blockDim.x * gridDim.x ) {
448+ const int k_group_id = linear_idx / total_n;
449+ const int n_id = linear_idx % total_n;
450+
451+ uint16_t res = 0 ;
452+ for (int j = 0 ; j < 4 ; j++) {
453+ const int k_id = k_group_id * 4 + j;
454+ uint16_t val = input_data_ptr[k_id * total_n + n_id];
455+ val = val & 0x0F ;
456+ val = val << (j * 4 );
457+ res |= val;
458+ }
462459
463- #pragma unroll
464- for (int idx = 0 ; idx < interleave; idx++) {
465- const int8_t * src_ptr_2 = src_ptr_1 + idx;
466- int8_t * dst_ptr_2 = dst_ptr_1 + idx * interleave_group / 2 ;
460+ constexpr int map[8 ] = {0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 };
461+ // remember output(in 16 bit granularity)'shape is
462+ // [16, 4, total_k/64, total_n/4]
463+ // index is :
464+ // [k_group_id % 16, n_id % 4, k_group_id/16, n_id/4]
465+ const int32_t new_index = map[k_group_id % 8 ] + k_group_id % 16 / 8 * 8 +
466+ (n_id % 4 ) * 16 + k_group_id / 16 * (16 * 4 ) +
467+ n_id / 4 * (total_k);
468+
469+ reinterpret_cast <uint16_t *>(output_data_ptr)[new_index] = res;
470+ }
471+ }
467472
468- int8_t tmp0 = src_ptr_2[src_offset];
469- int8_t tmp1 = src_ptr_2[src_offset + permute_group / 2 * total_n];
473+ __global__ void w4a8_inplace_permute (uint32_t * output_data_ptr, int numel) {
474+ for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x ;
475+ linear_idx < numel;
476+ linear_idx += blockDim.x * gridDim.x ) {
477+ const uint32_t value = output_data_ptr[linear_idx];
470478
471- int8_t packed_val = (tmp0 & 0x0f ) | ((tmp1 & 0x0f ) << 4 ) ;
479+ uint32_t res = 0 ;
472480
473- int dst_offset = threadIdx.x ;
474- dst_ptr_2[dst_offset] = packed_val;
475- }
481+ const int map[8 ] = {0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 };
482+ for (int i = 0 ; i < 8 ; i++) {
483+ uint32_t tmp = value >> (i * 4 );
484+ tmp = tmp & 0x0F ;
485+ tmp = tmp << (map[i] * 4 );
486+ res |= tmp;
476487 }
488+ output_data_ptr[linear_idx] = res;
477489 }
478490}
479491
@@ -495,9 +507,11 @@ void weight_permute_gpu_w4a8(const GPUContext& dev_ctx,
495507 << " block size = " << block_size;
496508 if (arch > 70 ) {
497509 if (algo == " w4a8" ) {
498- dim3 block_dim (32 , block_size / 32 );
510+ dim3 block_dim (128 );
499511 weight_permute_transpose_interleave_kernel_w4a8<<<grid_size, block_dim>>>(
500512 input_data, output_data, numel, total_k, total_n);
513+ w4a8_inplace_permute<<<grid_size, block_dim>>>(
514+ reinterpret_cast <uint32_t *>(output_data), numel / 8 );
501515 }
502516 } else {
503517 phi::errors::Unimplemented (
0 commit comments