Skip to content

Commit 2326609

Browse files
Evie Wrightdoviethoa-at-work
authored andcommitted
Improve packing performance for quantized Int4 per-block
Improves performance of ‘kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon’ by vectorizing row summation Signed-off-by: Evie Wright <evie.wright@arm.com> Approved-by: Anton Bondarenko <anton.bondarenko@arm.com>
1 parent d18f620 commit 2326609

2 files changed

Lines changed: 20 additions & 54 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ KleidiAI follows the [Semantic Versioning](https://semver.org/) specification fo
1212

1313
- New Advanced SIMD micro-kernels:
1414
- Optimized version of kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0 kernel for block depth of 4 bytes (`kai_rhs_pack_nxk_qsi4c32pnrx4_qsu4c32s1s0_neon`)
15+
- Improve performance of `kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon`
1516

1617
## v1.10.0
1718

kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon.c

Lines changed: 19 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -205,66 +205,22 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
205205
const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]);
206206
const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]);
207207

208-
// Take zero-point (-8) into account
208+
// Initialize partial sum taking new zero-point (8) into account
209209
int32_t partial_sum0 = -(32 * 8);
210210
int32_t partial_sum1 = -(32 * 8);
211211
int32_t partial_sum2 = -(32 * 8);
212212
int32_t partial_sum3 = -(32 * 8);
213213

214214
const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx);
215215

216-
// Load elements as uint64_ts to calculate sums more efficiently
217-
uint64_t ld0_0 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride);
218-
uint64_t ld0_1 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride + 8);
219-
220-
uint64_t ld1_0 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride);
221-
uint64_t ld1_1 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride + 8);
222-
223-
uint64_t ld2_0 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride);
224-
uint64_t ld2_1 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride + 8);
225-
226-
uint64_t ld3_0 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride);
227-
uint64_t ld3_1 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride + 8);
228-
229-
// Copy to vector registers
230-
const uint8x8_t vld0_0 = vcreate_u8(ld0_0);
231-
const uint8x8_t vld0_1 = vcreate_u8(ld0_1);
232-
233-
const uint8x8_t vld1_0 = vcreate_u8(ld1_0);
234-
const uint8x8_t vld1_1 = vcreate_u8(ld1_1);
235-
236-
const uint8x8_t vld2_0 = vcreate_u8(ld2_0);
237-
const uint8x8_t vld2_1 = vcreate_u8(ld2_1);
238-
239-
const uint8x8_t vld3_0 = vcreate_u8(ld3_0);
240-
const uint8x8_t vld3_1 = vcreate_u8(ld3_1);
241-
242-
// Calculate sums
243-
for (size_t idx = 0; idx < 16; ++idx) {
244-
const int32_t e0_0 = (int32_t)(ld0_0 & 0x0F);
245-
const int32_t e0_1 = (int32_t)(ld0_1 & 0x0F);
246-
partial_sum0 += e0_0 + e0_1;
247-
ld0_0 = ld0_0 >> 4;
248-
ld0_1 = ld0_1 >> 4;
249-
250-
const int32_t e1_0 = (int32_t)(ld1_0 & 0x0F);
251-
const int32_t e1_1 = (int32_t)(ld1_1 & 0x0F);
252-
partial_sum1 += e1_0 + e1_1;
253-
ld1_0 = ld1_0 >> 4;
254-
ld1_1 = ld1_1 >> 4;
255-
256-
const int32_t e2_0 = (int32_t)(ld2_0 & 0x0F);
257-
const int32_t e2_1 = (int32_t)(ld2_1 & 0x0F);
258-
partial_sum2 += e2_0 + e2_1;
259-
ld2_0 = ld2_0 >> 4;
260-
ld2_1 = ld2_1 >> 4;
261-
262-
const int32_t e3_0 = (int32_t)(ld3_0 & 0x0F);
263-
const int32_t e3_1 = (int32_t)(ld3_1 & 0x0F);
264-
partial_sum3 += e3_0 + e3_1;
265-
ld3_0 = ld3_0 >> 4;
266-
ld3_1 = ld3_1 >> 4;
267-
}
216+
const uint8x8_t vld0_0 = vld1_u8(src_block_base + n0_idx * rhs_stride);
217+
const uint8x8_t vld0_1 = vld1_u8(src_block_base + n0_idx * rhs_stride + 8);
218+
const uint8x8_t vld1_0 = vld1_u8(src_block_base + n1_idx * rhs_stride);
219+
const uint8x8_t vld1_1 = vld1_u8(src_block_base + n1_idx * rhs_stride + 8);
220+
const uint8x8_t vld2_0 = vld1_u8(src_block_base + n2_idx * rhs_stride);
221+
const uint8x8_t vld2_1 = vld1_u8(src_block_base + n2_idx * rhs_stride + 8);
222+
const uint8x8_t vld3_0 = vld1_u8(src_block_base + n3_idx * rhs_stride);
223+
const uint8x8_t vld3_1 = vld1_u8(src_block_base + n3_idx * rhs_stride + 8);
268224

269225
const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask);
270226
const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4);
@@ -325,7 +281,16 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
325281
(uint8_t*)dst_row + (nr * block_length_in_bytes) + 24,
326282
veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask));
327283

328-
// Add to row sums
284+
// Calculate and store row sums
285+
partial_sum0 += (int32_t)vaddlvq_u16(vaddl_u8(
286+
vadd_u8(vld0_s1s, vand_u8(vld0_1, bottom_mask)), vadd_u8(vld0_s0s, vshr_n_u8(vld0_1, 4))));
287+
partial_sum1 += (int32_t)vaddlvq_u16(vaddl_u8(
288+
vadd_u8(vld1_s1s, vand_u8(vld1_1, bottom_mask)), vadd_u8(vld1_s0s, vshr_n_u8(vld1_1, 4))));
289+
partial_sum2 += (int32_t)vaddlvq_u16(vaddl_u8(
290+
vadd_u8(vld2_s1s, vand_u8(vld2_1, bottom_mask)), vadd_u8(vld2_s0s, vshr_n_u8(vld2_1, 4))));
291+
partial_sum3 += (int32_t)vaddlvq_u16(vaddl_u8(
292+
vadd_u8(vld3_s1s, vand_u8(vld3_1, bottom_mask)), vadd_u8(vld3_s0s, vshr_n_u8(vld3_1, 4))));
293+
329294
// NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
330295
sums[nr_idx + 0] += (float)partial_sum0 * d0;
331296
sums[nr_idx + 1] += (float)partial_sum1 * d1;

0 commit comments

Comments
 (0)