Skip to content

Commit 61f5b5e

Browse files
committed
Add: Reusable reductions
1 parent b55f013 commit 61f5b5e

File tree

10 files changed

+1397
-224
lines changed

10 files changed

+1397
-224
lines changed

include/simsimd/binary.h

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@
136136

137137
#include "types.h"
138138

139+
#include "reduce.h"
140+
139141
#if defined(__cplusplus)
140142
extern "C" {
141143
#endif
@@ -377,21 +379,6 @@ SIMSIMD_INTERNAL void simsimd_jaccard_b128_finalize_serial(
377379
#pragma GCC target("arch=armv8-a+simd")
378380
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
379381

380-
SIMSIMD_INTERNAL simsimd_u32_t _simsimd_reduce_u8x16_neon(uint8x16_t vec) {
381-
// Split the vector into two halves and widen to `uint16x8_t`
382-
uint16x8_t low_u16x8 = vmovl_u8(vget_low_u8(vec)); // widen lower 8 elements
383-
uint16x8_t high_u16x8 = vmovl_u8(vget_high_u8(vec)); // widen upper 8 elements
384-
385-
// Sum the widened halves
386-
uint16x8_t sum_u16x8 = vaddq_u16(low_u16x8, high_u16x8);
387-
388-
// Now reduce the `uint16x8_t` to a single `simsimd_u32_t`
389-
uint32x4_t sum_u32x4 = vpaddlq_u16(sum_u16x8); // pairwise add into 32-bit integers
390-
uint64x2_t sum_u64x2 = vpaddlq_u32(sum_u32x4); // pairwise add into 64-bit integers
391-
simsimd_u32_t final_sum = vaddvq_u64(sum_u64x2); // final horizontal add to 32-bit result
392-
return final_sum;
393-
}
394-
395382
SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const *a, simsimd_b8_t const *b, simsimd_size_t n_words,
396383
simsimd_u32_t *result) {
397384
simsimd_u32_t differences = 0;
@@ -408,7 +395,7 @@ SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const *a, simsimd_b8_t
408395
uint8x16_t xor_popcount_u8x16 = vcntq_u8(veorq_u8(a_u8x16, b_u8x16));
409396
popcount_u8x16 = vaddq_u8(popcount_u8x16, xor_popcount_u8x16);
410397
}
411-
differences += _simsimd_reduce_u8x16_neon(popcount_u8x16);
398+
differences += _simsimd_reduce_add_u8x16_neon(popcount_u8x16);
412399
}
413400
// Handle the tail
414401
for (; i != n_words; ++i) differences += simsimd_popcount_b8(a[i] ^ b[i]);
@@ -432,8 +419,8 @@ SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const *a, simsimd_b8_t
432419
intersection_popcount_u8x16 = vaddq_u8(intersection_popcount_u8x16, vcntq_u8(vandq_u8(a_u8x16, b_u8x16)));
433420
union_popcount_u8x16 = vaddq_u8(union_popcount_u8x16, vcntq_u8(vorrq_u8(a_u8x16, b_u8x16)));
434421
}
435-
intersection_count += _simsimd_reduce_u8x16_neon(intersection_popcount_u8x16);
436-
union_count += _simsimd_reduce_u8x16_neon(union_popcount_u8x16);
422+
intersection_count += _simsimd_reduce_add_u8x16_neon(intersection_popcount_u8x16);
423+
union_count += _simsimd_reduce_add_u8x16_neon(union_popcount_u8x16);
437424
}
438425
// Handle the tail
439426
for (; i != n_words; ++i)

include/simsimd/curved.h

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -617,22 +617,6 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_neon(simsimd_f16_t const *a, simsimd
617617
*result = _simsimd_sqrt_f32_neon(sum);
618618
}
619619

620-
SIMSIMD_INTERNAL simsimd_f32_t _simsimd_reduce_f16x8_neon(float16x8_t vec) {
621-
// Split the 8-element vector into two 4-element vectors
622-
float16x4_t low = vget_low_f16(vec); // Lower 4 elements
623-
float16x4_t high = vget_high_f16(vec); // Upper 4 elements
624-
625-
// Add the lower and upper parts
626-
float16x4_t sum = vadd_f16(low, high);
627-
628-
// Perform pairwise addition to reduce 4 elements to 2, then to 1
629-
sum = vpadd_f16(sum, sum); // First reduction: 4 -> 2
630-
sum = vpadd_f16(sum, sum); // Second reduction: 2 -> 1
631-
632-
// Convert the remaining half-precision value to single-precision and return
633-
return vgetq_lane_f32(vcvt_f32_f16(sum), 0);
634-
}
635-
636620
SIMSIMD_INTERNAL float16x8x2_t _simsimd_partial_load_f16x8x2_neon(simsimd_f16c_t const *x, simsimd_size_t n) {
637621
union {
638622
float16x8x2_t vecs;
@@ -689,8 +673,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16c_neon(simsimd_f16c_t const *a, simsimd_
689673
}
690674

691675
simsimd_f32c_t cb_j;
692-
cb_j.real = _simsimd_reduce_f16x8_neon(cb_j_real_f16x8);
693-
cb_j.imag = _simsimd_reduce_f16x8_neon(cb_j_imag_f16x8);
676+
cb_j.real = _simsimd_reduce_add_f16x8_neon(cb_j_real_f16x8);
677+
cb_j.imag = _simsimd_reduce_add_f16x8_neon(cb_j_imag_f16x8);
694678
sum_real += a_i.real * cb_j.real - a_i.imag * cb_j.imag;
695679
sum_imag += a_i.real * cb_j.imag + a_i.imag * cb_j.real;
696680
}
@@ -903,15 +887,15 @@ SIMSIMD_PUBLIC void simsimd_bilinear_f16_haswell(simsimd_f16_t const *a, simsimd
903887
}
904888

905889
// Handle the tail of every row
906-
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_f32x8);
890+
simsimd_f32_t sum = _simsimd_reduce_add_f32x8_haswell(sum_f32x8);
907891
simsimd_size_t const tail_length = n % 8;
908892
simsimd_size_t const tail_start = n - tail_length;
909893
if (tail_length) {
910894
for (simsimd_size_t i = 0; i != n; ++i) {
911895
simsimd_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))));
912896
__m256 b_f32x8 = _simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length);
913897
__m256 c_f32x8 = _simsimd_partial_load_f16x8_haswell(c + i * n + tail_start, tail_length);
914-
simsimd_f32_t cb_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_f32x8, c_f32x8));
898+
simsimd_f32_t cb_j = _simsimd_reduce_add_f32x8_haswell(_mm256_mul_ps(b_f32x8, c_f32x8));
915899
sum += a_i * cb_j;
916900
}
917901
}
@@ -938,7 +922,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, sims
938922
}
939923

940924
// Handle the tail of every row
941-
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_f32x8);
925+
simsimd_f32_t sum = _simsimd_reduce_add_f32x8_haswell(sum_f32x8);
942926
simsimd_size_t const tail_length = n % 8;
943927
simsimd_size_t const tail_start = n - tail_length;
944928
if (tail_length) {
@@ -950,7 +934,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_haswell(simsimd_f16_t const *a, sims
950934
_simsimd_partial_load_f16x8_haswell(a + tail_start, tail_length),
951935
_simsimd_partial_load_f16x8_haswell(b + tail_start, tail_length));
952936
__m256 c_f32x8 = _simsimd_partial_load_f16x8_haswell(c + i * n + tail_start, tail_length);
953-
simsimd_f32_t cdiff_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
937+
simsimd_f32_t cdiff_j = _simsimd_reduce_add_f32x8_haswell(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
954938
sum += diff_i * cdiff_j;
955939
}
956940
}
@@ -976,7 +960,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
976960
}
977961

978962
// Handle the tail of every row
979-
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_f32x8);
963+
simsimd_f32_t sum = _simsimd_reduce_add_f32x8_haswell(sum_f32x8);
980964
simsimd_size_t const tail_length = n % 8;
981965
simsimd_size_t const tail_start = n - tail_length;
982966
if (tail_length) {
@@ -987,7 +971,7 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
987971
_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length));
988972
__m256 c_f32x8 = _simsimd_bf16x8_to_f32x8_haswell( //
989973
_simsimd_partial_load_bf16x8_haswell(c + i * n + tail_start, tail_length));
990-
simsimd_f32_t cb_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(b_f32x8, c_f32x8));
974+
simsimd_f32_t cb_j = _simsimd_reduce_add_f32x8_haswell(_mm256_mul_ps(b_f32x8, c_f32x8));
991975
sum += a_i * cb_j;
992976
}
993977
}
@@ -1017,7 +1001,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
10171001
}
10181002

10191003
// Handle the tail of every row
1020-
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_f32x8);
1004+
simsimd_f32_t sum = _simsimd_reduce_add_f32x8_haswell(sum_f32x8);
10211005
simsimd_size_t const tail_length = n % 8;
10221006
simsimd_size_t const tail_start = n - tail_length;
10231007
if (tail_length) {
@@ -1031,7 +1015,7 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
10311015
_simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length)));
10321016
__m256 c_f32x8 = _simsimd_bf16x8_to_f32x8_haswell(
10331017
_simsimd_partial_load_bf16x8_haswell(c + i * n + tail_start, tail_length));
1034-
simsimd_f32_t cdiff_j = _simsimd_reduce_f32x8_haswell(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
1018+
simsimd_f32_t cdiff_j = _simsimd_reduce_add_f32x8_haswell(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
10351019
sum += diff_i * cdiff_j;
10361020
}
10371021
}
@@ -1541,8 +1525,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16c_genoa(simsimd_bf16c_t const *a, simsi
15411525
j += 16;
15421526
if (j < n) goto simsimd_bilinear_bf16c_skylake_cycle;
15431527
// Horizontal sums are the expensive part of the computation:
1544-
simsimd_f64_t const cb_j_real = _simsimd_reduce_f32x16_skylake(cb_j_real_f32x16);
1545-
simsimd_f64_t const cb_j_imag = _simsimd_reduce_f32x16_skylake(cb_j_imag_f32x16);
1528+
simsimd_f64_t const cb_j_real = _simsimd_reduce_add_f32x16_skylake(cb_j_real_f32x16);
1529+
simsimd_f64_t const cb_j_imag = _simsimd_reduce_add_f32x16_skylake(cb_j_imag_f32x16);
15461530
sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag;
15471531
sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real;
15481532
}

0 commit comments

Comments
 (0)