@@ -208,6 +208,145 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
208
208
return dsum;
209
209
}
210
210
211
+ template <>
212
+ inline float euclidean_distance_squared<distance_comp_inner, float , float >(
213
+ float const * a, float const * b, size_t n) {
214
+
215
+ int n_rounded = n - (n % 4 );
216
+
217
+ float32x4_t vreg_dsum = vdupq_n_f32 (0 .f );
218
+ for (int i = 0 ; i < n_rounded; i += 4 ) {
219
+ float32x4_t vreg_a = vld1q_f32 (&a[i]);
220
+ float32x4_t vreg_b = vld1q_f32 (&b[i]);
221
+ vreg_a = vnegq_f32 (vreg_a);
222
+ vreg_dsum = vfmaq_f32 (vreg_dsum, vreg_a, vreg_b);
223
+ }
224
+
225
+ float dsum = vaddvq_f32 (vreg_dsum);
226
+ for (int i = n_rounded; i < n; ++i) {
227
+ dsum += -a[i] * b[i];
228
+ }
229
+
230
+ return dsum;
231
+ }
232
+
233
+ template <>
234
+ inline float euclidean_distance_squared<distance_comp_inner, float , ::std::int8_t >(
235
+ ::std::int8_t const * a, ::std::int8_t const * b, size_t n) {
236
+
237
+ int n_rounded = n - (n % 16 );
238
+ float dsum = 0 .f ;
239
+
240
+ if (n_rounded > 0 ) {
241
+ float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
242
+ float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
243
+ float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
244
+ float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
245
+
246
+ for (int i = 0 ; i < n_rounded; i += 16 ) {
247
+ int8x16_t vreg_a = vld1q_s8 (&a[i]);
248
+ int16x8_t vreg_a_s16_0 = vmovl_s8 (vget_low_s8 (vreg_a));
249
+ int16x8_t vreg_a_s16_1 = vmovl_s8 (vget_high_s8 (vreg_a));
250
+
251
+ int8x16_t vreg_b = vld1q_s8 (&b[i]);
252
+ int16x8_t vreg_b_s16_0 = vmovl_s8 (vget_low_s8 (vreg_b));
253
+ int16x8_t vreg_b_s16_1 = vmovl_s8 (vget_high_s8 (vreg_b));
254
+
255
+ # if 1
256
+ vreg_a_s16_0 = vmulq_s16 (vreg_a_s16_0, vreg_b_s16_0);
257
+ vreg_a_s16_1 = vmulq_s16 (vreg_a_s16_1, vreg_b_s16_1);
258
+
259
+ float32x4_t vreg_res_fp32_0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_a_s16_0)));
260
+ float32x4_t vreg_res_fp32_1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_a_s16_0)));
261
+ float32x4_t vreg_res_fp32_2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_a_s16_1)));
262
+ float32x4_t vreg_res_fp32_3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_a_s16_1)));
263
+
264
+ vreg_dsum_fp32_0 = vsubq_f32 (vreg_dsum_fp32_0, vreg_res_fp32_0);
265
+ vreg_dsum_fp32_1 = vsubq_f32 (vreg_dsum_fp32_1, vreg_res_fp32_1);
266
+ vreg_dsum_fp32_2 = vsubq_f32 (vreg_dsum_fp32_2, vreg_res_fp32_2);
267
+ vreg_dsum_fp32_3 = vsubq_f32 (vreg_dsum_fp32_3, vreg_res_fp32_3);
268
+ #else
269
+ // TODO: WILL BE REMOVED BEFORE MERGE
270
+ vreg_a_s16_0 = vnegq_s16 (vreg_a_s16_0);
271
+ vreg_a_s16_1 = vnegq_s16 (vreg_a_s16_1);
272
+
273
+ float32x4_t vreg_a_fp32_0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_a_s16_0)));
274
+ float32x4_t vreg_b_fp32_0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_b_s16_0)));
275
+ vreg_dsum_fp32_0 = vfmaq_f32 (vreg_dsum_fp32_0, vreg_a_fp32_0, vreg_b_fp32_0);
276
+ float32x4_t vreg_a_fp32_1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_a_s16_0)));
277
+ float32x4_t vreg_b_fp32_1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_b_s16_0)));
278
+ vreg_dsum_fp32_1 = vfmaq_f32 (vreg_dsum_fp32_1, vreg_a_fp32_1, vreg_b_fp32_1);
279
+ float32x4_t vreg_a_fp32_2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_a_s16_1)));
280
+ float32x4_t vreg_b_fp32_2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vreg_b_s16_1)));
281
+ vreg_dsum_fp32_2 = vfmaq_f32 (vreg_dsum_fp32_2, vreg_a_fp32_2, vreg_b_fp32_2);
282
+ float32x4_t vreg_a_fp32_3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_a_s16_1)));
283
+ float32x4_t vreg_b_fp32_3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vreg_b_s16_1)));
284
+ vreg_dsum_fp32_3 = vfmaq_f32 (vreg_dsum_fp32_3, vreg_a_fp32_3, vreg_b_fp32_3);
285
+ #endif
286
+ }
287
+
288
+ vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_1);
289
+ vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
290
+ vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
291
+
292
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
293
+ }
294
+
295
+ for (int i = n_rounded; i < n; ++i) {
296
+ dsum += -a[i] * b[i];
297
+ }
298
+
299
+ return dsum;
300
+ }
301
+
302
+ template <>
303
+ inline float euclidean_distance_squared<distance_comp_inner, float , ::std::uint8_t >(::std::uint8_t const * a, ::std::uint8_t const * b, size_t n) {
304
+ int n_rounded = n - (n % 16 );
305
+ float dsum = 0 .f ;
306
+
307
+ if (n_rounded > 0 ) {
308
+ float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
309
+ float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
310
+ float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
311
+ float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
312
+
313
+ for (int i = 0 ; i < n_rounded; i += 16 ) {
314
+ uint8x16_t vreg_a = vld1q_u8 (&a[i]);
315
+ uint16x8_t vreg_a_u16_0 = vmovl_u8 (vget_low_u8 (vreg_a));
316
+ uint16x8_t vreg_a_u16_1 = vmovl_u8 (vget_high_u8 (vreg_a));
317
+
318
+ uint8x16_t vreg_b = vld1q_u8 (&b[i]);
319
+ uint16x8_t vreg_b_u16_0 = vmovl_u8 (vget_low_u8 (vreg_b));
320
+ uint16x8_t vreg_b_u16_1 = vmovl_u8 (vget_high_u8 (vreg_b));
321
+
322
+ vreg_a_u16_0 = vmulq_u16 (vreg_a_u16_0, vreg_b_u16_0);
323
+ vreg_a_u16_1 = vmulq_u16 (vreg_a_u16_1, vreg_b_u16_1);
324
+
325
+ float32x4_t vreg_res_fp32_0 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_a_u16_0)));
326
+ float32x4_t vreg_res_fp32_1 = vcvtq_f32_u32 (vmovl_u16 (vget_high_u16 (vreg_a_u16_0)));
327
+ float32x4_t vreg_res_fp32_2 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_a_u16_1)));
328
+ float32x4_t vreg_res_fp32_3 = vcvtq_f32_u32 (vmovl_u16 (vget_high_u16 (vreg_a_u16_1)));
329
+
330
+ vreg_dsum_fp32_0 = vsubq_f32 (vreg_dsum_fp32_0, vreg_res_fp32_0);
331
+ vreg_dsum_fp32_1 = vsubq_f32 (vreg_dsum_fp32_1, vreg_res_fp32_1);
332
+ vreg_dsum_fp32_2 = vsubq_f32 (vreg_dsum_fp32_2, vreg_res_fp32_2);
333
+ vreg_dsum_fp32_3 = vsubq_f32 (vreg_dsum_fp32_3, vreg_res_fp32_3);
334
+ }
335
+
336
+ vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_1);
337
+ vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
338
+ vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
339
+
340
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
341
+ }
342
+
343
+ for (int i = n_rounded; i < n; ++i) {
344
+ dsum += -a[i] * b[i];
345
+ }
346
+
347
+ return dsum;
348
+ }
349
+
211
350
#endif // defined(__arm__) || defined(__aarch64__)
212
351
213
352
// -----------------------------------------------------------------------------
0 commit comments