Skip to content

Commit fb1e864

Browse files
committed
Make: +bf16 flag for sparse SVE2
1 parent 8af8dc4 commit fb1e864

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

include/simsimd/simsimd.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,8 +1044,10 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
10441044
if (viable & simsimd_cap_sve2_k)
10451045
switch (kind) {
10461046
case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_sve2, *c = simsimd_cap_sve2_k; return;
1047-
case simsimd_spdot_counts_u16_k: *m = (m_t)&simsimd_spdot_counts_u16_sve2, *c = simsimd_cap_sve2_k; return;
1048-
case simsimd_spdot_weights_u16_k:
1047+
case simsimd_metric_spdot_counts_k:
1048+
*m = (m_t)&simsimd_spdot_counts_u16_sve2, *c = simsimd_cap_sve2_k;
1049+
return;
1050+
case simsimd_metric_spdot_weights_k:
10491051
*m = (m_t)&simsimd_spdot_weights_u16_sve2, *c = simsimd_cap_sve2_k;
10501052
return;
10511053
default: break;
@@ -1062,10 +1064,10 @@ SIMSIMD_PUBLIC void simsimd_find_metric_punned( //
10621064
if (viable & simsimd_cap_turin_k)
10631065
switch (kind) {
10641066
case simsimd_metric_intersect_k: *m = (m_t)&simsimd_intersect_u16_turin, *c = simsimd_cap_turin_k; return;
1065-
case simsimd_spdot_counts_u16_k:
1067+
case simsimd_metric_spdot_counts_k:
10661068
*m = (m_t)&simsimd_spdot_counts_u16_turin, *c = simsimd_cap_turin_k;
10671069
return;
1068-
case simsimd_spdot_weights_u16_k:
1070+
case simsimd_metric_spdot_weights_k:
10691071
*m = (m_t)&simsimd_spdot_weights_u16_turin, *c = simsimd_cap_turin_k;
10701072
return;
10711073
default: break;

include/simsimd/sparse.h

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,17 +1247,17 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const* a, simsimd_u
12471247
*results = c;
12481248
}
12491249

1250-
SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
1251-
simsimd_u16_t const* a, simsimd_u16_t const* b, //
1252-
simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, //
1253-
simsimd_size_t a_length, simsimd_size_t b_length, //
1250+
SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
1251+
simsimd_u16_t const* a, simsimd_u16_t const* b, //
1252+
simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, //
1253+
simsimd_size_t a_length, simsimd_size_t b_length, //
12541254
simsimd_distance_t* results) {
12551255

12561256
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
12571257
simsimd_size_t const register_size = svcnth();
12581258
simsimd_size_t const lanes_count = register_size / 8;
12591259
simsimd_size_t a_idx = 0, b_idx = 0;
1260-
svfloat32_t product_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
1260+
svint64_t product_vec = svdupq_n_s64(0, 0);
12611261
simsimd_size_t intersection_size = 0;
12621262

12631263
while (a_idx < a_length && b_idx < b_length) {
@@ -1303,12 +1303,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
13031303
simsimd_u64_t b_step = svcntp_b16(b_progress, b_mask);
13041304

13051305
// Compare `a_vec` with each lane of `b_vec`
1306-
svbfloat16_t a_weights_vec = svld1_bf16(a_progress, a_weights + a_idx);
1307-
svbfloat16_t b_weights_vec = svld1_bf16(b_progress, b_weights + b_idx);
1306+
svint16_t a_weights_vec = svld1_s16(a_progress, a_weights + a_idx);
1307+
svint16_t b_weights_vec = svld1_s16(b_progress, b_weights + b_idx);
13081308
for (simsimd_size_t i = 0; i < lanes_count; i++) {
13091309
svbool_t equal_mask = svmatch_u16(a_progress, a_vec, b_vec);
1310-
svbfloat16_t b_equal_weights_vec = svsel_bf16(equal_mask, b_weights_vec, svdup_n_bf16(0.f));
1311-
product_vec = svbfdot_f32(product_vec, a_weights_vec, b_equal_weights_vec);
1310+
svint16_t b_equal_weights_vec = svsel_s16(equal_mask, b_weights_vec, svdup_n_s16(0.f));
1311+
product_vec = svdot_s64(product_vec, a_weights_vec, b_equal_weights_vec);
13121312
b_vec = svext_u16(b_vec, b_vec, 8);
13131313
intersection_size += svcntp_b16(svptrue_b16(), equal_mask);
13141314
}
@@ -1318,20 +1318,29 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
13181318
b_idx += b_step;
13191319
}
13201320
results[0] = (simsimd_distance_t)intersection_size;
1321-
results[1] = svaddv_f32(svptrue_b32(), product_vec);
1321+
results[1] = svaddv_s64(svptrue_b64(), product_vec);
13221322
}
13231323

1324-
SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
1325-
simsimd_u16_t const* a, simsimd_u16_t const* b, //
1326-
simsimd_i16_t const* a_weights, simsimd_i16_t const* b_weights, //
1327-
simsimd_size_t a_length, simsimd_size_t b_length, //
1324+
#pragma clang attribute pop
1325+
#pragma GCC pop_options
1326+
#endif // SIMSIMD_TARGET_SVE2
1327+
1328+
#if SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
1329+
#pragma GCC push_options
1330+
#pragma GCC target("arch=armv8.6-a+sve+sve2+bf16")
1331+
#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function)
1332+
1333+
SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
1334+
simsimd_u16_t const* a, simsimd_u16_t const* b, //
1335+
simsimd_bf16_t const* a_weights, simsimd_bf16_t const* b_weights, //
1336+
simsimd_size_t a_length, simsimd_size_t b_length, //
13281337
simsimd_distance_t* results) {
13291338

13301339
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
13311340
simsimd_size_t const register_size = svcnth();
13321341
simsimd_size_t const lanes_count = register_size / 8;
13331342
simsimd_size_t a_idx = 0, b_idx = 0;
1334-
svint64_t product_vec = svdupq_n_s64(0, 0);
1343+
svfloat32_t product_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
13351344
simsimd_size_t intersection_size = 0;
13361345

13371346
while (a_idx < a_length && b_idx < b_length) {
@@ -1377,12 +1386,15 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
13771386
simsimd_u64_t b_step = svcntp_b16(b_progress, b_mask);
13781387

13791388
// Compare `a_vec` with each lane of `b_vec`
1380-
svbfloat16_t a_weights_vec = svld1_s16(a_progress, a_weights + a_idx);
1381-
svbfloat16_t b_weights_vec = svld1_s16(b_progress, b_weights + b_idx);
1389+
svbfloat16_t a_weights_vec = svld1_bf16(a_progress, a_weights + a_idx);
1390+
svbfloat16_t b_weights_vec = svld1_bf16(b_progress, b_weights + b_idx);
13821391
for (simsimd_size_t i = 0; i < lanes_count; i++) {
13831392
svbool_t equal_mask = svmatch_u16(a_progress, a_vec, b_vec);
1384-
svbfloat16_t b_equal_weights_vec = svsel_s16(equal_mask, b_weights_vec, svdup_n_bf16(0.f));
1385-
product_vec = svdot_s64(product_vec, a_weights_vec, b_equal_weights_vec);
1393+
//! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
1394+
//! So we reinterprete floats as integers and apply `svsel_s16`.
1395+
svint16_t b_equal_weights_vec =
1396+
svsel_s16(equal_mask, svreinterpret_s16_bs16(b_weights_vec), svdup_n_s16(0));
1397+
product_vec = svbfdot_f32(product_vec, a_weights_vec, svreinterpret_bf16_s16(b_equal_weights_vec));
13861398
b_vec = svext_u16(b_vec, b_vec, 8);
13871399
intersection_size += svcntp_b16(svptrue_b16(), equal_mask);
13881400
}
@@ -1392,12 +1404,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
13921404
b_idx += b_step;
13931405
}
13941406
results[0] = (simsimd_distance_t)intersection_size;
1395-
results[1] = svaddv_s64(svptrue_b64(), product_vec);
1407+
results[1] = svaddv_f32(svptrue_b32(), product_vec);
13961408
}
13971409

13981410
#pragma clang attribute pop
13991411
#pragma GCC pop_options
1400-
#endif // SIMSIMD_TARGET_SVE2
1412+
#endif // SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
14011413
#endif // SIMSIMD_TARGET_ARM
14021414

14031415
#ifdef __cplusplus

0 commit comments

Comments
 (0)