Skip to content

Commit d6370a2

Browse files
committed
Add NEON implementations for distance_comp_inner comparator
1 parent 5a2abba commit d6370a2

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

cpp/include/raft/neighbors/detail/refine_host-inl.hpp

+139
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,145 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
208208
return dsum;
209209
}
210210

211+
template<>
212+
inline float euclidean_distance_squared<distance_comp_inner, float, float>(
213+
float const* a, float const* b, size_t n) {
214+
215+
int n_rounded = n - (n % 4);
216+
217+
float32x4_t vreg_dsum = vdupq_n_f32(0.f);
218+
for (int i = 0; i < n_rounded; i += 4) {
219+
float32x4_t vreg_a = vld1q_f32(&a[i]);
220+
float32x4_t vreg_b = vld1q_f32(&b[i]);
221+
vreg_a = vnegq_f32(vreg_a);
222+
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
223+
}
224+
225+
float dsum = vaddvq_f32(vreg_dsum);
226+
for (int i = n_rounded; i < n; ++i) {
227+
dsum += -a[i] * b[i];
228+
}
229+
230+
return dsum;
231+
}
232+
233+
template<>
234+
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_t>(
235+
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {
236+
237+
int n_rounded = n - (n % 16);
238+
float dsum = 0.f;
239+
240+
if (n_rounded > 0) {
241+
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
242+
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
243+
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
244+
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
245+
246+
for (int i = 0; i < n_rounded; i += 16) {
247+
int8x16_t vreg_a = vld1q_s8(&a[i]);
248+
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
249+
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));
250+
251+
int8x16_t vreg_b = vld1q_s8(&b[i]);
252+
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
253+
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));
254+
255+
# if 1
256+
vreg_a_s16_0 = vmulq_s16(vreg_a_s16_0, vreg_b_s16_0);
257+
vreg_a_s16_1 = vmulq_s16(vreg_a_s16_1, vreg_b_s16_1);
258+
259+
float32x4_t vreg_res_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0)));
260+
float32x4_t vreg_res_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0)));
261+
float32x4_t vreg_res_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1)));
262+
float32x4_t vreg_res_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1)));
263+
264+
vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0);
265+
vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1);
266+
vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2);
267+
vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3);
268+
#else
269+
// TODO: WILL BE REMOVED BEFORE MERGE
270+
vreg_a_s16_0 = vnegq_s16(vreg_a_s16_0);
271+
vreg_a_s16_1 = vnegq_s16(vreg_a_s16_1);
272+
273+
float32x4_t vreg_a_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0)));
274+
float32x4_t vreg_b_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_0)));
275+
vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_a_fp32_0, vreg_b_fp32_0);
276+
float32x4_t vreg_a_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0)));
277+
float32x4_t vreg_b_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_0)));
278+
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_a_fp32_1, vreg_b_fp32_1);
279+
float32x4_t vreg_a_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1)));
280+
float32x4_t vreg_b_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_1)));
281+
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_a_fp32_2, vreg_b_fp32_2);
282+
float32x4_t vreg_a_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1)));
283+
float32x4_t vreg_b_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_1)));
284+
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_a_fp32_3, vreg_b_fp32_3);
285+
#endif
286+
}
287+
288+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
289+
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
290+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
291+
292+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
293+
}
294+
295+
for (int i = n_rounded; i < n; ++i) {
296+
dsum += -a[i] * b[i];
297+
}
298+
299+
return dsum;
300+
}
301+
302+
template<>
303+
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
304+
int n_rounded = n - (n % 16);
305+
float dsum = 0.f;
306+
307+
if (n_rounded > 0) {
308+
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
309+
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
310+
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
311+
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
312+
313+
for (int i = 0; i < n_rounded; i += 16) {
314+
uint8x16_t vreg_a = vld1q_u8(&a[i]);
315+
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
316+
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
317+
318+
uint8x16_t vreg_b = vld1q_u8(&b[i]);
319+
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
320+
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
321+
322+
vreg_a_u16_0 = vmulq_u16(vreg_a_u16_0, vreg_b_u16_0);
323+
vreg_a_u16_1 = vmulq_u16(vreg_a_u16_1, vreg_b_u16_1);
324+
325+
float32x4_t vreg_res_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
326+
float32x4_t vreg_res_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
327+
float32x4_t vreg_res_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
328+
float32x4_t vreg_res_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));
329+
330+
vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0);
331+
vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1);
332+
vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2);
333+
vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3);
334+
}
335+
336+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
337+
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
338+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
339+
340+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
341+
}
342+
343+
for (int i = n_rounded; i < n; ++i) {
344+
dsum += -a[i] * b[i];
345+
}
346+
347+
return dsum;
348+
}
349+
211350
#endif // defined(__arm__) || defined(__aarch64__)
212351

213352
// -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)