Skip to content

Commit eddea2a

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Extract RaBitQ result handler to impl/fast_scan/ (#4895)
Summary: Pull Request resolved: #4895 Move IVFRaBitQHeapHandler from a nested class inside IndexIVFRaBitQFastScan to a standalone template in impl/fast_scan/rabitq_result_handler.h. Add SL template parameter for future DD support. A using alias in the index class preserves source compatibility. Pure refactor — no behavior change. Reviewed By: mdouze Differential Revision: D95950482 fbshipit-source-id: 9e5d101e8392e5b0fc53c6f475d4ea46b69ba5ff
1 parent 9d6b2e7 commit eddea2a

4 files changed

Lines changed: 130 additions & 96 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ set(FAISS_HEADERS
266266
impl/fast_scan/decompose_qbs.h
267267
impl/fast_scan/kernels_simd256.h
268268
impl/fast_scan/kernels_simd512.h
269+
impl/fast_scan/rabitq_result_handler.h
269270
impl/residual_quantizer_encode_steps.h
270271
impl/simd_dispatch.h
271272
impl/fast_scan/simd_result_handlers.h

faiss/IndexIVFRaBitQFastScan.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,11 @@ SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
563563
}
564564

565565
/*********************************************************
566-
* IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler implementation
566+
* simd_result_handlers::IVFRaBitQHeapHandler implementation
567567
*********************************************************/
568568

569-
template <class C>
570-
IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
569+
template <class C, SIMDLevel SL>
570+
simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::IVFRaBitQHeapHandler(
571571
const IndexIVFRaBitQFastScan* idx,
572572
size_t nq_val,
573573
size_t k_val,
@@ -601,8 +601,8 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
601601
}
602602
}
603603

604-
template <class C>
605-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
604+
template <class C, SIMDLevel SL>
605+
void simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::handle(
606606
size_t q,
607607
size_t b,
608608
simd16uint16 d0,
@@ -748,23 +748,23 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
748748
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
749749
}
750750

751-
template <class C>
752-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
751+
template <class C, SIMDLevel SL>
752+
void simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::set_list_context(
753753
size_t list_no,
754754
const std::vector<int>& probe_map) {
755755
current_list_no = list_no;
756756
probe_indices = probe_map;
757757
list_codes_ptr = index->invlists->get_codes(list_no);
758758
}
759759

760-
template <class C>
761-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::begin(
760+
template <class C, SIMDLevel SL>
761+
void simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::begin(
762762
const float* norms) {
763763
this->normalizers = norms;
764764
}
765765

766-
template <class C>
767-
void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
766+
template <class C, SIMDLevel SL>
767+
void simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::end() {
768768
#pragma omp parallel for
769769
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
770770
float* heap_dis = heap_distances + q * k;
@@ -773,8 +773,8 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
773773
}
774774
}
775775

776-
template <class C>
777-
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
776+
template <class C, SIMDLevel SL>
777+
float simd_result_handlers::IVFRaBitQHeapHandler<C, SL>::
778778
compute_full_multibit_distance(
779779
size_t /*db_idx*/,
780780
size_t local_q,

faiss/IndexIVFRaBitQFastScan.h

Lines changed: 3 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <faiss/impl/RaBitQStats.h>
1717
#include <faiss/impl/RaBitQUtils.h>
1818
#include <faiss/impl/RaBitQuantizer.h>
19-
#include <faiss/impl/fast_scan/simd_result_handlers.h>
19+
#include <faiss/impl/fast_scan/rabitq_result_handler.h>
2020
#include <faiss/utils/AlignedTable.h>
2121
#include <faiss/utils/Heap.h>
2222

@@ -167,89 +167,9 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
167167
const IDSelector* sel = nullptr,
168168
const IVFSearchParameters* params = nullptr) const override;
169169

170-
/** SIMD result handler for IndexIVFRaBitQFastScan that applies
171-
* RaBitQ-specific distance corrections during batch processing.
172-
*
173-
* This handler processes batches of 32 distance computations from SIMD
174-
* kernels, applies RaBitQ distance formula adjustments (factors and
175-
* normalizers), and immediately updates result heaps. This eliminates the
176-
* need for post-processing and provides significant performance benefits.
177-
*
178-
* Key optimizations:
179-
* - Direct heap integration with no intermediate result storage
180-
* - Batch-level computation of normalizers and query factors
181-
* - Specialized handling for both centered and non-centered quantization
182-
* modes
183-
* - Efficient inner product metric corrections
184-
* - Uses runtime boolean for multi-bit mode
185-
*
186-
* @tparam C Comparator type (CMin/CMax) for heap operations
187-
*/
170+
/// RaBitQ-specific result handler (defined in impl/fast_scan/)
188171
template <class C>
189-
struct IVFRaBitQHeapHandler
190-
: simd_result_handlers::ResultHandlerCompare<C, true> {
191-
const IndexIVFRaBitQFastScan* index;
192-
float* heap_distances; // [nq * k]
193-
int64_t* heap_labels; // [nq * k]
194-
const size_t nq, k;
195-
size_t current_list_no = 0;
196-
const uint8_t* list_codes_ptr = nullptr; // raw block data for list
197-
std::vector<int>
198-
probe_indices; // probe index for each query in current batch
199-
const FastScanDistancePostProcessing*
200-
context; // Processing context with query factors
201-
const bool is_multibit; // Whether to use multi-bit two-stage search
202-
size_t nup = 0; // Number of heap updates
203-
204-
// Cached block-layout constants (invariant for handler lifetime)
205-
const size_t storage_size;
206-
const size_t packed_block_size;
207-
const size_t full_block_size;
208-
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
209-
210-
// Use float-based comparator for heap operations
211-
using Cfloat = typename std::conditional<
212-
C::is_max,
213-
CMax<float, int64_t>,
214-
CMin<float, int64_t>>::type;
215-
216-
IVFRaBitQHeapHandler(
217-
const IndexIVFRaBitQFastScan* idx,
218-
size_t nq_val,
219-
size_t k_val,
220-
float* distances,
221-
int64_t* labels,
222-
const FastScanDistancePostProcessing* ctx = nullptr,
223-
bool multibit = false);
224-
225-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1)
226-
override;
227-
228-
/// Override base class virtual method to receive context information
229-
void set_list_context(size_t list_no, const std::vector<int>& probe_map)
230-
override;
231-
232-
void begin(const float* norms) override;
233-
234-
void end() override;
235-
236-
size_t num_updates() override {
237-
return nup;
238-
}
239-
240-
private:
241-
/// Compute full multi-bit distance for a candidate vector (multi-bit
242-
/// only)
243-
/// @param db_idx Global database vector index
244-
/// @param local_q Batch-local query index (for probe_indices access)
245-
/// @param global_q Global query index (for storage indexing)
246-
/// @param local_offset Offset within the current inverted list
247-
float compute_full_multibit_distance(
248-
size_t /*db_idx*/,
249-
size_t local_q,
250-
size_t global_q,
251-
size_t local_offset) const;
252-
};
172+
using IVFRaBitQHeapHandler = simd_result_handlers::IVFRaBitQHeapHandler<C>;
253173
};
254174

255175
} // namespace faiss
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <memory>
11+
#include <vector>
12+
13+
#include <faiss/impl/CodePacker.h>
14+
#include <faiss/impl/fast_scan/FastScanDistancePostProcessing.h>
15+
#include <faiss/impl/fast_scan/simd_result_handlers.h>
16+
#include <faiss/utils/Heap.h>
17+
18+
namespace faiss {
19+
20+
// Forward declaration — full definition needed only in implementation
21+
struct IndexIVFRaBitQFastScan;
22+
23+
namespace simd_result_handlers {
24+
25+
/** SIMD result handler for IndexIVFRaBitQFastScan that applies
26+
* RaBitQ-specific distance corrections during batch processing.
27+
*
28+
* This handler processes batches of 32 distance computations from SIMD
29+
* kernels, applies RaBitQ distance formula adjustments (factors and
30+
* normalizers), and immediately updates result heaps. This eliminates the
31+
* need for post-processing and provides significant performance benefits.
32+
*
33+
* Key optimizations:
34+
* - Direct heap integration with no intermediate result storage
35+
* - Batch-level computation of normalizers and query factors
36+
* - Specialized handling for both centered and non-centered quantization
37+
* modes
38+
* - Efficient inner product metric corrections
39+
* - Uses runtime boolean for multi-bit mode
40+
*
41+
* @tparam C Comparator type (CMin/CMax) for heap operations
42+
* @tparam SL SIMD level for dynamic dispatch
43+
*/
44+
template <class C, SIMDLevel SL = SINGLE_SIMD_LEVEL_256>
45+
struct IVFRaBitQHeapHandler : ResultHandlerCompare<C, true, SL> {
46+
using RHC = ResultHandlerCompare<C, true, SL>;
47+
using typename RHC::simd16uint16;
48+
49+
const IndexIVFRaBitQFastScan* index;
50+
float* heap_distances; // [nq * k]
51+
int64_t* heap_labels; // [nq * k]
52+
const size_t nq, k;
53+
size_t current_list_no = 0;
54+
const uint8_t* list_codes_ptr = nullptr; // raw block data for list
55+
std::vector<int>
56+
probe_indices; // probe index for each query in current batch
57+
const FastScanDistancePostProcessing*
58+
context; // Processing context with query factors
59+
const bool is_multibit; // Whether to use multi-bit two-stage search
60+
size_t nup = 0; // Number of heap updates
61+
62+
// Cached block-layout constants (invariant for handler lifetime)
63+
const size_t storage_size;
64+
const size_t packed_block_size;
65+
const size_t full_block_size;
66+
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
67+
68+
// Use float-based comparator for heap operations
69+
using Cfloat = typename std::conditional<
70+
C::is_max,
71+
CMax<float, int64_t>,
72+
CMin<float, int64_t>>::type;
73+
74+
IVFRaBitQHeapHandler(
75+
const IndexIVFRaBitQFastScan* idx,
76+
size_t nq_val,
77+
size_t k_val,
78+
float* distances,
79+
int64_t* labels,
80+
const FastScanDistancePostProcessing* ctx = nullptr,
81+
bool multibit = false);
82+
83+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) override;
84+
85+
/// Override base class virtual method to receive context information
86+
void set_list_context(size_t list_no, const std::vector<int>& probe_map)
87+
override;
88+
89+
void begin(const float* norms) override;
90+
91+
void end() override;
92+
93+
size_t num_updates() override {
94+
return nup;
95+
}
96+
97+
private:
98+
/// Compute full multi-bit distance for a candidate vector (multi-bit
99+
/// only)
100+
/// @param db_idx Global database vector index
101+
/// @param local_q Batch-local query index (for probe_indices access)
102+
/// @param global_q Global query index (for storage indexing)
103+
/// @param local_offset Offset within the current inverted list
104+
float compute_full_multibit_distance(
105+
size_t /*db_idx*/,
106+
size_t local_q,
107+
size_t global_q,
108+
size_t local_offset) const;
109+
};
110+
111+
} // namespace simd_result_handlers
112+
113+
} // namespace faiss

0 commit comments

Comments
 (0)