Skip to content

Commit dbf3d52

Browse files
committed
support wa4a8
1 parent 8536b48 commit dbf3d52

1 file changed

Lines changed: 44 additions & 30 deletions

File tree

paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -434,46 +434,58 @@ void weight_quant_gpu(const GPUContext& dev_ctx,
434434
}
435435
}
436436

437-
// pack int8 weight to 2int4 int one int8
438437
__global__ void weight_permute_transpose_interleave_kernel_w4a8(
439438
const int8_t* input_data_ptr,
440439
int8_t* output_data_ptr,
441440
int numel,
442441
int total_k,
443442
int total_n) {
444-
const int interleave = 4;
445-
const int interleave_group = 64;
446-
const int permute_group = 32;
447-
448-
for (int block_n = blockIdx.x; block_n < total_n / interleave;
449-
block_n += gridDim.x) {
450-
const int8_t* src_ptr = input_data_ptr + block_n * interleave;
451-
int8_t* dst_ptr = output_data_ptr + block_n * interleave * total_k / 2;
452-
453-
for (int block_k = threadIdx.y; block_k < total_k / interleave_group;
454-
block_k += blockDim.y) {
455-
const int8_t* src_ptr_1 = src_ptr + block_k * interleave_group * total_n;
456-
int8_t* dst_ptr_1 = dst_ptr + block_k * interleave_group * interleave / 2;
457-
458-
int tid_div_16 = threadIdx.x / 16;
459-
int tid_mod_16 = threadIdx.x % 16;
460-
461-
int src_offset = (tid_div_16 * permute_group + tid_mod_16) * total_n;
443+
// every 4 k-direction 4bit is packed to 2 int8, so here / 4.
444+
numel = numel / 4;
445+
for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
446+
linear_idx < numel;
447+
linear_idx += blockDim.x * gridDim.x) {
448+
const int k_group_id = linear_idx / total_n;
449+
const int n_id = linear_idx % total_n;
450+
451+
uint16_t res = 0;
452+
for (int j = 0; j < 4; j++) {
453+
const int k_id = k_group_id * 4 + j;
454+
uint16_t val = input_data_ptr[k_id * total_n + n_id];
455+
val = val & 0x0F;
456+
val = val << (j * 4);
457+
res |= val;
458+
}
462459

463-
#pragma unroll
464-
for (int idx = 0; idx < interleave; idx++) {
465-
const int8_t* src_ptr_2 = src_ptr_1 + idx;
466-
int8_t* dst_ptr_2 = dst_ptr_1 + idx * interleave_group / 2;
460+
constexpr int map[8] = {0, 2, 4, 6, 1, 3, 5, 7};
461+
// remember output(in 16 bit granularity)'shape is
462+
// [16, 4, total_k/64, total_n/4]
463+
// index is :
464+
// [k_group_id % 16, n_id % 4, k_group_id/16, n_id/4]
465+
const int32_t new_index = map[k_group_id % 8] + k_group_id % 16 / 8 * 8 +
466+
(n_id % 4) * 16 + k_group_id / 16 * (16 * 4) +
467+
n_id / 4 * (total_k);
468+
469+
reinterpret_cast<uint16_t*>(output_data_ptr)[new_index] = res;
470+
}
471+
}
467472

468-
int8_t tmp0 = src_ptr_2[src_offset];
469-
int8_t tmp1 = src_ptr_2[src_offset + permute_group / 2 * total_n];
473+
__global__ void w4a8_inplace_permute(uint32_t* output_data_ptr, int numel) {
474+
for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
475+
linear_idx < numel;
476+
linear_idx += blockDim.x * gridDim.x) {
477+
const uint32_t value = output_data_ptr[linear_idx];
470478

471-
int8_t packed_val = (tmp0 & 0x0f) | ((tmp1 & 0x0f) << 4);
479+
uint32_t res = 0;
472480

473-
int dst_offset = threadIdx.x;
474-
dst_ptr_2[dst_offset] = packed_val;
475-
}
481+
const int map[8] = {0, 2, 4, 6, 1, 3, 5, 7};
482+
for (int i = 0; i < 8; i++) {
483+
uint32_t tmp = value >> (i * 4);
484+
tmp = tmp & 0x0F;
485+
tmp = tmp << (map[i] * 4);
486+
res |= tmp;
476487
}
488+
output_data_ptr[linear_idx] = res;
477489
}
478490
}
479491

@@ -495,9 +507,11 @@ void weight_permute_gpu_w4a8(const GPUContext& dev_ctx,
495507
<< " block size = " << block_size;
496508
if (arch > 70) {
497509
if (algo == "w4a8") {
498-
dim3 block_dim(32, block_size / 32);
510+
dim3 block_dim(128);
499511
weight_permute_transpose_interleave_kernel_w4a8<<<grid_size, block_dim>>>(
500512
input_data, output_data, numel, total_k, total_n);
513+
w4a8_inplace_permute<<<grid_size, block_dim>>>(
514+
reinterpret_cast<uint32_t*>(output_data), numel / 8);
501515
}
502516
} else {
503517
phi::errors::Unimplemented(

0 commit comments

Comments
 (0)