@@ -205,66 +205,22 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
205205 const float d2 = kai_cast_f32_bf16 (((uint16_t * )rhs_packed_scale )[nr_idx + 2 ]);
206206 const float d3 = kai_cast_f32_bf16 (((uint16_t * )rhs_packed_scale )[nr_idx + 3 ]);
207207
208- // Take zero-point (- 8) into account
208+ // Initialize partial sum taking new zero-point (8) into account
209209 int32_t partial_sum0 = - (32 * 8 );
210210 int32_t partial_sum1 = - (32 * 8 );
211211 int32_t partial_sum2 = - (32 * 8 );
212212 int32_t partial_sum3 = - (32 * 8 );
213213
214214 const uint8_t * src_block_base = rhs + ((k0_idx_i / 2 ) + dst_byte_idx );
215215
216- // Load elements as uint64_ts to calculate sums more efficiently
217- uint64_t ld0_0 = * (const uint64_t * )(src_block_base + n0_idx * rhs_stride );
218- uint64_t ld0_1 = * (const uint64_t * )(src_block_base + n0_idx * rhs_stride + 8 );
219-
220- uint64_t ld1_0 = * (const uint64_t * )(src_block_base + n1_idx * rhs_stride );
221- uint64_t ld1_1 = * (const uint64_t * )(src_block_base + n1_idx * rhs_stride + 8 );
222-
223- uint64_t ld2_0 = * (const uint64_t * )(src_block_base + n2_idx * rhs_stride );
224- uint64_t ld2_1 = * (const uint64_t * )(src_block_base + n2_idx * rhs_stride + 8 );
225-
226- uint64_t ld3_0 = * (const uint64_t * )(src_block_base + n3_idx * rhs_stride );
227- uint64_t ld3_1 = * (const uint64_t * )(src_block_base + n3_idx * rhs_stride + 8 );
228-
229- // Copy to vector registers
230- const uint8x8_t vld0_0 = vcreate_u8 (ld0_0 );
231- const uint8x8_t vld0_1 = vcreate_u8 (ld0_1 );
232-
233- const uint8x8_t vld1_0 = vcreate_u8 (ld1_0 );
234- const uint8x8_t vld1_1 = vcreate_u8 (ld1_1 );
235-
236- const uint8x8_t vld2_0 = vcreate_u8 (ld2_0 );
237- const uint8x8_t vld2_1 = vcreate_u8 (ld2_1 );
238-
239- const uint8x8_t vld3_0 = vcreate_u8 (ld3_0 );
240- const uint8x8_t vld3_1 = vcreate_u8 (ld3_1 );
241-
242- // Calculate sums
243- for (size_t idx = 0 ; idx < 16 ; ++ idx ) {
244- const int32_t e0_0 = (int32_t )(ld0_0 & 0x0F );
245- const int32_t e0_1 = (int32_t )(ld0_1 & 0x0F );
246- partial_sum0 += e0_0 + e0_1 ;
247- ld0_0 = ld0_0 >> 4 ;
248- ld0_1 = ld0_1 >> 4 ;
249-
250- const int32_t e1_0 = (int32_t )(ld1_0 & 0x0F );
251- const int32_t e1_1 = (int32_t )(ld1_1 & 0x0F );
252- partial_sum1 += e1_0 + e1_1 ;
253- ld1_0 = ld1_0 >> 4 ;
254- ld1_1 = ld1_1 >> 4 ;
255-
256- const int32_t e2_0 = (int32_t )(ld2_0 & 0x0F );
257- const int32_t e2_1 = (int32_t )(ld2_1 & 0x0F );
258- partial_sum2 += e2_0 + e2_1 ;
259- ld2_0 = ld2_0 >> 4 ;
260- ld2_1 = ld2_1 >> 4 ;
261-
262- const int32_t e3_0 = (int32_t )(ld3_0 & 0x0F );
263- const int32_t e3_1 = (int32_t )(ld3_1 & 0x0F );
264- partial_sum3 += e3_0 + e3_1 ;
265- ld3_0 = ld3_0 >> 4 ;
266- ld3_1 = ld3_1 >> 4 ;
267- }
216+ const uint8x8_t vld0_0 = vld1_u8 (src_block_base + n0_idx * rhs_stride );
217+ const uint8x8_t vld0_1 = vld1_u8 (src_block_base + n0_idx * rhs_stride + 8 );
218+ const uint8x8_t vld1_0 = vld1_u8 (src_block_base + n1_idx * rhs_stride );
219+ const uint8x8_t vld1_1 = vld1_u8 (src_block_base + n1_idx * rhs_stride + 8 );
220+ const uint8x8_t vld2_0 = vld1_u8 (src_block_base + n2_idx * rhs_stride );
221+ const uint8x8_t vld2_1 = vld1_u8 (src_block_base + n2_idx * rhs_stride + 8 );
222+ const uint8x8_t vld3_0 = vld1_u8 (src_block_base + n3_idx * rhs_stride );
223+ const uint8x8_t vld3_1 = vld1_u8 (src_block_base + n3_idx * rhs_stride + 8 );
268224
269225 const uint8x8_t vld0_s1s = vand_u8 (vld0_0 , bottom_mask );
270226 const uint8x8_t vld0_s0s = vshr_n_u8 (vld0_0 , 4 );
@@ -325,7 +281,16 @@ void kai_run_rhs_pack_nxk_qsi4c32pnrx8_qsu4c32s1s0_neon(
325281 (uint8_t * )dst_row + (nr * block_length_in_bytes ) + 24 ,
326282 veor_u8 (vld3_s16s0s_upper , zero_point_conversion_mask ));
327283
328- // Add to row sums
284+ // Calculate and store row sums
285+ partial_sum0 += (int32_t )vaddlvq_u16 (vaddl_u8 (
286+ vadd_u8 (vld0_s1s , vand_u8 (vld0_1 , bottom_mask )), vadd_u8 (vld0_s0s , vshr_n_u8 (vld0_1 , 4 ))));
287+ partial_sum1 += (int32_t )vaddlvq_u16 (vaddl_u8 (
288+ vadd_u8 (vld1_s1s , vand_u8 (vld1_1 , bottom_mask )), vadd_u8 (vld1_s0s , vshr_n_u8 (vld1_1 , 4 ))));
289+ partial_sum2 += (int32_t )vaddlvq_u16 (vaddl_u8 (
290+ vadd_u8 (vld2_s1s , vand_u8 (vld2_1 , bottom_mask )), vadd_u8 (vld2_s0s , vshr_n_u8 (vld2_1 , 4 ))));
291+ partial_sum3 += (int32_t )vaddlvq_u16 (vaddl_u8 (
292+ vadd_u8 (vld3_s1s , vand_u8 (vld3_1 , bottom_mask )), vadd_u8 (vld3_s0s , vshr_n_u8 (vld3_1 , 4 ))));
293+
329294 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
330295 sums [nr_idx + 0 ] += (float )partial_sum0 * d0 ;
331296 sums [nr_idx + 1 ] += (float )partial_sum1 * d1 ;
0 commit comments