Skip to content

Commit d1c432f

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Convert rabitq_simd.h to runtime SIMD dispatch (facebookresearch#4912)
Summary: Pull Request resolved: facebookresearch#4912 Convert RaBitQ SIMD primitives from compile-time #ifdef cascades to the DD template specialization pattern. - Replace inline SIMD functions in rabitq_simd.h with template<SIMDLevel> declarations and NONE (scalar) specializations - Create rabitq_avx2.cpp with AVX2+SSE4.1 specializations for bitwise_and_dot_product, bitwise_xor_dot_product, popcount, and compute_inner_product (including BMI2 bitplane kernel) - Create rabitq_avx512.cpp with AVX512 specializations (falls back to AVX2 bitplane kernel for compute_inner_product with ex_bits >= 2) - Add AVX512_SPR forwarding to AVX512 for all four functions - Create rabitq_neon.cpp forwarding ARM_NEON/ARM_SVE to NONE - Wrap calls in RaBitQuantizer.cpp and RaBitQUtils.cpp with with_simd_level dispatch - Add -mbmi2 to AVX2 and AVX512 SIMD flags (required by _pext_u64 in bitplane kernel; BMI2 is present on all Haswell+ CPUs) - Register new SIMD files in xplat.bzl and CMakeLists.txt Reviewed By: mdouze Differential Revision: D96298743 fbshipit-source-id: 1941bf099664ba2d3d1c1245d5c06d1775614955
1 parent 6bca961 commit d1c432f

File tree

8 files changed

+1275
-821
lines changed

8 files changed

+1275
-821
lines changed

faiss/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@ set(FAISS_SIMD_AVX2_SRC
1515
impl/approx_topk/avx2.cpp
1616
utils/simd_impl/distances_avx2.cpp
1717
utils/distances_fused/simdlib_based.cpp
18+
utils/simd_impl/rabitq_avx2.cpp
1819
)
1920
set(FAISS_SIMD_AVX512_SRC
2021
impl/fast_scan/impl-avx512.cpp
2122
impl/pq_code_distance/pq_code_distance-avx512.cpp
2223
impl/scalar_quantizer/sq-avx512.cpp
2324
utils/simd_impl/distances_avx512.cpp
2425
utils/distances_fused/avx512.cpp
26+
utils/simd_impl/rabitq_avx512.cpp
2527
)
2628
set(FAISS_SIMD_NEON_SRC
2729
impl/fast_scan/impl-neon.cpp
2830
impl/scalar_quantizer/sq-neon.cpp
2931
impl/approx_topk/neon.cpp
3032
utils/simd_impl/distances_aarch64.cpp
3133
utils/distances_fused/simdlib_based_neon.cpp
34+
utils/simd_impl/rabitq_neon.cpp
3235
)
3336
set(FAISS_SIMD_SVE_SRC
3437
impl/pq_code_distance/pq_code_distance-sve.cpp

faiss/impl/RaBitQUtils.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <faiss/impl/RaBitQUtils.h>
99

1010
#include <faiss/impl/FaissAssert.h>
11+
#include <faiss/impl/simd_dispatch.h>
1112
#include <faiss/utils/distances.h>
1213
#include <faiss/utils/rabitq_simd.h>
1314
#include <algorithm>
@@ -306,6 +307,9 @@ size_t compute_per_vector_storage_size(size_t nb_bits, size_t d) {
306307
}
307308
}
308309

310+
// Non-template wrapper with dynamic dispatch (one dispatch per call).
311+
// The hot path in RaBitQuantizer dispatches once at distance computer
312+
// construction, so per-vector dispatch only affects this utility path.
309313
float compute_full_multibit_distance(
310314
const uint8_t* sign_bits,
311315
const uint8_t* ex_code,
@@ -315,18 +319,18 @@ float compute_full_multibit_distance(
315319
size_t d,
316320
size_t ex_bits,
317321
MetricType metric_type) {
318-
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
319-
320-
float ex_ip = rabitq::multibit::compute_inner_product(
321-
sign_bits, ex_code, rotated_q, d, ex_bits, cb);
322-
323-
float dist = qr_base + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
324-
325-
if (metric_type == MetricType::METRIC_L2) {
326-
dist = std::max(0.0f, dist);
327-
}
328-
329-
return dist;
322+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
323+
[&]<SIMDLevel SL>() {
324+
return compute_full_multibit_distance<SL>(
325+
sign_bits,
326+
ex_code,
327+
ex_fac,
328+
rotated_q,
329+
qr_base,
330+
d,
331+
ex_bits,
332+
metric_type);
333+
});
330334
}
331335

332336
void populate_block_aux_from_flat_storage(

faiss/impl/RaBitQUtils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#include <faiss/MetricType.h>
1111
#include <faiss/impl/platform_macros.h>
1212
#include <faiss/utils/AlignedTable.h>
13+
#include <faiss/utils/rabitq_simd.h>
14+
#include <faiss/utils/simd_levels.h>
15+
#include <algorithm>
1316
#include <cstddef>
1417
#include <cstdint>
1518
#include <cstring>
@@ -337,6 +340,33 @@ float compute_full_multibit_distance(
337340
size_t ex_bits,
338341
MetricType metric_type);
339342

343+
// SIMDLevel-templatized version — avoids per-call dynamic dispatch.
344+
// Inline so it can be used from templatized distance computers without
345+
// needing explicit instantiations in per-SIMD TUs.
346+
template <SIMDLevel SL>
347+
inline float compute_full_multibit_distance(
348+
const uint8_t* sign_bits,
349+
const uint8_t* ex_code,
350+
const ExtraBitsFactors& ex_fac,
351+
const float* rotated_q,
352+
float qr_base,
353+
size_t d,
354+
size_t ex_bits,
355+
MetricType metric_type) {
356+
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
357+
358+
float ex_ip = rabitq::multibit::compute_inner_product<SL>(
359+
sign_bits, ex_code, rotated_q, d, ex_bits, cb);
360+
361+
float dist = qr_base + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
362+
363+
if (metric_type == MetricType::METRIC_L2) {
364+
dist = std::max(0.0f, dist);
365+
}
366+
367+
return dist;
368+
}
369+
340370
/** Compute pointer to a vector's auxiliary data within block layout. */
341371
template <typename T>
342372
inline T* get_block_aux_ptr(

0 commit comments

Comments
 (0)