@@ -38,11 +38,12 @@ namespace detail {
38
38
// -----------------------------------------------------------------------------
39
39
40
40
template <typename DC, typename DistanceT, typename DataT>
41
- DistanceT euclidean_distance_squared_generic (DataT const * a, DataT const * b, size_t n) {
41
+ DistanceT euclidean_distance_squared_generic (DataT const * a, DataT const * b, size_t n)
42
+ {
42
43
size_t constexpr max_vreg_len = 512 / (8 * sizeof (DistanceT));
43
44
44
45
// max_vreg_len is a power of two
45
- size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1 ));
46
+ size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1 ));
46
47
DistanceT distance[max_vreg_len] = {0 };
47
48
48
49
for (size_t i = 0 ; i < n_rounded; i += max_vreg_len) {
@@ -70,42 +71,44 @@ struct distance_comp_l2;
70
71
struct distance_comp_inner ;
71
72
72
73
// fallback
73
- template <typename DC, typename DistanceT, typename DataT>
74
- DistanceT euclidean_distance_squared (DataT const * a, DataT const * b, size_t n) {
74
+ template <typename DC, typename DistanceT, typename DataT>
75
+ DistanceT euclidean_distance_squared (DataT const * a, DataT const * b, size_t n)
76
+ {
75
77
return euclidean_distance_squared_generic<DC, DistanceT, DataT>(a, b, n);
76
78
}
77
79
78
80
#if defined(__arm__) || defined(__aarch64__)
79
81
80
- template <>
81
- inline float euclidean_distance_squared<distance_comp_l2, float , float >(
82
- float const * a, float const * b, size_t n) {
83
-
82
+ template <>
83
+ inline float euclidean_distance_squared<distance_comp_l2, float , float >(float const * a,
84
+ float const * b,
85
+ size_t n)
86
+ {
84
87
size_t n_rounded = n - (n % 4 );
85
88
86
89
float32x4_t vreg_dsum = vdupq_n_f32 (0 .f );
87
90
for (size_t i = 0 ; i < n_rounded; i += 4 ) {
88
91
float32x4_t vreg_a = vld1q_f32 (&a[i]);
89
92
float32x4_t vreg_b = vld1q_f32 (&b[i]);
90
93
float32x4_t vreg_d = vsubq_f32 (vreg_a, vreg_b);
91
- vreg_dsum = vfmaq_f32 (vreg_dsum, vreg_d, vreg_d);
94
+ vreg_dsum = vfmaq_f32 (vreg_dsum, vreg_d, vreg_d);
92
95
}
93
96
94
97
float dsum = vaddvq_f32 (vreg_dsum);
95
98
for (size_t i = n_rounded; i < n; ++i) {
96
- float d = a[i] - b[i];
97
- dsum += d * d;
99
+ float d = a[i] - b[i];
100
+ dsum += d * d;
98
101
}
99
102
100
103
return dsum;
101
104
}
102
105
103
- template <>
106
+ template <>
104
107
inline float euclidean_distance_squared<distance_comp_l2, float , ::std::int8_t >(
105
- ::std::int8_t const * a, ::std::int8_t const * b, size_t n) {
106
-
108
+ ::std::int8_t const * a, ::std::int8_t const * b, size_t n)
109
+ {
107
110
size_t n_rounded = n - (n % 16 );
108
- float dsum = 0 .f ;
111
+ float dsum = 0 .f ;
109
112
110
113
if (n_rounded > 0 ) {
111
114
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
@@ -114,11 +117,11 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
114
117
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
115
118
116
119
for (size_t i = 0 ; i < n_rounded; i += 16 ) {
117
- int8x16_t vreg_a = vld1q_s8 (&a[i]);
120
+ int8x16_t vreg_a = vld1q_s8 (&a[i]);
118
121
int16x8_t vreg_a_s16_0 = vmovl_s8 (vget_low_s8 (vreg_a));
119
122
int16x8_t vreg_a_s16_1 = vmovl_s8 (vget_high_s8 (vreg_a));
120
123
121
- int8x16_t vreg_b = vld1q_s8 (&b[i]);
124
+ int8x16_t vreg_b = vld1q_s8 (&b[i]);
122
125
int16x8_t vreg_b_s16_0 = vmovl_s8 (vget_low_s8 (vreg_b));
123
126
int16x8_t vreg_b_s16_1 = vmovl_s8 (vget_high_s8 (vreg_b));
124
127
@@ -140,23 +143,23 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
140
143
vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
141
144
vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
142
145
143
- dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
146
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
144
147
}
145
148
146
149
for (size_t i = n_rounded; i < n; ++i) {
147
- float d = a[i] - b[i];
148
- dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
150
+ float d = a[i] - b[i];
151
+ dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
149
152
}
150
153
151
154
return dsum;
152
155
}
153
156
154
- template <>
157
+ template <>
155
158
inline float euclidean_distance_squared<distance_comp_l2, float , ::std::uint8_t >(
156
- ::std::uint8_t const * a, ::std::uint8_t const * b, size_t n) {
157
-
159
+ ::std::uint8_t const * a, ::std::uint8_t const * b, size_t n)
160
+ {
158
161
size_t n_rounded = n - (n % 16 );
159
- float dsum = 0 .f ;
162
+ float dsum = 0 .f ;
160
163
161
164
if (n_rounded > 0 ) {
162
165
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
@@ -165,17 +168,17 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
165
168
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
166
169
167
170
for (size_t i = 0 ; i < n_rounded; i += 16 ) {
168
- uint8x16_t vreg_a = vld1q_u8 (&a[i]);
169
- uint16x8_t vreg_a_u16_0 = vmovl_u8 (vget_low_u8 (vreg_a));
170
- uint16x8_t vreg_a_u16_1 = vmovl_u8 (vget_high_u8 (vreg_a));
171
+ uint8x16_t vreg_a = vld1q_u8 (&a[i]);
172
+ uint16x8_t vreg_a_u16_0 = vmovl_u8 (vget_low_u8 (vreg_a));
173
+ uint16x8_t vreg_a_u16_1 = vmovl_u8 (vget_high_u8 (vreg_a));
171
174
float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_a_u16_0)));
172
175
float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32 (vmovl_u16 (vget_high_u16 (vreg_a_u16_0)));
173
176
float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_a_u16_1)));
174
177
float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32 (vmovl_u16 (vget_high_u16 (vreg_a_u16_1)));
175
178
176
- uint8x16_t vreg_b = vld1q_u8 (&b[i]);
177
- uint16x8_t vreg_b_u16_0 = vmovl_u8 (vget_low_u8 (vreg_b));
178
- uint16x8_t vreg_b_u16_1 = vmovl_u8 (vget_high_u8 (vreg_b));
179
+ uint8x16_t vreg_b = vld1q_u8 (&b[i]);
180
+ uint16x8_t vreg_b_u16_0 = vmovl_u8 (vget_low_u8 (vreg_b));
181
+ uint16x8_t vreg_b_u16_1 = vmovl_u8 (vget_high_u8 (vreg_b));
179
182
float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_b_u16_0)));
180
183
float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32 (vmovl_u16 (vget_high_u16 (vreg_b_u16_0)));
181
184
float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32 (vmovl_u16 (vget_low_u16 (vreg_b_u16_1)));
@@ -196,45 +199,46 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
196
199
vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
197
200
vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
198
201
199
- dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
202
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
200
203
}
201
204
202
205
for (size_t i = n_rounded; i < n; ++i) {
203
- float d = a[i] - b[i];
204
- dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
206
+ float d = a[i] - b[i];
207
+ dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
205
208
}
206
209
207
210
return dsum;
208
211
}
209
212
210
- template <>
211
- inline float euclidean_distance_squared<distance_comp_inner, float , float >(
212
- float const * a, float const * b, size_t n) {
213
-
213
+ template <>
214
+ inline float euclidean_distance_squared<distance_comp_inner, float , float >(float const * a,
215
+ float const * b,
216
+ size_t n)
217
+ {
214
218
size_t n_rounded = n - (n % 4 );
215
219
216
220
float32x4_t vreg_dsum = vdupq_n_f32 (0 .f );
217
221
for (size_t i = 0 ; i < n_rounded; i += 4 ) {
218
222
float32x4_t vreg_a = vld1q_f32 (&a[i]);
219
223
float32x4_t vreg_b = vld1q_f32 (&b[i]);
220
- vreg_a = vnegq_f32 (vreg_a);
221
- vreg_dsum = vfmaq_f32 (vreg_dsum, vreg_a, vreg_b);
224
+ vreg_a = vnegq_f32 (vreg_a);
225
+ vreg_dsum = vfmaq_f32 (vreg_dsum, vreg_a, vreg_b);
222
226
}
223
227
224
228
float dsum = vaddvq_f32 (vreg_dsum);
225
229
for (size_t i = n_rounded; i < n; ++i) {
226
- dsum += -a[i] * b[i];
230
+ dsum += -a[i] * b[i];
227
231
}
228
232
229
233
return dsum;
230
234
}
231
235
232
- template <>
236
+ template <>
233
237
inline float euclidean_distance_squared<distance_comp_inner, float , ::std::int8_t >(
234
- ::std::int8_t const * a, ::std::int8_t const * b, size_t n) {
235
-
238
+ ::std::int8_t const * a, ::std::int8_t const * b, size_t n)
239
+ {
236
240
size_t n_rounded = n - (n % 16 );
237
- float dsum = 0 .f ;
241
+ float dsum = 0 .f ;
238
242
239
243
if (n_rounded > 0 ) {
240
244
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
@@ -243,11 +247,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
243
247
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
244
248
245
249
for (size_t i = 0 ; i < n_rounded; i += 16 ) {
246
- int8x16_t vreg_a = vld1q_s8 (&a[i]);
250
+ int8x16_t vreg_a = vld1q_s8 (&a[i]);
247
251
int16x8_t vreg_a_s16_0 = vmovl_s8 (vget_low_s8 (vreg_a));
248
252
int16x8_t vreg_a_s16_1 = vmovl_s8 (vget_high_s8 (vreg_a));
249
253
250
- int8x16_t vreg_b = vld1q_s8 (&b[i]);
254
+ int8x16_t vreg_b = vld1q_s8 (&b[i]);
251
255
int16x8_t vreg_b_s16_0 = vmovl_s8 (vget_low_s8 (vreg_b));
252
256
int16x8_t vreg_b_s16_1 = vmovl_s8 (vget_high_s8 (vreg_b));
253
257
@@ -269,20 +273,22 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
269
273
vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
270
274
vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
271
275
272
- dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
276
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
273
277
}
274
278
275
279
for (size_t i = n_rounded; i < n; ++i) {
276
- dsum += -a[i] * b[i];
280
+ dsum += -a[i] * b[i];
277
281
}
278
282
279
283
return dsum;
280
284
}
281
285
282
- template <>
283
- inline float euclidean_distance_squared<distance_comp_inner, float , ::std::uint8_t >(::std::uint8_t const * a, ::std::uint8_t const * b, size_t n) {
286
+ template <>
287
+ inline float euclidean_distance_squared<distance_comp_inner, float , ::std::uint8_t >(
288
+ ::std::uint8_t const * a, ::std::uint8_t const * b, size_t n)
289
+ {
284
290
size_t n_rounded = n - (n % 16 );
285
- float dsum = 0 .f ;
291
+ float dsum = 0 .f ;
286
292
287
293
if (n_rounded > 0 ) {
288
294
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32 (0 .f );
@@ -291,11 +297,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
291
297
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
292
298
293
299
for (size_t i = 0 ; i < n_rounded; i += 16 ) {
294
- uint8x16_t vreg_a = vld1q_u8 (&a[i]);
300
+ uint8x16_t vreg_a = vld1q_u8 (&a[i]);
295
301
uint16x8_t vreg_a_u16_0 = vmovl_u8 (vget_low_u8 (vreg_a));
296
302
uint16x8_t vreg_a_u16_1 = vmovl_u8 (vget_high_u8 (vreg_a));
297
303
298
- uint8x16_t vreg_b = vld1q_u8 (&b[i]);
304
+ uint8x16_t vreg_b = vld1q_u8 (&b[i]);
299
305
uint16x8_t vreg_b_u16_0 = vmovl_u8 (vget_low_u8 (vreg_b));
300
306
uint16x8_t vreg_b_u16_1 = vmovl_u8 (vget_high_u8 (vreg_b));
301
307
@@ -317,17 +323,17 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
317
323
vreg_dsum_fp32_2 = vaddq_f32 (vreg_dsum_fp32_2, vreg_dsum_fp32_3);
318
324
vreg_dsum_fp32_0 = vaddq_f32 (vreg_dsum_fp32_0, vreg_dsum_fp32_2);
319
325
320
- dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
326
+ dsum = vaddvq_f32 (vreg_dsum_fp32_0); // faddp
321
327
}
322
328
323
329
for (size_t i = n_rounded; i < n; ++i) {
324
- dsum += -a[i] * b[i];
330
+ dsum += -a[i] * b[i];
325
331
}
326
332
327
333
return dsum;
328
334
}
329
335
330
- #endif // defined(__arm__) || defined(__aarch64__)
336
+ #endif // defined(__arm__) || defined(__aarch64__)
331
337
332
338
// -----------------------------------------------------------------------------
333
339
// Refine kernel
@@ -421,7 +427,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
421
427
distance = std::numeric_limits<DistanceT>::max ();
422
428
} else {
423
429
const DataT* row = dataset.data_handle () + dim * id;
424
- distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
430
+ distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
425
431
}
426
432
refined_pairs[tid][j] = std::make_tuple (distance, id);
427
433
}
0 commit comments