Skip to content

Commit 2fdaa78

Browse files
authored
Merge branch 'main' into export-D95911440
2 parents ce9cd6f + d8c1a97 commit 2fdaa78

42 files changed

Lines changed: 2166 additions & 1665 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

faiss/CMakeLists.txt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@
99
# Architecture-specific: only include files for the current build architecture
1010
# =============================================================================
1111
set(FAISS_SIMD_AVX2_SRC
12+
impl/fast_scan/impl-avx2.cpp
1213
impl/pq_code_distance/pq_code_distance-avx2.cpp
1314
impl/scalar_quantizer/sq-avx2.cpp
1415
utils/simd_impl/distances_avx2.cpp
1516
)
1617
set(FAISS_SIMD_AVX512_SRC
18+
impl/fast_scan/impl-avx512.cpp
1719
impl/pq_code_distance/pq_code_distance-avx512.cpp
1820
impl/scalar_quantizer/sq-avx512.cpp
1921
utils/simd_impl/distances_avx512.cpp
2022
)
2123
set(FAISS_SIMD_NEON_SRC
24+
impl/fast_scan/impl-neon.cpp
2225
impl/scalar_quantizer/sq-neon.cpp
2326
utils/simd_impl/distances_aarch64.cpp
2427
)
@@ -117,9 +120,7 @@ set(FAISS_SRC
117120
impl/kmeans1d.cpp
118121
impl/lattice_Zn.cpp
119122
impl/mapped_io.cpp
120-
impl/fast_scan/pq4_fast_scan.cpp
121-
impl/fast_scan/pq4_fast_scan_search_1.cpp
122-
impl/fast_scan/pq4_fast_scan_search_qbs.cpp
123+
impl/fast_scan/fast_scan.cpp
123124
impl/residual_quantizer_encode_steps.cpp
124125
impl/zerocopy_io.cpp
125126
impl/NNDescent.cpp
@@ -262,10 +263,13 @@ set(FAISS_HEADERS
262263
impl/kmeans1d.h
263264
impl/lattice_Zn.h
264265
impl/platform_macros.h
265-
impl/fast_scan/pq4_fast_scan.h
266+
impl/fast_scan/accumulate_loops.h
267+
impl/fast_scan/dispatching.h
268+
impl/fast_scan/fast_scan.h
266269
impl/fast_scan/decompose_qbs.h
267270
impl/fast_scan/kernels_simd256.h
268271
impl/fast_scan/kernels_simd512.h
272+
impl/fast_scan/rabitq_dispatching.h
269273
impl/fast_scan/rabitq_result_handler.h
270274
impl/residual_quantizer_encode_steps.h
271275
impl/simd_dispatch.h
@@ -285,7 +289,6 @@ set(FAISS_HEADERS
285289
utils/WorkerThread.h
286290
utils/distances.h
287291
utils/distances_dispatch.h
288-
utils/extra_distances-inl.h
289292
utils/extra_distances.h
290293
utils/fp16-fp16c.h
291294
utils/fp16-inl.h
@@ -327,6 +330,7 @@ set(FAISS_HEADERS
327330
utils/hamming_distance/avx512-inl.h
328331
utils/simd_impl/distances_autovec-inl.h
329332
utils/simd_impl/distances_simdlib256.h
333+
utils/simd_impl/IVFFlatScanner-inl.h
330334
utils/simd_impl/distances_sse-inl.h
331335
)
332336

faiss/IndexAdditiveQuantizerFastScan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <faiss/impl/LocalSearchQuantizer.h>
1515
#include <faiss/impl/ResidualQuantizer.h>
1616
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
17-
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
17+
#include <faiss/impl/fast_scan/fast_scan.h>
1818
#include <faiss/utils/quantize_lut.h>
1919
#include <faiss/utils/utils.h>
2020

faiss/IndexFastScan.cpp

Lines changed: 27 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
#include <faiss/impl/IDSelector.h>
1717
#include <faiss/impl/RaBitQUtils.h>
1818
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
19-
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
19+
#include <faiss/impl/fast_scan/fast_scan.h>
2020
#include <faiss/impl/fast_scan/simd_result_handlers.h>
2121
#include <faiss/utils/hamming.h>
2222
#include <faiss/utils/quantize_lut.h>
2323
#include <faiss/utils/utils.h>
2424

2525
namespace faiss {
2626

27-
using namespace simd_result_handlers;
28-
2927
inline size_t roundup(size_t a, size_t b) {
3028
return (a + b - 1) / b * b;
3129
}
@@ -211,43 +209,18 @@ void estimators_from_tables_generic(
211209

212210
} // anonymous namespace
213211

214-
// Default implementation of make_knn_handler with centralized fallback logic
215-
SIMDResultHandlerToFloat* IndexFastScan::make_knn_handler(
212+
std::unique_ptr<FastScanCodeScanner> IndexFastScan::make_knn_scanner(
216213
bool is_max,
217-
int impl,
218214
idx_t n,
219215
idx_t k,
220216
size_t ntotal,
221217
float* distances,
222218
idx_t* labels,
223219
const IDSelector* sel,
220+
int impl,
224221
const FastScanDistancePostProcessing&) const {
225-
// Create default handlers based on k and impl
226-
if (is_max) {
227-
using HeapHC = HeapHandler<CMax<uint16_t, int>, false>;
228-
using ReservoirHC = ReservoirHandler<CMax<uint16_t, int>, false>;
229-
using SingleResultHC = SingleResultHandler<CMax<uint16_t, int>, false>;
230-
231-
if (k == 1) {
232-
return new SingleResultHC(n, ntotal, distances, labels, sel);
233-
} else if (impl % 2 == 0) {
234-
return new HeapHC(n, ntotal, k, distances, labels, sel);
235-
} else {
236-
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
237-
}
238-
} else {
239-
using HeapHC = HeapHandler<CMin<uint16_t, int>, false>;
240-
using ReservoirHC = ReservoirHandler<CMin<uint16_t, int>, false>;
241-
using SingleResultHC = SingleResultHandler<CMin<uint16_t, int>, false>;
242-
243-
if (k == 1) {
244-
return new SingleResultHC(n, ntotal, distances, labels, sel);
245-
} else if (impl % 2 == 0) {
246-
return new HeapHC(n, ntotal, k, distances, labels, sel);
247-
} else {
248-
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
249-
}
250-
}
222+
return make_fast_scan_knn_scanner(
223+
is_max, impl, n, ntotal, k, distances, labels, sel);
251224
}
252225

253226
using namespace quantize_lut;
@@ -468,7 +441,6 @@ void IndexFastScan::search_implem_12(
468441
idx_t* labels,
469442
int impl,
470443
const FastScanDistancePostProcessing& context) const {
471-
using RH = ResultHandlerCompare<C, false>;
472444
FAISS_THROW_IF_NOT(bbs == 32);
473445

474446
// handle qbs2 blocking by recursive call
@@ -519,36 +491,26 @@ void IndexFastScan::search_implem_12(
519491
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
520492
FAISS_THROW_IF_NOT(LUT_nq == n);
521493

522-
std::unique_ptr<RH> handler(
523-
static_cast<RH*>(make_knn_handler(
524-
C::is_max,
525-
impl,
526-
n,
527-
k,
528-
ntotal,
529-
distances,
530-
labels,
531-
nullptr,
532-
context)));
533-
534-
handler->disable = bool(skip & 2);
535-
handler->normalizers = normalizers.get();
536-
537-
if (skip & 4) {
538-
// pass
539-
} else {
540-
pq4_accumulate_loop_qbs(
494+
auto scanner = make_knn_scanner(
495+
C::is_max, n, k, ntotal, distances, labels, nullptr, impl, context);
496+
auto* rh = scanner->handler();
497+
rh->normalizers = normalizers.get();
498+
// Note: skip & 2 previously set handler->disable (run kernel,
499+
// discard results). Through the scanner path, skip & 2 now skips
500+
// the kernel entirely (same as skip & 4), since disable is not
501+
// accessible through the SIMDResultHandlerToFloat* interface.
502+
if (!(skip & (2 | 4))) {
503+
scanner->accumulate_loop_qbs(
541504
qbs,
542505
ntotal2,
543506
M2,
544507
codes.get(),
545508
LUT.get(),
546-
*handler.get(),
547509
context.pq2x4_scale,
548510
get_block_stride());
549511
}
550512
if (!(skip & 8)) {
551-
handler->end();
513+
rh->end();
552514
}
553515
}
554516

@@ -563,7 +525,6 @@ void IndexFastScan::search_implem_14(
563525
idx_t* labels,
564526
int impl,
565527
const FastScanDistancePostProcessing& context) const {
566-
using RH = ResultHandlerCompare<C, false>;
567528
FAISS_THROW_IF_NOT(bbs % 32 == 0);
568529

569530
int qbs2 = qbs == 0 ? 4 : qbs;
@@ -603,36 +564,27 @@ void IndexFastScan::search_implem_14(
603564
AlignedTable<uint8_t> LUT(n * dim12);
604565
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
605566

606-
std::unique_ptr<RH> handler(
607-
static_cast<RH*>(make_knn_handler(
608-
C::is_max,
609-
impl,
610-
n,
611-
k,
612-
ntotal,
613-
distances,
614-
labels,
615-
nullptr,
616-
context)));
617-
handler->disable = bool(skip & 2);
618-
handler->normalizers = normalizers.get();
619-
620-
if (skip & 4) {
621-
// pass
622-
} else {
623-
pq4_accumulate_loop(
567+
auto scanner = make_knn_scanner(
568+
C::is_max, n, k, ntotal, distances, labels, nullptr, impl, context);
569+
auto* rh = scanner->handler();
570+
rh->normalizers = normalizers.get();
571+
// Note: skip & 2 previously set handler->disable (run kernel,
572+
// discard results). Through the scanner path, skip & 2 now skips
573+
// the kernel entirely (same as skip & 4), since disable is not
574+
// accessible through the SIMDResultHandlerToFloat* interface.
575+
if (!(skip & (2 | 4))) {
576+
scanner->accumulate_loop(
624577
n,
625578
ntotal2,
626579
bbs,
627580
M2,
628581
codes.get(),
629582
LUT.get(),
630-
*handler.get(),
631583
context.pq2x4_scale,
632584
get_block_stride());
633585
}
634586
if (!(skip & 8)) {
635-
handler->end();
587+
rh->end();
636588
}
637589
}
638590

faiss/IndexFastScan.h

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
#pragma once
99

10+
#include <memory>
11+
1012
#include <faiss/Index.h>
1113
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
14+
#include <faiss/impl/fast_scan/fast_scan.h>
1215
#include <faiss/utils/AlignedTable.h>
1316

1417
namespace faiss {
1518

1619
struct CodePacker;
1720
struct IDSelector;
18-
struct SIMDResultHandlerToFloat;
1921

2022
/** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
2123
*
@@ -121,33 +123,22 @@ struct IndexFastScan : Index {
121123
const float* x,
122124
const FastScanDistancePostProcessing& context) const = 0;
123125

124-
/** Create a KNN handler for this index type
125-
*
126-
* This method can be overridden by derived classes to provide
127-
* specialized handlers (e.g., RaBitQHeapHandler for RaBitQ indexes).
128-
* Base implementation creates standard handlers based on k and impl.
126+
/** Create a SIMD-dispatched scanner for knn search.
129127
*
130-
* @param is_max whether to use CMax comparator (true) or CMin (false)
131-
* @param impl implementation number
132-
* @param n number of queries
133-
* @param k number of neighbors to find
134-
* @param ntotal total number of vectors in database
135-
* @param distances output distances array
136-
* @param labels output labels array
137-
* @param sel optional ID selector
138-
* @param context processing context for distance post-processing
139-
* @return pointer to created handler (never returns nullptr)
128+
* Returns a FastScanCodeScanner that bundles handler + accumulation
129+
* kernel behind the SIMD dispatch boundary.
130+
* The scanner's accumulate methods dispatch to the optimal SIMD level.
140131
*/
141-
virtual SIMDResultHandlerToFloat* make_knn_handler(
132+
virtual std::unique_ptr<FastScanCodeScanner> make_knn_scanner(
142133
bool is_max,
143-
int impl,
144134
idx_t n,
145135
idx_t k,
146136
size_t ntotal,
147137
float* distances,
148138
idx_t* labels,
149139
const IDSelector* sel,
150-
const FastScanDistancePostProcessing& context) const;
140+
int impl = 0,
141+
const FastScanDistancePostProcessing& context = {}) const;
151142

152143
// called by search function
153144
void compute_quantized_LUT(

faiss/IndexIVFAdditiveQuantizerFastScan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <faiss/impl/AuxIndexStructures.h>
1616
#include <faiss/impl/FaissAssert.h>
1717
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
18-
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
18+
#include <faiss/impl/fast_scan/fast_scan.h>
1919
#include <faiss/impl/simd_dispatch.h>
2020
#include <faiss/invlists/BlockInvertedLists.h>
2121
#include <faiss/utils/distances.h>

0 commit comments

Comments
 (0)