Skip to content

Commit b19e9ba

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Add FastScanCodeScanner dispatch boundary with per-SIMD TUs
Summary: Add `FastScanCodeScanner`, a virtual base that bundles handler + kernel behind the SIMD dispatch boundary. In DD mode, `SINGLE_SIMD_LEVEL = NONE` so the existing fast scan code path uses emulated SIMD types. The new scanner provides per-SIMD translation units (AVX2, AVX512, ARM_NEON) compiled with the correct ISA flags, and a factory function (`make_fast_scan_knn_scanner`) that uses `DISPATCH_SIMDLevel` to select the right TU at runtime. This follows the proven `THE_LEVEL_TO_DISPATCH` pattern from the scalar quantizer per-SIMD TUs (`sq-dispatch.h`). Each per-SIMD TU includes `dispatching.h` which provides: - `ScannerMixIn<Handler>`: wraps a concrete handler and calls accumulation kernels (both search_1 multi-BB and QBS paths) - Factory specialization `make_fast_scan_scanner_impl<SL>()` with combinatorial dispatch over `is_max × with_id_map × handler_type` (SingleResultHandler for k=1, HeapHandler for k≤20, ReservoirHandler for k>20) New files: - `impl/fast_scan/dispatching.h` — dispatch template header - `impl/fast_scan/impl-avx2.cpp` — AVX2 per-SIMD TU - `impl/fast_scan/impl-avx512.cpp` — AVX512 per-SIMD TU - `impl/fast_scan/impl-neon.cpp` — ARM NEON TU (with ARM_SVE forwarding) Modified files: - `impl/fast_scan/pq4_fast_scan.h` — FastScanCodeScanner base + factory decl - `impl/fast_scan/pq4_fast_scan.cpp` — NONE specialization + dispatch wrapper - `xplat.bzl` / `CMakeLists.txt` — register SIMD files and header Note: RaBitQ handler is not wired through FastScanCodeScanner in this diff. That comes in later diffs when callers are switched. Differential Revision: D95950483
1 parent eddea2a commit b19e9ba

10 files changed

Lines changed: 458 additions & 63 deletions

faiss/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
# Architecture-specific: only include files for the current build architecture
1010
# =============================================================================
1111
set(FAISS_SIMD_AVX2_SRC
12+
impl/fast_scan/impl-avx2.cpp
1213
impl/pq_code_distance/pq_code_distance-avx2.cpp
1314
impl/scalar_quantizer/sq-avx2.cpp
1415
utils/simd_impl/distances_avx2.cpp
1516
)
1617
set(FAISS_SIMD_AVX512_SRC
18+
impl/fast_scan/impl-avx512.cpp
1719
impl/pq_code_distance/pq_code_distance-avx512.cpp
1820
impl/scalar_quantizer/sq-avx512.cpp
1921
utils/simd_impl/distances_avx512.cpp
2022
)
2123
set(FAISS_SIMD_NEON_SRC
24+
impl/fast_scan/impl-neon.cpp
2225
impl/scalar_quantizer/sq-neon.cpp
2326
utils/simd_impl/distances_aarch64.cpp
2427
)
@@ -262,6 +265,8 @@ set(FAISS_HEADERS
262265
impl/kmeans1d.h
263266
impl/lattice_Zn.h
264267
impl/platform_macros.h
268+
impl/fast_scan/accumulate_loops.h
269+
impl/fast_scan/dispatching.h
265270
impl/fast_scan/pq4_fast_scan.h
266271
impl/fast_scan/decompose_qbs.h
267272
impl/fast_scan/kernels_simd256.h
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
/**
11+
* @file accumulate_loops.h
12+
* @brief Shared accumulation loop helpers for fast-scan search paths.
13+
*
14+
* Contains the search_1 multi-BB accumulation loop (bbs > 32):
15+
* - accumulate_fixed_blocks / pq4_accumulate_loop_fixed_scaler
16+
*
17+
* The QBS path (bbs == 32) is in decompose_qbs.h.
18+
*
19+
* All functions live in `namespace faiss` (not anonymous) so they can be
20+
* shared by both the per-SIMD TU dispatcher (dispatching.h) and the old
21+
* free-function search paths (pq4_fast_scan_search_1.cpp).
22+
*
23+
* The QBS helpers here always use pq4_kernel_qbs_256 (never 512-bit).
24+
* This is required for the per-SIMD DD TUs where SINGLE_SIMD_LEVEL=NONE
25+
* leaves 512-bit types empty. The old pq4_fast_scan_search_qbs.cpp
26+
* continues to use decompose_qbs.h which includes both 256 and 512 paths.
27+
*/
28+
29+
#include <cassert>
30+
31+
#include <faiss/impl/FaissAssert.h>
32+
#include <faiss/impl/fast_scan/LookupTableScaler.h>
33+
#include <faiss/impl/fast_scan/kernels_simd256.h>
34+
#include <faiss/impl/fast_scan/simd_result_handlers.h>
35+
36+
namespace faiss {
37+
38+
using namespace simd_result_handlers;
39+
40+
/***************************************************************
41+
* Search_1 path helpers (multi-BB kernel, bbs > 32)
42+
***************************************************************/
43+
44+
template <int NQ, int BB, class ResultHandler, class Scaler>
45+
void accumulate_fixed_blocks(
46+
size_t nb,
47+
int nsq,
48+
const uint8_t* codes,
49+
const uint8_t* LUT,
50+
ResultHandler& res,
51+
const Scaler& scaler,
52+
size_t block_stride) {
53+
constexpr int bbs = 32 * BB;
54+
for (size_t j0 = 0; j0 < nb; j0 += bbs) {
55+
FixedStorageHandler<NQ, 2 * BB> res2;
56+
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
57+
res.set_block_origin(0, j0);
58+
res2.to_other_handler(res);
59+
codes += block_stride;
60+
}
61+
}
62+
63+
template <class ResultHandler, class Scaler>
64+
void pq4_accumulate_loop_fixed_scaler(
65+
int nq,
66+
size_t nb,
67+
int bbs,
68+
int nsq,
69+
const uint8_t* codes,
70+
const uint8_t* LUT,
71+
ResultHandler& res,
72+
const Scaler& scaler,
73+
size_t block_stride) {
74+
FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
75+
FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
76+
FAISS_THROW_IF_NOT(bbs % 32 == 0);
77+
FAISS_THROW_IF_NOT(nb % bbs == 0);
78+
79+
#define FAISS_ACCLOOP_DISPATCH(NQ, BB) \
80+
case NQ * 1000 + BB: \
81+
accumulate_fixed_blocks<NQ, BB>( \
82+
nb, nsq, codes, LUT, res, scaler, block_stride); \
83+
break
84+
85+
switch (nq * 1000 + bbs / 32) {
86+
FAISS_ACCLOOP_DISPATCH(1, 1);
87+
FAISS_ACCLOOP_DISPATCH(1, 2);
88+
FAISS_ACCLOOP_DISPATCH(1, 3);
89+
FAISS_ACCLOOP_DISPATCH(1, 4);
90+
FAISS_ACCLOOP_DISPATCH(1, 5);
91+
FAISS_ACCLOOP_DISPATCH(2, 1);
92+
FAISS_ACCLOOP_DISPATCH(2, 2);
93+
FAISS_ACCLOOP_DISPATCH(3, 1);
94+
FAISS_ACCLOOP_DISPATCH(4, 1);
95+
default:
96+
FAISS_THROW_FMT("nq=%d bbs=%d not instantiated", nq, bbs);
97+
}
98+
#undef FAISS_ACCLOOP_DISPATCH
99+
}
100+
101+
} // namespace faiss

faiss/impl/fast_scan/decompose_qbs.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ namespace faiss {
1919
using namespace simd_result_handlers;
2020

2121
/*
22-
* Unified kernel: selects 256-bit vs 512-bit path based on
23-
* compile-time __AVX512F__ guard.
22+
* Unified kernel: selects 256-bit vs 512-bit path.
23+
*
24+
* In static AVX512 mode: SINGLE_SIMD_LEVEL == AVX512, uses 512-bit kernel.
25+
* In DD mode AVX512 TU: __AVX512F__ is defined (compiler flags) but
26+
* SINGLE_SIMD_LEVEL == NONE (handlers use emulated types), so we fall
27+
* through to the 256-bit kernel. This is correct and intentional.
2428
*/
2529
template <int NQ, class ResultHandler, class Scaler>
2630
void kernel_accumulate_block(
@@ -30,7 +34,13 @@ void kernel_accumulate_block(
3034
ResultHandler& res,
3135
const Scaler& scaler) {
3236
#ifdef __AVX512F__
33-
pq4_kernel_qbs_512<NQ>(nsq, codes, LUT, res, scaler);
37+
if constexpr (
38+
SINGLE_SIMD_LEVEL == SIMDLevel::AVX512 ||
39+
SINGLE_SIMD_LEVEL == SIMDLevel::AVX512_SPR) {
40+
pq4_kernel_qbs_512<NQ>(nsq, codes, LUT, res, scaler);
41+
} else {
42+
pq4_kernel_qbs_256<NQ>(nsq, codes, LUT, res, scaler);
43+
}
3444
#else
3545
pq4_kernel_qbs_256<NQ>(nsq, codes, LUT, res, scaler);
3646
#endif

faiss/impl/fast_scan/dispatching.h

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
/**
11+
* @file dispatching.h
12+
* @brief Per-SIMD TU dispatch template for fast scan.
13+
*
14+
* This header is included once per SIMD TU with THE_LEVEL_TO_DISPATCH
15+
* set to the desired SIMDLevel. It provides:
16+
* - ScannerMixIn: wraps a handler + calls kernel at the TU's SIMD level
17+
* - make_fast_scan_scanner_impl<SL>: factory specialization
18+
*
19+
* Usage (in a per-SIMD .cpp file):
20+
* #define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX2
21+
* #include <faiss/impl/fast_scan/dispatching.h>
22+
*
23+
* Kernel helpers come from accumulate_loops.h (search_1 multi-BB path)
24+
* and decompose_qbs.h (QBS path, with if-constexpr guard for 512-bit).
25+
*/
26+
27+
#ifndef THE_LEVEL_TO_DISPATCH
28+
#error "Define THE_LEVEL_TO_DISPATCH before including this header"
29+
#endif
30+
31+
#include <memory>
32+
33+
#include <faiss/impl/fast_scan/accumulate_loops.h>
34+
#include <faiss/impl/fast_scan/decompose_qbs.h>
35+
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
36+
37+
namespace faiss {
38+
39+
using namespace simd_result_handlers;
40+
41+
/***************************************************************
42+
* ScannerMixIn: wraps a concrete handler + calls accumulation
43+
* kernels. Lives behind the virtual FastScanCodeScanner interface
44+
* so callers don't need to know the handler type.
45+
***************************************************************/
46+
47+
template <class Handler>
48+
struct ScannerMixIn : FastScanCodeScanner {
49+
Handler handler_;
50+
51+
template <typename... Args>
52+
explicit ScannerMixIn(Args&&... args)
53+
: handler_(std::forward<Args>(args)...) {}
54+
55+
SIMDResultHandlerToFloat* handler() override {
56+
return &handler_;
57+
}
58+
59+
void accumulate_loop(
60+
int nq,
61+
size_t nb,
62+
int bbs,
63+
int nsq,
64+
const uint8_t* codes,
65+
const uint8_t* LUT,
66+
int pq2x4_scale,
67+
size_t block_stride) override {
68+
if (pq2x4_scale) {
69+
NormTableScaler<> scaler(pq2x4_scale);
70+
pq4_accumulate_loop_fixed_scaler(
71+
nq,
72+
nb,
73+
bbs,
74+
nsq,
75+
codes,
76+
LUT,
77+
handler_,
78+
scaler,
79+
block_stride);
80+
} else {
81+
DummyScaler<> dummy;
82+
pq4_accumulate_loop_fixed_scaler(
83+
nq,
84+
nb,
85+
bbs,
86+
nsq,
87+
codes,
88+
LUT,
89+
handler_,
90+
dummy,
91+
block_stride);
92+
}
93+
}
94+
95+
void accumulate_loop_qbs(
96+
int qbs,
97+
size_t nb,
98+
int nsq,
99+
const uint8_t* codes,
100+
const uint8_t* LUT,
101+
int pq2x4_scale,
102+
size_t block_stride) override {
103+
if (pq2x4_scale) {
104+
NormTableScaler<> scaler(pq2x4_scale);
105+
pq4_accumulate_loop_qbs_fixed_scaler(
106+
qbs, nb, nsq, codes, LUT, handler_, scaler, block_stride);
107+
} else {
108+
DummyScaler<> dummy;
109+
pq4_accumulate_loop_qbs_fixed_scaler(
110+
qbs, nb, nsq, codes, LUT, handler_, dummy, block_stride);
111+
}
112+
}
113+
};
114+
115+
/***************************************************************
116+
* Factory specialization for this SIMD level.
117+
*
118+
* Combinatorial dispatch: is_max × with_id_map × handler type
119+
* k == 1: SingleResultHandler
120+
* impl even: HeapHandler
121+
* impl odd: ReservoirHandler (capacity = 2*k)
122+
***************************************************************/
123+
124+
template <>
125+
std::unique_ptr<FastScanCodeScanner> make_fast_scan_scanner_impl<
126+
THE_LEVEL_TO_DISPATCH>(
127+
bool is_max,
128+
int impl,
129+
size_t nq,
130+
size_t ntotal,
131+
int64_t k,
132+
float* distances,
133+
int64_t* ids,
134+
const IDSelector* sel,
135+
bool with_id_map) {
136+
// Helper lambda: given comparator C and with_id_map W, select handler
137+
auto make = [&]<class C, bool W>() -> std::unique_ptr<FastScanCodeScanner> {
138+
if (k == 1) {
139+
using H = SingleResultHandler<C, W>;
140+
return std::make_unique<ScannerMixIn<H>>(
141+
nq, ntotal, distances, ids, sel);
142+
} else if (impl % 2 == 0) {
143+
using H = HeapHandler<C, W>;
144+
return std::make_unique<ScannerMixIn<H>>(
145+
nq, ntotal, k, distances, ids, sel);
146+
} else {
147+
using H = ReservoirHandler<C, W>;
148+
return std::make_unique<ScannerMixIn<H>>(
149+
nq, ntotal, size_t(k), size_t(2 * k), distances, ids, sel);
150+
}
151+
};
152+
153+
if (is_max) {
154+
if (with_id_map) {
155+
return make.template operator()<CMax<uint16_t, int64_t>, true>();
156+
} else {
157+
return make.template operator()<CMax<uint16_t, int>, false>();
158+
}
159+
} else {
160+
if (with_id_map) {
161+
return make.template operator()<CMin<uint16_t, int64_t>, true>();
162+
} else {
163+
return make.template operator()<CMin<uint16_t, int>, false>();
164+
}
165+
}
166+
}
167+
168+
} // namespace faiss

faiss/impl/fast_scan/impl-avx2.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#ifdef COMPILE_SIMD_AVX2
9+
10+
#define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX2
11+
#include <faiss/impl/fast_scan/dispatching.h> // IWYU pragma: keep
12+
13+
#endif // COMPILE_SIMD_AVX2
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#ifdef COMPILE_SIMD_AVX512
9+
10+
#define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX512
11+
#include <faiss/impl/fast_scan/dispatching.h> // IWYU pragma: keep
12+
13+
#endif // COMPILE_SIMD_AVX512

faiss/impl/fast_scan/impl-neon.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#ifdef COMPILE_SIMD_ARM_NEON
9+
10+
#define THE_LEVEL_TO_DISPATCH SIMDLevel::ARM_NEON
11+
#include <faiss/impl/fast_scan/dispatching.h> // IWYU pragma: keep
12+
13+
// ARM_SVE: forward to ARM_NEON implementation until a dedicated SVE
14+
// specialization is written (same pattern as scalar_quantizer/sq-neon.cpp).
15+
#ifdef COMPILE_SIMD_ARM_SVE
16+
17+
namespace faiss {
18+
19+
template <>
20+
std::unique_ptr<FastScanCodeScanner> make_fast_scan_scanner_impl<
21+
SIMDLevel::ARM_SVE>(
22+
bool is_max,
23+
int impl,
24+
size_t nq,
25+
size_t ntotal,
26+
int64_t k,
27+
float* distances,
28+
int64_t* ids,
29+
const IDSelector* sel,
30+
bool with_id_map) {
31+
return make_fast_scan_scanner_impl<SIMDLevel::ARM_NEON>(
32+
is_max, impl, nq, ntotal, k, distances, ids, sel, with_id_map);
33+
}
34+
35+
} // namespace faiss
36+
37+
#endif // COMPILE_SIMD_ARM_SVE
38+
39+
#endif // COMPILE_SIMD_ARM_NEON

0 commit comments

Comments
 (0)