Skip to content

Optimize euclidian distance in raft refine phase #2574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 327 additions & 3 deletions cpp/include/raft/neighbors/detail/refine_host-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,334 @@

#include <algorithm>

#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
#endif

namespace raft::neighbors::detail {

// -----------------------------------------------------------------------------
// Generic implementation
// -----------------------------------------------------------------------------

template <typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) {
// vector register capacity in elements
size_t constexpr vreg_len = (128 / 8) / sizeof(DistanceT);
// unroll factor = vector register capacity * number of ports;
size_t constexpr unroll_factor = vreg_len * 4;

// unroll factor is a power of two
size_t n_rounded = n & (0xFFFFFFFF ^ (unroll_factor - 1));
DistanceT distance[unroll_factor] = {0};

for (size_t i = 0; i < n_rounded; i += unroll_factor) {
for (size_t j = 0; j < unroll_factor; ++j) {
distance[j] += DC::template eval<DistanceT>(a[i + j], b[i + j]);
}
}

for (size_t i = n_rounded; i < n; ++i) {
distance[i] += DC::template eval<DistanceT>(a[i], b[i]);
}

for (size_t i = 1; i < unroll_factor; ++i) {
distance[0] += distance[i];
}

return distance[0];
}

// -----------------------------------------------------------------------------
// NEON implementation
// -----------------------------------------------------------------------------

struct distance_comp_l2;
struct distance_comp_inner;

// fallback
template<typename DC, typename DistanceT, typename DataT>
DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n) {
return euclidean_distance_squared_generic<DC, DistanceT, DataT>(a, b, n);
}

#if defined(__arm__) || defined(__aarch64__)

template<>
inline float euclidean_distance_squared<distance_comp_l2, float, float>(
float const* a, float const* b, size_t n) {

int n_rounded = n - (n % 4);

float32x4_t vreg_dsum = vdupq_n_f32(0.f);
for (int i = 0; i < n_rounded; i += 4) {
float32x4_t vreg_a = vld1q_f32(&a[i]);
float32x4_t vreg_b = vld1q_f32(&b[i]);
float32x4_t vreg_d = vsubq_f32(vreg_a, vreg_b);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d);
}

float dsum = vaddvq_f32(vreg_dsum);
for (int i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d;
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::int8_t>(
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {

int n_rounded = n - (n % 16);
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (int i = 0; i < n_rounded; i += 16) {
int8x16_t vreg_a = vld1q_s8(&a[i]);
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));

int8x16_t vreg_b = vld1q_s8(&b[i]);
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));

int16x8_t vreg_d_s16_0 = vsubq_s16(vreg_a_s16_0, vreg_b_s16_0);
int16x8_t vreg_d_s16_1 = vsubq_s16(vreg_a_s16_1, vreg_b_s16_1);

float32x4_t vreg_d_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_0)));
float32x4_t vreg_d_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_0)));
float32x4_t vreg_d_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_1)));
float32x4_t vreg_d_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_1)));

vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0);
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1);
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2);
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3);
}

vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (int i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_l2, float, ::std::uint8_t>(
::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {

int n_rounded = n - (n % 16);
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (int i = 0; i < n_rounded; i += 16) {
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));
float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));

uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));
float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_0)));
float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_0)));
float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_1)));
float32x4_t vreg_b_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_1)));

float32x4_t vreg_d_fp32_0 = vsubq_f32(vreg_a_fp32_0, vreg_b_fp32_0);
float32x4_t vreg_d_fp32_1 = vsubq_f32(vreg_a_fp32_1, vreg_b_fp32_1);
float32x4_t vreg_d_fp32_2 = vsubq_f32(vreg_a_fp32_2, vreg_b_fp32_2);
float32x4_t vreg_d_fp32_3 = vsubq_f32(vreg_a_fp32_3, vreg_b_fp32_3);

vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0);
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1);
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2);
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3);
}

vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (int i = n_rounded; i < n; ++i) {
float d = a[i] - b[i];
dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_inner, float, float>(
float const* a, float const* b, size_t n) {

int n_rounded = n - (n % 4);

float32x4_t vreg_dsum = vdupq_n_f32(0.f);
for (int i = 0; i < n_rounded; i += 4) {
float32x4_t vreg_a = vld1q_f32(&a[i]);
float32x4_t vreg_b = vld1q_f32(&b[i]);
vreg_a = vnegq_f32(vreg_a);
vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b);
}

float dsum = vaddvq_f32(vreg_dsum);
for (int i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::int8_t>(
::std::int8_t const* a, ::std::int8_t const* b, size_t n) {

int n_rounded = n - (n % 16);
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (int i = 0; i < n_rounded; i += 16) {
int8x16_t vreg_a = vld1q_s8(&a[i]);
int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a));
int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a));

int8x16_t vreg_b = vld1q_s8(&b[i]);
int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b));
int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b));

# if 1
vreg_a_s16_0 = vmulq_s16(vreg_a_s16_0, vreg_b_s16_0);
vreg_a_s16_1 = vmulq_s16(vreg_a_s16_1, vreg_b_s16_1);

float32x4_t vreg_res_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0)));
float32x4_t vreg_res_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0)));
float32x4_t vreg_res_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1)));
float32x4_t vreg_res_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1)));

vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0);
vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1);
vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2);
vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3);
#else
// TODO: WILL BE REMOVED BEFORE MERGE
vreg_a_s16_0 = vnegq_s16(vreg_a_s16_0);
vreg_a_s16_1 = vnegq_s16(vreg_a_s16_1);

float32x4_t vreg_a_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0)));
float32x4_t vreg_b_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_0)));
vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_a_fp32_0, vreg_b_fp32_0);
float32x4_t vreg_a_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0)));
float32x4_t vreg_b_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_0)));
vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_a_fp32_1, vreg_b_fp32_1);
float32x4_t vreg_a_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1)));
float32x4_t vreg_b_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_1)));
vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_a_fp32_2, vreg_b_fp32_2);
float32x4_t vreg_a_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1)));
float32x4_t vreg_b_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_1)));
vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_a_fp32_3, vreg_b_fp32_3);
#endif
}

vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (int i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
}

return dsum;
}

template<>
inline float euclidean_distance_squared<distance_comp_inner, float, ::std::uint8_t>(::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) {
int n_rounded = n - (n % 16);
float dsum = 0.f;

if (n_rounded > 0) {
float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f);
float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0;
float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0;

for (int i = 0; i < n_rounded; i += 16) {
uint8x16_t vreg_a = vld1q_u8(&a[i]);
uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a));
uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a));

uint8x16_t vreg_b = vld1q_u8(&b[i]);
uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b));
uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b));

vreg_a_u16_0 = vmulq_u16(vreg_a_u16_0, vreg_b_u16_0);
vreg_a_u16_1 = vmulq_u16(vreg_a_u16_1, vreg_b_u16_1);

float32x4_t vreg_res_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0)));
float32x4_t vreg_res_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0)));
float32x4_t vreg_res_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1)));
float32x4_t vreg_res_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1)));

vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0);
vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1);
vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2);
vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3);
}

vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1);
vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3);
vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2);

dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp
}

for (int i = n_rounded; i < n; ++i) {
dsum += -a[i] * b[i];
}

return dsum;
}

#endif // defined(__arm__) || defined(__aarch64__)

// -----------------------------------------------------------------------------
// Refine kernel
// -----------------------------------------------------------------------------

template <typename DC, typename IdxT, typename DataT, typename DistanceT, typename ExtentsT>
[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl(
raft::host_matrix_view<const DataT, ExtentsT, row_major> dataset,
Expand Down Expand Up @@ -112,9 +438,7 @@ template <typename DC, typename IdxT, typename DataT, typename DistanceT, typena
distance = std::numeric_limits<DistanceT>::max();
} else {
const DataT* row = dataset.data_handle() + dim * id;
for (size_t k = 0; k < dim; k++) {
distance += DC::template eval<DistanceT>(query[k], row[k]);
}
distance = euclidean_distance_squared<DC, DistanceT, DataT>(query, row, dim);
}
refined_pairs[j] = std::make_tuple(distance, id);
}
Expand Down