Skip to content

Commit 836af24

Browse files
committed
Change code style with clang-format
1 parent 3a032ac commit 836af24

File tree

1 file changed

+63
-57
lines changed

1 file changed

+63
-57
lines changed

cpp/src/neighbors/refine/refine_host.hpp

+63-57
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ namespace detail {
3838
// -----------------------------------------------------------------------------
3939

4040
template <typename DC, typename DistanceT, typename DataT>
41-
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) {
41+
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n)
42+
{
4243
size_t constexpr max_vreg_len = 512 / (8 * sizeof(DistanceT));
4344

4445
// max_vreg_len is a power of two
45-
size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1));
46+
size_t n_rounded = n & (0xFFFFFFFF ^ (max_vreg_len - 1));
4647
DistanceT distance[max_vreg_len] = {0};
4748

4849
for (size_t i = 0; i < n_rounded; i += max_vreg_len) {
@@ -70,42 +71,44 @@ struct distance_comp_l2;
7071
struct distance_comp_inner;
7172

7273
// fallback
73-
template<typename DC, typename DistanceT, typename DataT>
74-
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n) {
74+
template <typename DC, typename DistanceT, typename DataT>
75+
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n)
76+
{
7577
return euclidean_distance_squared_generic<DC, DistanceT, DataT>(a, b, n);
7678
}
7779

7880
#if defined(__arm__) || defined(__aarch64__)
7981

80-
template<>
81-
inline float euclidean_distance_squared<distance_comp_l2, float, float>(
82-
float const* a, float const* b, size_t n) {
83-
82+
template <>
83+
inline float euclidean_distance_squared<distance_comp_l2, float, float>(float const* a,
84+
float const* b,
85+
size_t n)
86+
{
8487
size_t n_rounded = n - (n % 4);
8588

8689
float32x4_t vreg_dsum = vdupq_n_f32(0.f);
8790
for (size_t i = 0; i < n_rounded; i += 4) {
8891
float32x4_t vreg_a = vld1q_f32(&a[i]);
8992
float32x4_t vreg_b = vld1q_f32(&b[i]);
9093
float32x4_t vreg_d = vsubq_f32(vreg_a, vreg_b);
91-
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
94+
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
9295
}
9396

9497
float dsum = vaddvq_f32(vreg_dsum);
9598
for (size_t i = n_rounded; i < n; ++i) {
96-
float d = a[i] - b[i];
97-
dsum += d * d;
99+
float d = a[i] - b[i];
100+
dsum += d * d;
98101
}
99102

100103
return dsum;
101104
}
102105

103-
template<>
106+
template <>
104107
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
105-
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {
106-
108+
::std::int8_t const* a, ::std::int8_t const* b, size_t n)
109+
{
107110
size_t n_rounded = n - (n % 16);
108-
float dsum = 0.f;
111+
float dsum = 0.f;
109112

110113
if (n_rounded > 0) {
111114
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
@@ -114,11 +117,11 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
114117
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
115118

116119
for (size_t i = 0; i < n_rounded; i += 16) {
117-
int8x16_t vreg_a = vld1q_s8(&a[i]);
120+
int8x16_t vreg_a = vld1q_s8(&a[i]);
118121
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
119122
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));
120123

121-
int8x16_t vreg_b = vld1q_s8(&b[i]);
124+
int8x16_t vreg_b = vld1q_s8(&b[i]);
122125
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
123126
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));
124127

@@ -140,23 +143,23 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
140143
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
141144
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
142145

143-
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
146+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
144147
}
145148

146149
for (size_t i = n_rounded; i < n; ++i) {
147-
float d = a[i] - b[i];
148-
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
150+
float d = a[i] - b[i];
151+
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
149152
}
150153

151154
return dsum;
152155
}
153156

154-
template<>
157+
template <>
155158
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>(
156-
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
157-
159+
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n)
160+
{
158161
size_t n_rounded = n - (n % 16);
159-
float dsum = 0.f;
162+
float dsum = 0.f;
160163

161164
if (n_rounded > 0) {
162165
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
@@ -165,17 +168,17 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
165168
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
166169

167170
for (size_t i = 0; i < n_rounded; i += 16) {
168-
uint8x16_t vreg_a = vld1q_u8(&a[i]);
169-
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
170-
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
171+
uint8x16_t vreg_a = vld1q_u8(&a[i]);
172+
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
173+
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
171174
float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
172175
float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
173176
float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
174177
float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));
175178

176-
uint8x16_t vreg_b = vld1q_u8(&b[i]);
177-
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
178-
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
179+
uint8x16_t vreg_b = vld1q_u8(&b[i]);
180+
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
181+
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
179182
float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_0)));
180183
float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_0)));
181184
float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_1)));
@@ -196,45 +199,46 @@ inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>
196199
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
197200
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
198201

199-
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
202+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
200203
}
201204

202205
for (size_t i = n_rounded; i < n; ++i) {
203-
float d = a[i] - b[i];
204-
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
206+
float d = a[i] - b[i];
207+
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
205208
}
206209

207210
return dsum;
208211
}
209212

210-
template<>
211-
inline float euclidean_distance_squared<distance_comp_inner, float, float>(
212-
float const* a, float const* b, size_t n) {
213-
213+
template <>
214+
inline float euclidean_distance_squared<distance_comp_inner, float, float>(float const* a,
215+
float const* b,
216+
size_t n)
217+
{
214218
size_t n_rounded = n - (n % 4);
215219

216220
float32x4_t vreg_dsum = vdupq_n_f32(0.f);
217221
for (size_t i = 0; i < n_rounded; i += 4) {
218222
float32x4_t vreg_a = vld1q_f32(&a[i]);
219223
float32x4_t vreg_b = vld1q_f32(&b[i]);
220-
vreg_a = vnegq_f32(vreg_a);
221-
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
224+
vreg_a = vnegq_f32(vreg_a);
225+
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
222226
}
223227

224228
float dsum = vaddvq_f32(vreg_dsum);
225229
for (size_t i = n_rounded; i < n; ++i) {
226-
dsum += -a[i] * b[i];
230+
dsum += -a[i] * b[i];
227231
}
228232

229233
return dsum;
230234
}
231235

232-
template<>
236+
template <>
233237
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_t>(
234-
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {
235-
238+
::std::int8_t const* a, ::std::int8_t const* b, size_t n)
239+
{
236240
size_t n_rounded = n - (n % 16);
237-
float dsum = 0.f;
241+
float dsum = 0.f;
238242

239243
if (n_rounded > 0) {
240244
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
@@ -243,11 +247,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
243247
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
244248

245249
for (size_t i = 0; i < n_rounded; i += 16) {
246-
int8x16_t vreg_a = vld1q_s8(&a[i]);
250+
int8x16_t vreg_a = vld1q_s8(&a[i]);
247251
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
248252
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));
249253

250-
int8x16_t vreg_b = vld1q_s8(&b[i]);
254+
int8x16_t vreg_b = vld1q_s8(&b[i]);
251255
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
252256
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));
253257

@@ -269,20 +273,22 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_
269273
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
270274
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
271275

272-
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
276+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
273277
}
274278

275279
for (size_t i = n_rounded; i < n; ++i) {
276-
dsum += -a[i] * b[i];
280+
dsum += -a[i] * b[i];
277281
}
278282

279283
return dsum;
280284
}
281285

282-
template<>
283-
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
286+
template <>
287+
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(
288+
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n)
289+
{
284290
size_t n_rounded = n - (n % 16);
285-
float dsum = 0.f;
291+
float dsum = 0.f;
286292

287293
if (n_rounded > 0) {
288294
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
@@ -291,11 +297,11 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
291297
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;
292298

293299
for (size_t i = 0; i < n_rounded; i += 16) {
294-
uint8x16_t vreg_a = vld1q_u8(&a[i]);
300+
uint8x16_t vreg_a = vld1q_u8(&a[i]);
295301
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
296302
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
297303

298-
uint8x16_t vreg_b = vld1q_u8(&b[i]);
304+
uint8x16_t vreg_b = vld1q_u8(&b[i]);
299305
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
300306
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
301307

@@ -317,17 +323,17 @@ inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8
317323
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
318324
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);
319325

320-
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
326+
dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
321327
}
322328

323329
for (size_t i = n_rounded; i < n; ++i) {
324-
dsum += -a[i] * b[i];
330+
dsum += -a[i] * b[i];
325331
}
326332

327333
return dsum;
328334
}
329335

330-
#endif // defined(__arm__) || defined(__aarch64__)
336+
#endif // defined(__arm__) || defined(__aarch64__)
331337

332338
// -----------------------------------------------------------------------------
333339
// Refine kernel
@@ -421,7 +427,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
421427
distance = std::numeric_limits<DistanceT>::max();
422428
} else {
423429
const DataT* row = dataset.data_handle() + dim * id;
424-
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
430+
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
425431
}
426432
refined_pairs[tid][j] = std::make_tuple(distance, id);
427433
}

0 commit comments

Comments
 (0)