Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
# Architecture-specific: only include files for the current build architecture
# =============================================================================
set(FAISS_SIMD_AVX2_SRC
impl/fast_scan/impl-avx2.cpp
impl/pq_code_distance/pq_code_distance-avx2.cpp
impl/scalar_quantizer/sq-avx2.cpp
utils/simd_impl/distances_avx2.cpp
)
set(FAISS_SIMD_AVX512_SRC
impl/fast_scan/impl-avx512.cpp
impl/pq_code_distance/pq_code_distance-avx512.cpp
impl/scalar_quantizer/sq-avx512.cpp
utils/simd_impl/distances_avx512.cpp
)
set(FAISS_SIMD_NEON_SRC
impl/fast_scan/impl-neon.cpp
impl/scalar_quantizer/sq-neon.cpp
utils/simd_impl/distances_aarch64.cpp
)
Expand Down Expand Up @@ -117,7 +120,7 @@ set(FAISS_SRC
impl/kmeans1d.cpp
impl/lattice_Zn.cpp
impl/mapped_io.cpp
impl/fast_scan/pq4_fast_scan.cpp
impl/fast_scan/fast_scan.cpp
impl/fast_scan/pq4_fast_scan_search_1.cpp
impl/fast_scan/pq4_fast_scan_search_qbs.cpp
impl/residual_quantizer_encode_steps.cpp
Expand Down Expand Up @@ -262,10 +265,13 @@ set(FAISS_HEADERS
impl/kmeans1d.h
impl/lattice_Zn.h
impl/platform_macros.h
impl/fast_scan/pq4_fast_scan.h
impl/fast_scan/accumulate_loops.h
impl/fast_scan/dispatching.h
impl/fast_scan/fast_scan.h
impl/fast_scan/decompose_qbs.h
impl/fast_scan/kernels_simd256.h
impl/fast_scan/kernels_simd512.h
impl/fast_scan/rabitq_dispatching.h
impl/fast_scan/rabitq_result_handler.h
impl/residual_quantizer_encode_steps.h
impl/simd_dispatch.h
Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
#include <faiss/impl/fast_scan/fast_scan.h>
#include <faiss/utils/quantize_lut.h>
#include <faiss/utils/utils.h>

Expand Down
102 changes: 27 additions & 75 deletions faiss/IndexFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/RaBitQUtils.h>
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
#include <faiss/impl/fast_scan/fast_scan.h>
#include <faiss/impl/fast_scan/simd_result_handlers.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/quantize_lut.h>
#include <faiss/utils/utils.h>

namespace faiss {

using namespace simd_result_handlers;

inline size_t roundup(size_t a, size_t b) {
return (a + b - 1) / b * b;
}
Expand Down Expand Up @@ -211,43 +209,18 @@ void estimators_from_tables_generic(

} // anonymous namespace

// Default implementation of make_knn_handler with centralized fallback logic
SIMDResultHandlerToFloat* IndexFastScan::make_knn_handler(
std::unique_ptr<FastScanCodeScanner> IndexFastScan::make_knn_scanner(
bool is_max,
int impl,
idx_t n,
idx_t k,
size_t ntotal,
float* distances,
idx_t* labels,
const IDSelector* sel,
int impl,
const FastScanDistancePostProcessing&) const {
// Create default handlers based on k and impl
if (is_max) {
using HeapHC = HeapHandler<CMax<uint16_t, int>, false>;
using ReservoirHC = ReservoirHandler<CMax<uint16_t, int>, false>;
using SingleResultHC = SingleResultHandler<CMax<uint16_t, int>, false>;

if (k == 1) {
return new SingleResultHC(n, ntotal, distances, labels, sel);
} else if (impl % 2 == 0) {
return new HeapHC(n, ntotal, k, distances, labels, sel);
} else {
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
}
} else {
using HeapHC = HeapHandler<CMin<uint16_t, int>, false>;
using ReservoirHC = ReservoirHandler<CMin<uint16_t, int>, false>;
using SingleResultHC = SingleResultHandler<CMin<uint16_t, int>, false>;

if (k == 1) {
return new SingleResultHC(n, ntotal, distances, labels, sel);
} else if (impl % 2 == 0) {
return new HeapHC(n, ntotal, k, distances, labels, sel);
} else {
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
}
}
return make_fast_scan_knn_scanner(
is_max, impl, n, ntotal, k, distances, labels, sel);
}

using namespace quantize_lut;
Expand Down Expand Up @@ -468,7 +441,6 @@ void IndexFastScan::search_implem_12(
idx_t* labels,
int impl,
const FastScanDistancePostProcessing& context) const {
using RH = ResultHandlerCompare<C, false>;
FAISS_THROW_IF_NOT(bbs == 32);

// handle qbs2 blocking by recursive call
Expand Down Expand Up @@ -519,36 +491,26 @@ void IndexFastScan::search_implem_12(
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
FAISS_THROW_IF_NOT(LUT_nq == n);

std::unique_ptr<RH> handler(
static_cast<RH*>(make_knn_handler(
C::is_max,
impl,
n,
k,
ntotal,
distances,
labels,
nullptr,
context)));

handler->disable = bool(skip & 2);
handler->normalizers = normalizers.get();

if (skip & 4) {
// pass
} else {
pq4_accumulate_loop_qbs(
auto scanner = make_knn_scanner(
C::is_max, n, k, ntotal, distances, labels, nullptr, impl, context);
auto* rh = scanner->handler();
rh->normalizers = normalizers.get();
// Note: skip & 2 previously set handler->disable (run kernel,
// discard results). Through the scanner path, skip & 2 now skips
// the kernel entirely (same as skip & 4), since disable is not
// accessible through the SIMDResultHandlerToFloat* interface.
if (!(skip & (2 | 4))) {
scanner->accumulate_loop_qbs(
qbs,
ntotal2,
M2,
codes.get(),
LUT.get(),
*handler.get(),
context.pq2x4_scale,
get_block_stride());
}
if (!(skip & 8)) {
handler->end();
rh->end();
}
}

Expand All @@ -563,7 +525,6 @@ void IndexFastScan::search_implem_14(
idx_t* labels,
int impl,
const FastScanDistancePostProcessing& context) const {
using RH = ResultHandlerCompare<C, false>;
FAISS_THROW_IF_NOT(bbs % 32 == 0);

int qbs2 = qbs == 0 ? 4 : qbs;
Expand Down Expand Up @@ -603,36 +564,27 @@ void IndexFastScan::search_implem_14(
AlignedTable<uint8_t> LUT(n * dim12);
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());

std::unique_ptr<RH> handler(
static_cast<RH*>(make_knn_handler(
C::is_max,
impl,
n,
k,
ntotal,
distances,
labels,
nullptr,
context)));
handler->disable = bool(skip & 2);
handler->normalizers = normalizers.get();

if (skip & 4) {
// pass
} else {
pq4_accumulate_loop(
auto scanner = make_knn_scanner(
C::is_max, n, k, ntotal, distances, labels, nullptr, impl, context);
auto* rh = scanner->handler();
rh->normalizers = normalizers.get();
// Note: skip & 2 previously set handler->disable (run kernel,
// discard results). Through the scanner path, skip & 2 now skips
// the kernel entirely (same as skip & 4), since disable is not
// accessible through the SIMDResultHandlerToFloat* interface.
if (!(skip & (2 | 4))) {
scanner->accumulate_loop(
n,
ntotal2,
bbs,
M2,
codes.get(),
LUT.get(),
*handler.get(),
context.pq2x4_scale,
get_block_stride());
}
if (!(skip & 8)) {
handler->end();
rh->end();
}
}

Expand Down
29 changes: 10 additions & 19 deletions faiss/IndexFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@

#pragma once

#include <memory>

#include <faiss/Index.h>
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
#include <faiss/impl/fast_scan/fast_scan.h>
#include <faiss/utils/AlignedTable.h>

namespace faiss {

struct CodePacker;
struct IDSelector;
struct SIMDResultHandlerToFloat;

/** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
*
Expand Down Expand Up @@ -121,33 +123,22 @@ struct IndexFastScan : Index {
const float* x,
const FastScanDistancePostProcessing& context) const = 0;

/** Create a KNN handler for this index type
*
* This method can be overridden by derived classes to provide
* specialized handlers (e.g., RaBitQHeapHandler for RaBitQ indexes).
* Base implementation creates standard handlers based on k and impl.
/** Create a SIMD-dispatched scanner for knn search.
*
* @param is_max whether to use CMax comparator (true) or CMin (false)
* @param impl implementation number
* @param n number of queries
* @param k number of neighbors to find
* @param ntotal total number of vectors in database
* @param distances output distances array
* @param labels output labels array
* @param sel optional ID selector
* @param context processing context for distance post-processing
* @return pointer to created handler (never returns nullptr)
* Returns a FastScanCodeScanner that bundles handler + accumulation
* kernel behind the SIMD dispatch boundary.
* The scanner's accumulate methods dispatch to the optimal SIMD level.
*/
virtual SIMDResultHandlerToFloat* make_knn_handler(
virtual std::unique_ptr<FastScanCodeScanner> make_knn_scanner(
bool is_max,
int impl,
idx_t n,
idx_t k,
size_t ntotal,
float* distances,
idx_t* labels,
const IDSelector* sel,
const FastScanDistancePostProcessing& context) const;
int impl = 0,
const FastScanDistancePostProcessing& context = {}) const;

// called by search function
void compute_quantized_LUT(
Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexIVFAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
#include <faiss/impl/fast_scan/pq4_fast_scan.h>
#include <faiss/impl/fast_scan/fast_scan.h>
#include <faiss/impl/simd_dispatch.h>
#include <faiss/invlists/BlockInvertedLists.h>
#include <faiss/utils/distances.h>
Expand Down
Loading
Loading