Skip to content

Commit 6130f30

Browse files
committed
refactor
- Remove standalone sra_krl/ library, inline kernels as C++ template specializations under impl/pq_code_distance/ - Replace #ifdef __aarch64__ with COMPILE_SIMD_ARM_NEON and SIMDLevel::ARM_NEON templates - Convert .c files to .cpp, drop LANGUAGE CXX workarounds - No algorithmic changes, benchmarks consistent with prior results
1 parent f851c54 commit 6130f30

13 files changed

Lines changed: 2680 additions & 58 deletions

faiss/CMakeLists.txt

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ set(FAISS_SIMD_NEON_SRC
2525
impl/fast_scan/impl-neon.cpp
2626
impl/scalar_quantizer/sq-neon.cpp
2727
impl/approx_topk/neon.cpp
28+
impl/pq_code_distance/pq_code_distance-neon.cpp
29+
impl/ProductQuantizer-neon.cpp
2830
utils/simd_impl/distances_aarch64.cpp
31+
utils/simd_impl/distances_neon.cpp
32+
utils/simd_impl/matrix_transpose_neon.cpp
2933
)
3034
set(FAISS_SIMD_SVE_SRC
3135
impl/pq_code_distance/pq_code_distance-sve.cpp
@@ -431,7 +435,12 @@ if(NOT WIN32)
431435
endif()
432436
endif()
433437
target_sources(faiss_sve PRIVATE ${FAISS_SIMD_NEON_SRC} ${FAISS_SIMD_SVE_SRC})
434-
target_compile_definitions(faiss_sve PRIVATE COMPILE_SIMD_ARM_NEON COMPILE_SIMD_ARM_SVE)
438+
target_compile_definitions(faiss_sve PUBLIC COMPILE_SIMD_ARM_NEON COMPILE_SIMD_ARM_SVE)
439+
# ProductQuantizer-neon.cpp uses vdotq_s32/vdotq_u32 which require +dotprod
440+
set_source_files_properties(impl/ProductQuantizer-neon.cpp
441+
TARGET_DIRECTORY faiss_sve
442+
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+dotprod"
443+
)
435444

436445
# =============================================================================
437446
# Dynamic Dispatch Mode
@@ -462,12 +471,17 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
462471
)
463472
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
464473
# ARM NEON is always available on aarch64, no special compiler flags needed
465-
target_compile_definitions(faiss PRIVATE COMPILE_SIMD_ARM_NEON COMPILE_SIMD_ARM_SVE)
474+
target_compile_definitions(faiss PUBLIC COMPILE_SIMD_ARM_NEON COMPILE_SIMD_ARM_SVE)
466475
# Per-file SVE flags (NEON needs no flags on aarch64)
467476
set_source_files_properties(${FAISS_SIMD_SVE_SRC}
468477
TARGET_DIRECTORY faiss
469478
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+sve"
470479
)
480+
# ProductQuantizer-neon.cpp uses vdotq_s32/vdotq_u32 which require +dotprod
481+
set_source_files_properties(impl/ProductQuantizer-neon.cpp
482+
TARGET_DIRECTORY faiss
483+
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+dotprod"
484+
)
471485
endif()
472486
endif()
473487
endif()
@@ -476,8 +490,15 @@ endif()
476490
# and NEON source files are always compiled into the main faiss target.
477491
# (On x86, AVX2/AVX512 are optional and only compiled per opt_level above.)
478492
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
479-
target_compile_definitions(faiss PRIVATE COMPILE_SIMD_ARM_NEON)
493+
target_compile_definitions(faiss PUBLIC COMPILE_SIMD_ARM_NEON)
480494
target_sources(faiss PRIVATE ${FAISS_SIMD_NEON_SRC})
495+
# ProductQuantizer-neon.cpp uses vdotq_s32/vdotq_u32 which require +dotprod
496+
if(NOT FAISS_OPT_LEVEL STREQUAL "dd")
497+
set_source_files_properties(impl/ProductQuantizer-neon.cpp
498+
TARGET_DIRECTORY faiss
499+
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+dotprod"
500+
)
501+
endif()
481502
endif()
482503

483504
if(FAISS_ENABLE_SVS)

faiss/IndexIVFPQ.cpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,8 @@ struct IVFPQScannerT : QueryTables {
816816
}
817817

818818
float dis0;
819+
mutable std::vector<float> dis_buffer;
820+
mutable std::vector<size_t> idx_buffer;
819821

820822
void init_list(idx_t list_no, float coarse_dis_in, int mode) {
821823
this->key = list_no;
@@ -931,6 +933,59 @@ struct IVFPQScannerT : QueryTables {
931933
}
932934
}
933935

936+
/// Batch version using pq_code_distance_batch for ARM_NEON optimization
937+
template <class SearchResultType>
938+
void scan_list_with_table_batch(
939+
size_t ncode,
940+
const uint8_t* codes,
941+
SearchResultType& res) const {
942+
if (dis_buffer.size() < ncode) {
943+
dis_buffer.resize(ncode);
944+
}
945+
946+
pq_code_distance_batch(
947+
pq.M, pq.nbits, ncode, codes, sim_table, dis_buffer.data(), dis0);
948+
949+
for (size_t j = 0; j < ncode; j++) {
950+
if (!res.skip_entry(j)) {
951+
res.add(j, dis_buffer[j]);
952+
}
953+
}
954+
}
955+
956+
/// Batch version with IDSelector pre-filtering for ARM_NEON optimization.
957+
template <class SearchResultType>
958+
void scan_list_with_table_batch_sel(
959+
size_t ncode,
960+
const uint8_t* codes,
961+
SearchResultType& res) const {
962+
// Build index list of passing entries
963+
if (idx_buffer.size() < ncode) {
964+
idx_buffer.resize(ncode);
965+
}
966+
size_t npass = 0;
967+
for (size_t i = 0; i < ncode; i++) {
968+
if (!res.skip_entry(i)) {
969+
idx_buffer[npass++] = i;
970+
}
971+
}
972+
if (npass == 0) {
973+
return;
974+
}
975+
if (dis_buffer.size() < npass) {
976+
dis_buffer.resize(npass);
977+
}
978+
// Compute distances only for passing entries
979+
for (size_t j = 0; j < npass; j++) {
980+
dis_buffer[j] = pq_code_distance_single(
981+
pq.M, pq.nbits, sim_table,
982+
codes + idx_buffer[j] * pq.code_size) + dis0;
983+
}
984+
for (size_t j = 0; j < npass; j++) {
985+
res.add(idx_buffer[j], dis_buffer[j]);
986+
}
987+
}
988+
934989
/// tables are not precomputed, but pointers are provided to the
935990
/// relevant X_c|x_r tables
936991
template <class SearchResultType>
@@ -1239,7 +1294,21 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDist>,
12391294
assert(precompute_mode == 2);
12401295
this->scan_list_polysemous(ncode, codes, res);
12411296
} else if (precompute_mode == 2) {
1242-
this->scan_list_with_table(ncode, codes, res);
1297+
if constexpr (PQCodeDist::simd_level == SIMDLevel::ARM_NEON) {
1298+
if (this->pq.nbits == 8) {
1299+
if constexpr (use_sel) {
1300+
// ARM_NEON + IDSelector: pre-filter then batch compute
1301+
this->scan_list_with_table_batch_sel(ncode, codes, res);
1302+
} else {
1303+
// ARM_NEON, no IDSelector: full batch compute
1304+
this->scan_list_with_table_batch(ncode, codes, res);
1305+
}
1306+
} else {
1307+
this->scan_list_with_table(ncode, codes, res);
1308+
}
1309+
} else {
1310+
this->scan_list_with_table(ncode, codes, res);
1311+
}
12431312
} else if (precompute_mode == 1) {
12441313
this->scan_list_with_pointer(ncode, codes, res);
12451314
} else if (precompute_mode == 0) {

0 commit comments

Comments
 (0)