Skip to content

Commit 3f127ee

Browse files
alibeklfcmeta-codesync[bot]
authored andcommitted
Thread qb/centered search params through FastScan LUT and handler (#5095)
Summary: Pull Request resolved: #5095 D100399519 added IVFRaBitQSearchParameters support to the FastScan scanner but only patched the distance_to_code fallback path. The main search path (LUT construction and SIMD distance correction in handle()) still read qb/centered from the index, ignoring the search params override. This diff completes the fix by: 1. Adding qb/centered fields to FastScanDistancePostProcessing context 2. Threading them through compute_LUT → compute_residual_LUT 3. Reading them from context in the handler's handle() method 4. Extracting them from IVFRaBitQSearchParameters in search_preassigned Reviewed By: mnorris11 Differential Revision: D100674751 fbshipit-source-id: 72139ed1311f9422f8701933bf0464c411764d05
1 parent 61ed7dc commit 3f127ee

File tree

4 files changed

+90
-7
lines changed

4 files changed

+90
-7
lines changed

faiss/IndexIVFRaBitQFastScan.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,10 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
276276
const float* residual,
277277
QueryFactorsData& query_factors,
278278
float* lut_out,
279+
uint8_t qb_param,
280+
bool centered_param,
279281
const float* original_query) const {
280-
FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
282+
FAISS_THROW_IF_NOT(qb_param > 0 && qb_param <= 8);
281283

282284
std::vector<float> rotated_q(d);
283285
std::vector<uint8_t> rotated_qq(d);
@@ -287,8 +289,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
287289
residual,
288290
d,
289291
nullptr,
290-
qb,
291-
centered,
292+
qb_param,
293+
centered_param,
292294
metric_type,
293295
rotated_q,
294296
rotated_qq);
@@ -305,8 +307,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
305307
query_factors.rotated_q = rotated_q;
306308
}
307309

308-
if (centered) {
309-
const float max_code_value = (1 << qb) - 1;
310+
if (centered_param) {
311+
const float max_code_value = (1 << qb_param) - 1;
310312

311313
for (size_t m = 0; m < M; m++) {
312314
const size_t dim_start = m * 4;
@@ -372,15 +374,24 @@ void IndexIVFRaBitQFastScan::search_preassigned(
372374
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
373375

374376
size_t cur_nprobe = this->nprobe;
377+
uint8_t used_qb = qb;
378+
bool used_centered = centered;
375379
if (params) {
376380
FAISS_THROW_IF_NOT(params->max_codes == 0);
377381
cur_nprobe = params->nprobe;
382+
if (auto rparams =
383+
dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
384+
used_qb = rparams->qb;
385+
used_centered = rparams->centered;
386+
}
378387
}
379388

380389
std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
381390
FastScanDistancePostProcessing context;
382391
context.query_factors = query_factors_storage.data();
383392
context.nprobe = cur_nprobe;
393+
context.qb = used_qb;
394+
context.centered = used_centered;
384395

385396
const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
386397
search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
@@ -396,6 +407,10 @@ void IndexIVFRaBitQFastScan::compute_LUT(
396407
FAISS_THROW_IF_NOT(is_trained);
397408
FAISS_THROW_IF_NOT(by_residual);
398409

410+
// Use overridden qb/centered from context if provided, else index defaults
411+
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
412+
const bool used_centered = context.qb > 0 ? context.centered : centered;
413+
399414
size_t cq_nprobe = cq.nprobe;
400415

401416
size_t dim12 = 16 * M;
@@ -424,6 +439,8 @@ void IndexIVFRaBitQFastScan::compute_LUT(
424439
xij,
425440
query_factors_data,
426441
dis_tables.get() + ij * dim12,
442+
used_qb,
443+
used_centered,
427444
x + i * d);
428445

429446
// Store query factors using compact indexing (ij directly)
@@ -624,6 +641,8 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
624641
context = FastScanDistancePostProcessing{};
625642
context.query_factors = &query_factors;
626643
context.nprobe = 1;
644+
context.qb = qb;
645+
context.centered = centered;
627646

628647
index.compute_LUT_uint8(
629648
1, xi, cq, dis_tables, biases, &normalizers[0], context);

faiss/IndexIVFRaBitQFastScan.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
119119
const float* residual,
120120
QueryFactorsData& query_factors,
121121
float* lut_out,
122+
uint8_t qb_param,
123+
bool centered_param,
122124
const float* original_query = nullptr) const;
123125

124126
/// Decode FastScan code to RaBitQ residual vector with explicit
@@ -265,8 +267,9 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
265267
(idx_base / index->bbs) * full_block_size + packed_block_size;
266268

267269
// Cache index fields used in the inner loop.
268-
const bool centered = index->centered;
269-
const size_t qb = index->qb;
270+
// Use overridden qb/centered from context if provided, else index defaults.
271+
const bool centered = context->qb > 0 ? context->centered : index->centered;
272+
const size_t qb = context->qb > 0 ? context->qb : index->qb;
270273
const size_t d = index->d;
271274

272275
#ifndef NDEBUG

faiss/impl/fast_scan/FastScanDistancePostProcessing.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ struct FastScanDistancePostProcessing {
3535
/// Set to 0 to use index->nprobe as fallback.
3636
size_t nprobe = 0;
3737

38+
/// RaBitQ query quantization bits override.
39+
/// Set to 0 to use the index default (index->qb).
40+
uint8_t qb = 0;
41+
42+
/// RaBitQ centered scalar quantizer override.
43+
/// Only used when qb > 0 (i.e., when params are overridden).
44+
bool centered = false;
45+
3846
/// Default constructor - no processing
3947
FastScanDistancePostProcessing() = default;
4048

tests/test_rabitq_fastscan.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,59 @@ def test_ivf_factory_with_batch_size(self):
10701070
self.assertEqual(D.shape, (ds.nq, 5))
10711071

10721072

1073+
class TestRaBitQFastScanSearchParams(unittest.TestCase):
1074+
"""Test that IVFRaBitQSearchParameters qb/centered are respected."""
1075+
1076+
def test_higher_qb_improves_recall(self):
1077+
"""Search with qb=4 should give better recall than qb=1."""
1078+
d = 64
1079+
nlist = 16
1080+
nprobe = 4
1081+
k = 10
1082+
ds = datasets.SyntheticDataset(d, 5000, 5000, 50)
1083+
1084+
# Ground truth with flat index
1085+
index_flat = faiss.IndexFlatL2(d)
1086+
index_flat.add(ds.get_database())
1087+
_, I_gt = index_flat.search(ds.get_queries(), k)
1088+
1089+
# Build IVF RaBitQ FastScan index with default qb=8
1090+
quantizer = faiss.IndexFlat(d, faiss.METRIC_L2)
1091+
index = faiss.IndexIVFRaBitQFastScan(
1092+
quantizer, d, nlist, faiss.METRIC_L2, 32, True
1093+
)
1094+
index.nprobe = nprobe
1095+
index.train(ds.get_train())
1096+
index.add(ds.get_database())
1097+
1098+
# Search with qb=1 (coarse quantization)
1099+
params_qb1 = faiss.IVFRaBitQSearchParameters()
1100+
params_qb1.nprobe = nprobe
1101+
params_qb1.qb = 1
1102+
_, I_qb1 = index.search(ds.get_queries(), k, params=params_qb1)
1103+
1104+
# Search with qb=4 (finer quantization)
1105+
params_qb4 = faiss.IVFRaBitQSearchParameters()
1106+
params_qb4.nprobe = nprobe
1107+
params_qb4.qb = 4
1108+
_, I_qb4 = index.search(ds.get_queries(), k, params=params_qb4)
1109+
1110+
# Compute recall@k
1111+
recall_qb1 = np.mean([
1112+
len(np.intersect1d(I_qb1[i], I_gt[i])) / k
1113+
for i in range(ds.nq)
1114+
])
1115+
recall_qb4 = np.mean([
1116+
len(np.intersect1d(I_qb4[i], I_gt[i])) / k
1117+
for i in range(ds.nq)
1118+
])
1119+
1120+
self.assertGreater(
1121+
recall_qb4, recall_qb1,
1122+
f"qb=4 recall ({recall_qb4:.3f}) should be higher "
1123+
f"than qb=1 recall ({recall_qb1:.3f})"
1124+
)
1125+
10731126

10741127
if __name__ == "__main__":
10751128
unittest.main()

0 commit comments

Comments
 (0)