diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp index 9aff451dfc..ff9282bf74 100644 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -25,8 +25,334 @@ #include +#if defined(__arm__) || defined(__aarch64__) +#include +#endif + namespace raft::neighbors::detail { +// ----------------------------------------------------------------------------- +// Generic implementation +// ----------------------------------------------------------------------------- + +template +DistanceT euclidean_distance_squared_generic(DataT const* a, DataT const* b, size_t n) { + // vector register capacity in elements + size_t constexpr vreg_len = (128 / 8) / sizeof(DistanceT); + // unroll factor = vector register capacity * number of ports; + size_t constexpr unroll_factor = vreg_len * 4; + + // unroll factor is a power of two + size_t n_rounded = n & (0xFFFFFFFF ^ (unroll_factor - 1)); + DistanceT distance[unroll_factor] = {0}; + + for (size_t i = 0; i < n_rounded; i += unroll_factor) { + for (size_t j = 0; j < unroll_factor; ++j) { + distance[j] += DC::template eval(a[i + j], b[i + j]); + } + } + + for (size_t i = n_rounded; i < n; ++i) { + distance[i] += DC::template eval(a[i], b[i]); + } + + for (size_t i = 1; i < unroll_factor; ++i) { + distance[0] += distance[i]; + } + + return distance[0]; +} + +// ----------------------------------------------------------------------------- +// NEON implementation +// ----------------------------------------------------------------------------- + +struct distance_comp_l2; +struct distance_comp_inner; + +// fallback +template +DistanceT euclidean_distance_squared(DataT const* a, DataT const* b, size_t n) { + return euclidean_distance_squared_generic(a, b, n); +} + +#if defined(__arm__) || defined(__aarch64__) + +template<> +inline float euclidean_distance_squared( + float const* a, float const* b, size_t n) { + + int n_rounded = n - (n % 4); + + float32x4_t vreg_dsum = vdupq_n_f32(0.f); + for (int i = 0; i < n_rounded; i += 4) { + float32x4_t vreg_a = vld1q_f32(&a[i]); + float32x4_t vreg_b = vld1q_f32(&b[i]); + float32x4_t vreg_d = vsubq_f32(vreg_a, vreg_b); + vreg_dsum = vfmaq_f32(vreg_dsum, vreg_d, vreg_d); + } + + float dsum = vaddvq_f32(vreg_dsum); + for (int i = n_rounded; i < n; ++i) { + float d = a[i] - b[i]; + dsum += d * d; + } + + return dsum; +} + +template<> +inline float euclidean_distance_squared( + ::std::int8_t const* a, ::std::int8_t const* b, size_t n) { + + int n_rounded = n - (n % 16); + float dsum = 0.f; + + if (n_rounded > 0) { + float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f); + float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0; + + for (int i = 0; i < n_rounded; i += 16) { + int8x16_t vreg_a = vld1q_s8(&a[i]); + int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a)); + int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a)); + + int8x16_t vreg_b = vld1q_s8(&b[i]); + int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b)); + int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b)); + + int16x8_t vreg_d_s16_0 = vsubq_s16(vreg_a_s16_0, vreg_b_s16_0); + int16x8_t vreg_d_s16_1 = vsubq_s16(vreg_a_s16_1, vreg_b_s16_1); + + float32x4_t vreg_d_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_0))); + float32x4_t vreg_d_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_0))); + float32x4_t vreg_d_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_d_s16_1))); + float32x4_t vreg_d_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_d_s16_1))); + + vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0); + vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1); + vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2); + vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3); + } + + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1); + vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3); + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2); + + dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp + } + + for (int i = n_rounded; i < n; ++i) { + float d = a[i] - b[i]; + dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda + } + + return dsum; +} + +template<> +inline float euclidean_distance_squared( + ::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) { + + int n_rounded = n - (n % 16); + float dsum = 0.f; + + if (n_rounded > 0) { + float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f); + float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0; + + for (int i = 0; i < n_rounded; i += 16) { + uint8x16_t vreg_a = vld1q_u8(&a[i]); + uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a)); + uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a)); + float32x4_t vreg_a_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0))); + float32x4_t vreg_a_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0))); + float32x4_t vreg_a_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1))); + float32x4_t vreg_a_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1))); + + uint8x16_t vreg_b = vld1q_u8(&b[i]); + uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b)); + uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b)); + float32x4_t vreg_b_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_0))); + float32x4_t vreg_b_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_0))); + float32x4_t vreg_b_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_b_u16_1))); + float32x4_t vreg_b_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_b_u16_1))); + + float32x4_t vreg_d_fp32_0 = vsubq_f32(vreg_a_fp32_0, vreg_b_fp32_0); + float32x4_t vreg_d_fp32_1 = vsubq_f32(vreg_a_fp32_1, vreg_b_fp32_1); + float32x4_t vreg_d_fp32_2 = vsubq_f32(vreg_a_fp32_2, vreg_b_fp32_2); + float32x4_t vreg_d_fp32_3 = vsubq_f32(vreg_a_fp32_3, vreg_b_fp32_3); + + vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_d_fp32_0, vreg_d_fp32_0); + vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_d_fp32_1, vreg_d_fp32_1); + vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_d_fp32_2, vreg_d_fp32_2); + vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_d_fp32_3, vreg_d_fp32_3); + } + + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1); + vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3); + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2); + + dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp + } + + for (int i = n_rounded; i < n; ++i) { + float d = a[i] - b[i]; + dsum += d * d; // [nvc++] faddp, [clang] fadda, [gcc] vecsum+fadda + } + + return dsum; +} + +template<> +inline float euclidean_distance_squared( + float const* a, float const* b, size_t n) { + + int n_rounded = n - (n % 4); + + float32x4_t vreg_dsum = vdupq_n_f32(0.f); + for (int i = 0; i < n_rounded; i += 4) { + float32x4_t vreg_a = vld1q_f32(&a[i]); + float32x4_t vreg_b = vld1q_f32(&b[i]); + vreg_a = vnegq_f32(vreg_a); + vreg_dsum = vfmaq_f32(vreg_dsum, vreg_a, vreg_b); + } + + float dsum = vaddvq_f32(vreg_dsum); + for (int i = n_rounded; i < n; ++i) { + dsum += -a[i] * b[i]; + } + + return dsum; +} + +template<> +inline float euclidean_distance_squared( + ::std::int8_t const* a, ::std::int8_t const* b, size_t n) { + + int n_rounded = n - (n % 16); + float dsum = 0.f; + + if (n_rounded > 0) { + float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f); + float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0; + + for (int i = 0; i < n_rounded; i += 16) { + int8x16_t vreg_a = vld1q_s8(&a[i]); + int16x8_t vreg_a_s16_0 = vmovl_s8(vget_low_s8(vreg_a)); + int16x8_t vreg_a_s16_1 = vmovl_s8(vget_high_s8(vreg_a)); + + int8x16_t vreg_b = vld1q_s8(&b[i]); + int16x8_t vreg_b_s16_0 = vmovl_s8(vget_low_s8(vreg_b)); + int16x8_t vreg_b_s16_1 = vmovl_s8(vget_high_s8(vreg_b)); + + # if 1 + vreg_a_s16_0 = vmulq_s16(vreg_a_s16_0, vreg_b_s16_0); + vreg_a_s16_1 = vmulq_s16(vreg_a_s16_1, vreg_b_s16_1); + + float32x4_t vreg_res_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0))); + float32x4_t vreg_res_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0))); + float32x4_t vreg_res_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1))); + float32x4_t vreg_res_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1))); + + vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0); + vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1); + vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2); + vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3); + #else + // TODO: WILL BE REMOVED BEFORE MERGE + vreg_a_s16_0 = vnegq_s16(vreg_a_s16_0); + vreg_a_s16_1 = vnegq_s16(vreg_a_s16_1); + + float32x4_t vreg_a_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_0))); + float32x4_t vreg_b_fp32_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_0))); + vreg_dsum_fp32_0 = vfmaq_f32(vreg_dsum_fp32_0, vreg_a_fp32_0, vreg_b_fp32_0); + float32x4_t vreg_a_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_0))); + float32x4_t vreg_b_fp32_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_0))); + vreg_dsum_fp32_1 = vfmaq_f32(vreg_dsum_fp32_1, vreg_a_fp32_1, vreg_b_fp32_1); + float32x4_t vreg_a_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_a_s16_1))); + float32x4_t vreg_b_fp32_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vreg_b_s16_1))); + vreg_dsum_fp32_2 = vfmaq_f32(vreg_dsum_fp32_2, vreg_a_fp32_2, vreg_b_fp32_2); + float32x4_t vreg_a_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_a_s16_1))); + float32x4_t vreg_b_fp32_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vreg_b_s16_1))); + vreg_dsum_fp32_3 = vfmaq_f32(vreg_dsum_fp32_3, vreg_a_fp32_3, vreg_b_fp32_3); + #endif + } + + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1); + vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3); + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2); + + dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp + } + + for (int i = n_rounded; i < n; ++i) { + dsum += -a[i] * b[i]; + } + + return dsum; +} + +template<> +inline float euclidean_distance_squared(::std::uint8_t const* a, ::std::uint8_t const* b, size_t n) { + int n_rounded = n - (n % 16); + float dsum = 0.f; + + if (n_rounded > 0) { + float32x4_t vreg_dsum_fp32_0 = vdupq_n_f32(0.f); + float32x4_t vreg_dsum_fp32_1 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_2 = vreg_dsum_fp32_0; + float32x4_t vreg_dsum_fp32_3 = vreg_dsum_fp32_0; + + for (int i = 0; i < n_rounded; i += 16) { + uint8x16_t vreg_a = vld1q_u8(&a[i]); + uint16x8_t vreg_a_u16_0 = vmovl_u8(vget_low_u8(vreg_a)); + uint16x8_t vreg_a_u16_1 = vmovl_u8(vget_high_u8(vreg_a)); + + uint8x16_t vreg_b = vld1q_u8(&b[i]); + uint16x8_t vreg_b_u16_0 = vmovl_u8(vget_low_u8(vreg_b)); + uint16x8_t vreg_b_u16_1 = vmovl_u8(vget_high_u8(vreg_b)); + + vreg_a_u16_0 = vmulq_u16(vreg_a_u16_0, vreg_b_u16_0); + vreg_a_u16_1 = vmulq_u16(vreg_a_u16_1, vreg_b_u16_1); + + float32x4_t vreg_res_fp32_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_0))); + float32x4_t vreg_res_fp32_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_0))); + float32x4_t vreg_res_fp32_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vreg_a_u16_1))); + float32x4_t vreg_res_fp32_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vreg_a_u16_1))); + + vreg_dsum_fp32_0 = vsubq_f32(vreg_dsum_fp32_0, vreg_res_fp32_0); + vreg_dsum_fp32_1 = vsubq_f32(vreg_dsum_fp32_1, vreg_res_fp32_1); + vreg_dsum_fp32_2 = vsubq_f32(vreg_dsum_fp32_2, vreg_res_fp32_2); + vreg_dsum_fp32_3 = vsubq_f32(vreg_dsum_fp32_3, vreg_res_fp32_3); + } + + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_1); + vreg_dsum_fp32_2 = vaddq_f32(vreg_dsum_fp32_2, vreg_dsum_fp32_3); + vreg_dsum_fp32_0 = vaddq_f32(vreg_dsum_fp32_0, vreg_dsum_fp32_2); + + dsum = vaddvq_f32(vreg_dsum_fp32_0); // faddp + } + + for (int i = n_rounded; i < n; ++i) { + dsum += -a[i] * b[i]; + } + + return dsum; +} + +#endif // defined(__arm__) || defined(__aarch64__) + +// ----------------------------------------------------------------------------- +// Refine kernel +// ----------------------------------------------------------------------------- + template [[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl( raft::host_matrix_view dataset, @@ -112,9 +438,7 @@ template ::max(); } else { const DataT* row = dataset.data_handle() + dim * id; - for (size_t k = 0; k < dim; k++) { - distance += DC::template eval(query[k], row[k]); - } + distance = euclidean_distance_squared(query, row, dim); } refined_pairs[j] = std::make_tuple(distance, id); }