Skip to content

Commit 8d8268c

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Templatize simdlib types on SIMDLevel (#4866)
Summary: Pull Request resolved: #4866 Templatize all simd wrapper types (simd16uint16, simd32uint8, simd8float32, etc.) on SIMDLevel. This is the foundation for PQ4 fast scan Dynamic Dispatch. Primary templates are declared in simdlib.h. Each platform header provides explicit specializations: - simdlib_avx2.h: simd16uint16<AVX2>, simd32uint8<AVX2>, etc. - simdlib_avx512.h: simd32uint16<AVX512>, simd64uint8<AVX512>, etc. - simdlib_neon.h: simd16uint16<ARM_NEON>, etc. - simdlib_emulated.h: simd16uint16<NONE>, etc. (always included) - simdlib_ppc64.h: simd16uint16<NONE>, etc. (PPC-optimized scalar) SINGLE_SIMD_LEVEL (inline constexpr in simd_levels.h) resolves to NONE in DD mode and to the compiled-in level in static mode. SINGLE_SIMD_LEVEL_256 maps through simd256_level_selector for 256-bit types (AVX512->AVX2, SVE->NEON). Code without explicit SL context uses these. This is migration scaffolding — subsequent diffs will replace SINGLE_SIMD_LEVEL usages with proper SL dispatch. simd_result_handlers.h is no longer %include'd by SWIG (the templatized types are unparseable by SWIG). make_knn_handler methods are %ignore'd. The Python API does not use these internal SIMD handler types. Pre-existing bug fixes bundled with this refactor: - simdlib_avx512.h: simd512bit::bin() stack buffer overflow (char[257] -> char[513]) - simdlib_avx2.h: simd256bit constructor used aligned _mm256_load_si256 instead of unaligned _mm256_loadu_si256 - All platform headers: simd16uint16/simd32uint8 operator+=/operator-= returned by value instead of by reference Static builds: zero performance change. Template specializations produce identical layout, ABI, and codegen as the old plain structs. Reviewed By: mdouze Differential Revision: D95392150 fbshipit-source-id: 435b643f96a7e08d777796390066964d99295f63
1 parent 3e4c103 commit 8d8268c

30 files changed

Lines changed: 1527 additions & 1134 deletions

faiss/CMakeLists.txt

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,13 @@ set(FAISS_HEADERS
296296
utils/rabitq_simd.h
297297
utils/random.h
298298
utils/sorting.h
299-
utils/simdlib.h
300-
utils/simdlib_avx2.h
301-
utils/simdlib_avx512.h
302-
utils/simdlib_emulated.h
303-
utils/simdlib_neon.h
304-
utils/simdlib_ppc64.h
299+
impl/simdlib/simdlib.h
300+
impl/simdlib/simdlib_dispatch.h
301+
impl/simdlib/simdlib_avx2.h
302+
impl/simdlib/simdlib_avx512.h
303+
impl/simdlib/simdlib_emulated.h
304+
impl/simdlib/simdlib_neon.h
305+
impl/simdlib/simdlib_ppc64.h
305306
utils/utils.h
306307
utils/simd_levels.h
307308
utils/distances_fused/avx512.h
@@ -459,6 +460,14 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
459460
endif()
460461
endif()
461462

463+
# NEON is mandatory on ARM64 — ensure COMPILE_SIMD_ARM_NEON is always defined
464+
# and NEON source files are always compiled into the main faiss target.
465+
# (On x86, AVX2/AVX512 are optional and only compiled per opt_level above.)
466+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
467+
target_compile_definitions(faiss PRIVATE COMPILE_SIMD_ARM_NEON)
468+
target_sources(faiss PRIVATE ${FAISS_SIMD_NEON_SRC})
469+
endif()
470+
462471
if(FAISS_ENABLE_SVS)
463472
find_package(svs_runtime 0.2.0 REQUIRED)
464473

faiss/IndexIVFPQFastScan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
#include <faiss/impl/AuxIndexStructures.h>
1717
#include <faiss/impl/FaissAssert.h>
18+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1819
#include <faiss/utils/Heap.h>
1920
#include <faiss/utils/distances.h>
2021
#include <faiss/utils/extra_distances.h>
21-
#include <faiss/utils/simdlib.h>
2222

2323
#include <faiss/invlists/BlockInvertedLists.h>
2424

@@ -349,7 +349,7 @@ struct IVFPQFastScanScanner : InvertedListScanner {
349349
const float* x = index.by_residual ? residual.data() : this->xi;
350350
float accu = 0;
351351
// implemented for all vector distances, although only L2 and IP are
352-
// suppored by FastScan
352+
// supported by FastScan
353353
with_VectorDistance(pq.dsub, index.metric_type, 0.0, [&](auto vd) {
354354
int m;
355355
for (m = 0; m + 1 < pq.M; m += 2) {

faiss/IndexRaBitQFastScan.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include <faiss/impl/RaBitQUtils.h>
1717
#include <faiss/impl/RaBitQuantizer.h>
1818
#include <faiss/impl/simd_result_handlers.h>
19+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1920
#include <faiss/utils/Heap.h>
20-
#include <faiss/utils/simdlib.h>
2121

2222
namespace faiss {
2323

faiss/impl/LookupTableScaler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <cstdint>
1111
#include <cstdlib>
1212

13-
#include <faiss/utils/simdlib.h>
13+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1414

1515
/*******************************************
1616
* The Scaler objects are used to specialize the handling of the

faiss/impl/residual_quantizer_encode_steps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
#include <faiss/impl/FaissAssert.h>
1212
#include <faiss/impl/ResidualQuantizer.h>
1313
#include <faiss/impl/simd_dispatch.h>
14+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1415
#include <faiss/utils/Heap.h>
1516
#include <faiss/utils/distances.h>
16-
#include <faiss/utils/simdlib.h>
1717
#include <faiss/utils/utils.h>
1818

1919
#include <faiss/utils/approx_topk/approx_topk.h>

faiss/impl/scalar_quantizer/codecs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#pragma once
99

1010
#include <faiss/impl/ScalarQuantizer.h>
11+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1112
#include <faiss/utils/simd_levels.h>
12-
#include <faiss/utils/simdlib.h>
1313

1414
namespace faiss {
1515

faiss/impl/scalar_quantizer/distance_computers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#pragma once
99

1010
#include <faiss/impl/ScalarQuantizer.h>
11+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1112
#include <faiss/utils/simd_levels.h>
12-
#include <faiss/utils/simdlib.h>
1313

1414
namespace faiss {
1515

faiss/impl/scalar_quantizer/quantizers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#pragma once
99

1010
#include <faiss/impl/ScalarQuantizer.h>
11+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1112
#include <faiss/utils/bf16.h>
1213
#include <faiss/utils/fp16.h>
1314
#include <faiss/utils/simd_levels.h>
14-
#include <faiss/utils/simdlib.h>
1515

1616
namespace faiss {
1717

faiss/impl/scalar_quantizer/scanners.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#pragma once
1212

1313
#include <faiss/impl/ScalarQuantizer.h>
14+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1415
#include <faiss/utils/simd_levels.h>
15-
#include <faiss/utils/simdlib.h>
1616

1717
#include <faiss/impl/simd_dispatch.h>
1818

faiss/impl/scalar_quantizer/similarities.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#pragma once
99

1010
#include <faiss/impl/ScalarQuantizer.h>
11+
#include <faiss/impl/simdlib/simdlib_dispatch.h>
1112
#include <faiss/utils/simd_levels.h>
12-
#include <faiss/utils/simdlib.h>
1313

1414
namespace faiss {
1515

0 commit comments

Comments
 (0)