Skip to content

Commit 5a2abba

Browse files
committed
Add euclidian distance optimization: generic + written in NEON
1 parent e15a112 commit 5a2abba

File tree

1 file changed

+188
-3
lines changed

1 file changed

+188
-3
lines changed

cpp/include/raft/neighbors/detail/refine_host-inl.hpp

Lines changed: 188 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,195 @@
2525

2626
#include <algorithm>
2727

28+
#if defined(__arm__) || defined(__aarch64__)
29+
#include <arm_neon.h>
30+
#endif
31+
2832
namespace raft::neighbors::detail {
2933

34+
// -----------------------------------------------------------------------------
35+
// Generic implementation
36+
// -----------------------------------------------------------------------------
37+
38+
template <typename DC, typename DistanceT, typename DataT>
39+
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) {
40+
// vector register capacity in elements
41+
size_t constexpr vreg_len = (128 / 8) / sizeof(DistanceT);
42+
// unroll factor = vector register capacity * number of ports;
43+
size_t constexpr unroll_factor = vreg_len * 4;
44+
45+
// unroll factor is a power of two
46+
size_t n_rounded = n & (0xFFFFFFFF ^ (unroll_factor - 1));
47+
DistanceT distance[unroll_factor] = {0};
48+
49+
for (size_t i = 0; i < n_rounded; i += unroll_factor) {
50+
for (size_t j = 0; j < unroll_factor; ++j) {
51+
distance[j] += DC::template eval<DistanceT>(a[i + j], b[i + j]);
52+
}
53+
}
54+
55+
for (size_t i = n_rounded; i < n; ++i) {
56+
distance[i] += DC::template eval<DistanceT>(a[i], b[i]);
57+
}
58+
59+
for (size_t i = 1; i < unroll_factor; ++i) {
60+
distance[0] += distance[i];
61+
}
62+
63+
return distance[0];
64+
}
65+
66+
// -----------------------------------------------------------------------------
67+
// NEON implementation
68+
// -----------------------------------------------------------------------------
69+
70+
struct distance_comp_l2;
71+
struct distance_comp_inner;
72+
73+
// fallback
74+
template<typename DC, typename DistanceT, typename DataT>
75+
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n) {
76+
return euclidean_distance_squared_generic<DC, DistanceT, DataT>(a, b, n);
77+
}
78+
79+
#if defined(__arm__) || defined(__aarch64__)
80+
81+
template<>
82+
inline float euclidean_distance_squared<distance_comp_l2, float, float>(
83+
float const* a, float const* b, size_t n) {
84+
85+
int n_rounded = n - (n % 4);
86+
87+
float32x4_t vreg_dsum = vdupq_n_f32(0.f);
88+
for (int i = 0; i < n_rounded; i += 4) {
89+
float32x4_t vreg_a = vld1q_f32(&a[i]);
90+
float32x4_t vreg_b = vld1q_f32(&b[i]);
91+
float32x4_t vreg_d = vsubq_f32(vreg_a, vreg_b);
92+
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
93+
}
94+
95+
float dsum = vaddvq_f32(vreg_dsum);
96+
for (int i = n_rounded; i < n; ++i) {
97+
float d = a[i] - b[i];
98+
dsum += d * d;
99+
}
100+
101+
return dsum;
102+
}
103+
104+
template<>
105+
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
106+
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {
107+
108+
int n_rounded = n - (n % 16);
109+
float dsum = 0.f;
110+
111+
if (n_rounded > 0) {
112+
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
113+
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
114+
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
115+
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
116+
117+
for (int i = 0; i < n_rounded; i += 16) {
118+
int8x16_t vreg_a = vld1q_s8(&a[i]);
119+
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
120+
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));
121+
122+
int8x16_t vreg_b = vld1q_s8(&b[i]);
123+
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
124+
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));
125+
126+
int16x8_t vreg_d_s16_0 = vsubq_s16(vreg_a_s16_0, vreg_b_s16_0);
127+
int16x8_t vreg_d_s16_1 = vsubq_s16(vreg_a_s16_1, vreg_b_s16_1);
128+
129+
float32x4_t vreg_d_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_0)));
130+
float32x4_t vreg_d_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_0)));
131+
float32x4_t vreg_d_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_1)));
132+
float32x4_t vreg_d_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_1)));
133+
134+
vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0);
135+
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1);
136+
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2);
137+
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3);
138+
}
139+
140+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
141+
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
142+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
143+
144+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
145+
}
146+
147+
for (int i = n_rounded; i < n; ++i) {
148+
float d = a[i] - b[i];
149+
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
150+
}
151+
152+
return dsum;
153+
}
154+
155+
template<>
156+
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>(
157+
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
158+
159+
int n_rounded = n - (n % 16);
160+
float dsum = 0.f;
161+
162+
if (n_rounded > 0) {
163+
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
164+
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
165+
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
166+
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
167+
168+
for (int i = 0; i < n_rounded; i += 16) {
169+
uint8x16_t vreg_a = vld1q_u8(&a[i]);
170+
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
171+
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
172+
float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
173+
float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
174+
float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
175+
float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));
176+
177+
uint8x16_t vreg_b = vld1q_u8(&b[i]);
178+
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
179+
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
180+
float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_0)));
181+
float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_0)));
182+
float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_1)));
183+
float32x4_t vreg_b_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_1)));
184+
185+
float32x4_t vreg_d_fp32_0 = vsubq_f32(vreg_a_fp32_0, vreg_b_fp32_0);
186+
float32x4_t vreg_d_fp32_1 = vsubq_f32(vreg_a_fp32_1, vreg_b_fp32_1);
187+
float32x4_t vreg_d_fp32_2 = vsubq_f32(vreg_a_fp32_2, vreg_b_fp32_2);
188+
float32x4_t vreg_d_fp32_3 = vsubq_f32(vreg_a_fp32_3, vreg_b_fp32_3);
189+
190+
vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0);
191+
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1);
192+
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2);
193+
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3);
194+
}
195+
196+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
197+
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
198+
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
199+
200+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
201+
}
202+
203+
for (int i = n_rounded; i < n; ++i) {
204+
float d = a[i] - b[i];
205+
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
206+
}
207+
208+
return dsum;
209+
}
210+
211+
#endif // defined(__arm__) || defined(__aarch64__)
212+
213+
// -----------------------------------------------------------------------------
214+
// Refine kernel
215+
// -----------------------------------------------------------------------------
216+
30217
template <typename DC, typename IdxT, typename DataT, typename DistanceT, typename ExtentsT>
31218
[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl(
32219
raft::host_matrix_view<const DataT, ExtentsT, row_major> dataset,
@@ -112,9 +299,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
112299
distance = std::numeric_limits<DistanceT>::max();
113300
} else {
114301
const DataT* row = dataset.data_handle() + dim * id;
115-
for (size_t k = 0; k < dim; k++) {
116-
distance += DC::template eval<DistanceT>(query[k], row[k]);
117-
}
302+
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
118303
}
119304
refined_pairs[j] = std::make_tuple(distance, id);
120305
}

0 commit comments

Comments
 (0)