@@ -1247,17 +1247,17 @@ SIMSIMD_PUBLIC void simsimd_intersect_u32_sve2(simsimd_u32_t const* a, simsimd_u
1247
1247
*results = c;
1248
1248
}
1249
1249
1250
- SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2 ( //
1251
- simsimd_u16_t const * a, simsimd_u16_t const * b, //
1252
- simsimd_bf16_t const * a_weights, simsimd_bf16_t const * b_weights, //
1253
- simsimd_size_t a_length, simsimd_size_t b_length, //
1250
+ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2 ( //
1251
+ simsimd_u16_t const * a, simsimd_u16_t const * b, //
1252
+ simsimd_i16_t const * a_weights, simsimd_i16_t const * b_weights, //
1253
+ simsimd_size_t a_length, simsimd_size_t b_length, //
1254
1254
simsimd_distance_t * results) {
1255
1255
1256
1256
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
1257
1257
simsimd_size_t const register_size = svcnth ();
1258
1258
simsimd_size_t const lanes_count = register_size / 8 ;
1259
1259
simsimd_size_t a_idx = 0 , b_idx = 0 ;
1260
- svfloat32_t product_vec = svdupq_n_f32 ( 0 . f , 0 . f , 0 . f , 0 . f );
1260
+ svint64_t product_vec = svdupq_n_s64 ( 0 , 0 );
1261
1261
simsimd_size_t intersection_size = 0 ;
1262
1262
1263
1263
while (a_idx < a_length && b_idx < b_length) {
@@ -1303,12 +1303,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
1303
1303
simsimd_u64_t b_step = svcntp_b16 (b_progress, b_mask);
1304
1304
1305
1305
// Compare `a_vec` with each lane of `b_vec`
1306
- svbfloat16_t a_weights_vec = svld1_bf16 (a_progress, a_weights + a_idx);
1307
- svbfloat16_t b_weights_vec = svld1_bf16 (b_progress, b_weights + b_idx);
1306
+ svint16_t a_weights_vec = svld1_s16 (a_progress, a_weights + a_idx);
1307
+ svint16_t b_weights_vec = svld1_s16 (b_progress, b_weights + b_idx);
1308
1308
for (simsimd_size_t i = 0 ; i < lanes_count; i++) {
1309
1309
svbool_t equal_mask = svmatch_u16 (a_progress, a_vec, b_vec);
1310
- svbfloat16_t b_equal_weights_vec = svsel_bf16 (equal_mask, b_weights_vec, svdup_n_bf16 (0 .f ));
1311
- product_vec = svbfdot_f32 (product_vec, a_weights_vec, b_equal_weights_vec);
1310
+ svint16_t b_equal_weights_vec = svsel_s16 (equal_mask, b_weights_vec, svdup_n_s16 (0 .f ));
1311
+ product_vec = svdot_s64 (product_vec, a_weights_vec, b_equal_weights_vec);
1312
1312
b_vec = svext_u16 (b_vec, b_vec, 8 );
1313
1313
intersection_size += svcntp_b16 (svptrue_b16 (), equal_mask);
1314
1314
}
@@ -1318,20 +1318,29 @@ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2( //
1318
1318
b_idx += b_step;
1319
1319
}
1320
1320
results[0 ] = (simsimd_distance_t )intersection_size;
1321
- results[1 ] = svaddv_f32 ( svptrue_b32 (), product_vec);
1321
+ results[1 ] = svaddv_s64 ( svptrue_b64 (), product_vec);
1322
1322
}
1323
1323
1324
- SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2 ( //
1325
- simsimd_u16_t const * a, simsimd_u16_t const * b, //
1326
- simsimd_i16_t const * a_weights, simsimd_i16_t const * b_weights, //
1327
- simsimd_size_t a_length, simsimd_size_t b_length, //
1324
+ #pragma clang attribute pop
1325
+ #pragma GCC pop_options
1326
+ #endif // SIMSIMD_TARGET_SVE2
1327
+
1328
+ #if SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
1329
+ #pragma GCC push_options
1330
+ #pragma GCC target("arch=armv8.6-a+sve+sve2+bf16")
1331
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+sve+sve2+bf16"))), apply_to = function)
1332
+
1333
+ SIMSIMD_PUBLIC void simsimd_spdot_weights_u16_sve2 ( //
1334
+ simsimd_u16_t const * a, simsimd_u16_t const * b, //
1335
+ simsimd_bf16_t const * a_weights, simsimd_bf16_t const * b_weights, //
1336
+ simsimd_size_t a_length, simsimd_size_t b_length, //
1328
1337
simsimd_distance_t * results) {
1329
1338
1330
1339
// A single SVE lane is 128 bits wide, so one lane fits 8 values.
1331
1340
simsimd_size_t const register_size = svcnth ();
1332
1341
simsimd_size_t const lanes_count = register_size / 8 ;
1333
1342
simsimd_size_t a_idx = 0 , b_idx = 0 ;
1334
- svint64_t product_vec = svdupq_n_s64 ( 0 , 0 );
1343
+ svfloat32_t product_vec = svdupq_n_f32 ( 0 . f , 0 . f , 0 . f , 0 . f );
1335
1344
simsimd_size_t intersection_size = 0 ;
1336
1345
1337
1346
while (a_idx < a_length && b_idx < b_length) {
@@ -1377,12 +1386,15 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
1377
1386
simsimd_u64_t b_step = svcntp_b16 (b_progress, b_mask);
1378
1387
1379
1388
// Compare `a_vec` with each lane of `b_vec`
1380
- svbfloat16_t a_weights_vec = svld1_s16 (a_progress, a_weights + a_idx);
1381
- svbfloat16_t b_weights_vec = svld1_s16 (b_progress, b_weights + b_idx);
1389
+ svbfloat16_t a_weights_vec = svld1_bf16 (a_progress, a_weights + a_idx);
1390
+ svbfloat16_t b_weights_vec = svld1_bf16 (b_progress, b_weights + b_idx);
1382
1391
for (simsimd_size_t i = 0 ; i < lanes_count; i++) {
1383
1392
svbool_t equal_mask = svmatch_u16 (a_progress, a_vec, b_vec);
1384
- svbfloat16_t b_equal_weights_vec = svsel_s16 (equal_mask, b_weights_vec, svdup_n_bf16 (0 .f ));
1385
- product_vec = svdot_s64 (product_vec, a_weights_vec, b_equal_weights_vec);
1393
+ // ! The `svsel_bf16` intrinsic is broken in many compilers, not returning the correct type.
1394
+ // ! So we reinterprete floats as integers and apply `svsel_s16`.
1395
+ svint16_t b_equal_weights_vec =
1396
+ svsel_s16 (equal_mask, svreinterpret_s16_bs16 (b_weights_vec), svdup_n_s16 (0 ));
1397
+ product_vec = svbfdot_f32 (product_vec, a_weights_vec, svreinterpret_bf16_s16 (b_equal_weights_vec));
1386
1398
b_vec = svext_u16 (b_vec, b_vec, 8 );
1387
1399
intersection_size += svcntp_b16 (svptrue_b16 (), equal_mask);
1388
1400
}
@@ -1392,12 +1404,12 @@ SIMSIMD_PUBLIC void simsimd_spdot_counts_u16_sve2( //
1392
1404
b_idx += b_step;
1393
1405
}
1394
1406
results[0 ] = (simsimd_distance_t )intersection_size;
1395
- results[1 ] = svaddv_s64 ( svptrue_b64 (), product_vec);
1407
+ results[1 ] = svaddv_f32 ( svptrue_b32 (), product_vec);
1396
1408
}
1397
1409
1398
1410
#pragma clang attribute pop
1399
1411
#pragma GCC pop_options
1400
- #endif // SIMSIMD_TARGET_SVE2
1412
+ #endif // SIMSIMD_TARGET_SVE2 && SIMSIMD_TARGET_SVE_BF16
1401
1413
#endif // SIMSIMD_TARGET_ARM
1402
1414
1403
1415
#ifdef __cplusplus
0 commit comments