diff --git a/faiss/IndexIVFRaBitQFastScan.cpp b/faiss/IndexIVFRaBitQFastScan.cpp index 113a034672..a6c1a3f400 100644 --- a/faiss/IndexIVFRaBitQFastScan.cpp +++ b/faiss/IndexIVFRaBitQFastScan.cpp @@ -276,8 +276,10 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT( const float* residual, QueryFactorsData& query_factors, float* lut_out, + uint8_t qb_param, + bool centered_param, const float* original_query) const { - FAISS_THROW_IF_NOT(qb > 0 && qb <= 8); + FAISS_THROW_IF_NOT(qb_param > 0 && qb_param <= 8); std::vector rotated_q(d); std::vector rotated_qq(d); @@ -287,8 +289,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT( residual, d, nullptr, - qb, - centered, + qb_param, + centered_param, metric_type, rotated_q, rotated_qq); @@ -305,8 +307,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT( query_factors.rotated_q = rotated_q; } - if (centered) { - const float max_code_value = (1 << qb) - 1; + if (centered_param) { + const float max_code_value = (1 << qb_param) - 1; for (size_t m = 0; m < M; m++) { const size_t dim_start = m * 4; @@ -372,15 +374,24 @@ void IndexIVFRaBitQFastScan::search_preassigned( FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index"); size_t cur_nprobe = this->nprobe; + uint8_t used_qb = qb; + bool used_centered = centered; if (params) { FAISS_THROW_IF_NOT(params->max_codes == 0); cur_nprobe = params->nprobe; + if (auto rparams = + dynamic_cast(params)) { + used_qb = rparams->qb; + used_centered = rparams->centered; + } } std::vector query_factors_storage(n * cur_nprobe); FastScanDistancePostProcessing context; context.query_factors = query_factors_storage.data(); context.nprobe = cur_nprobe; + context.qb = used_qb; + context.centered = used_centered; const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign}; search_dispatch_implem(n, x, k, distances, labels, cq, context, params); @@ -396,6 +407,10 @@ void IndexIVFRaBitQFastScan::compute_LUT( FAISS_THROW_IF_NOT(is_trained); FAISS_THROW_IF_NOT(by_residual); + // Use overridden qb/centered from context if provided, else index defaults + const uint8_t used_qb = context.qb > 0 ? context.qb : qb; + const bool used_centered = context.qb > 0 ? context.centered : centered; + size_t cq_nprobe = cq.nprobe; size_t dim12 = 16 * M; @@ -424,6 +439,8 @@ void IndexIVFRaBitQFastScan::compute_LUT( xij, query_factors_data, dis_tables.get() + ij * dim12, + used_qb, + used_centered, x + i * d); // Store query factors using compact indexing (ij directly) @@ -624,6 +641,8 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner { context = FastScanDistancePostProcessing{}; context.query_factors = &query_factors; context.nprobe = 1; + context.qb = qb; + context.centered = centered; index.compute_LUT_uint8( 1, xi, cq, dis_tables, biases, &normalizers[0], context); diff --git a/faiss/IndexIVFRaBitQFastScan.h b/faiss/IndexIVFRaBitQFastScan.h index 0e19f53a31..97db8de512 100644 --- a/faiss/IndexIVFRaBitQFastScan.h +++ b/faiss/IndexIVFRaBitQFastScan.h @@ -119,6 +119,8 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan { const float* residual, QueryFactorsData& query_factors, float* lut_out, + uint8_t qb_param, + bool centered_param, const float* original_query = nullptr) const; /// Decode FastScan code to RaBitQ residual vector with explicit @@ -204,8 +206,7 @@ IVFRaBitQHeapHandler::IVFRaBitQHeapHandler( storage_size(idx->compute_per_vector_storage_size()), packed_block_size(((idx->M2 + 1) / 2) * idx->bbs), full_block_size(idx->get_block_stride()), - packer(idx->get_CodePacker()), - unpack_buf(idx->code_size) { + unpack_buf((idx->d + 7) / 8) { current_list_no = 0; probe_indices.clear(); for (int64_t q = 0; q < static_cast(nq); q++) { @@ -265,8 +266,9 @@ void IVFRaBitQHeapHandler::handle( (idx_base / index->bbs) * full_block_size + packed_block_size; // Cache index fields used in the inner loop. - const bool centered = index->centered; - const size_t qb = index->qb; + // Use overridden qb/centered from context if provided, else index defaults. + const bool centered = context->qb > 0 ? context->centered : index->centered; + const size_t qb = context->qb > 0 ? context->qb : index->qb; const size_t d = index->d; #ifndef NDEBUG @@ -394,8 +396,13 @@ float IVFRaBitQHeapHandler::compute_full_multibit_distance( const size_t storage_idx_val = global_q * cached_nprobe + probe_rank; const auto& query_factors = context->query_factors[storage_idx_val]; - // Unpack PQ4-interleaved sign bits for this vector into a linear buffer. - packer->unpack_1(this->list_codes_ptr, local_offset, unpack_buf.data()); + rabitq_utils::unpack_sign_bits_from_packed( + this->list_codes_ptr, + index->bbs, + index->M2, + local_offset, + full_block_size, + unpack_buf.data()); return rabitq_utils::compute_full_multibit_distance( unpack_buf.data(), diff --git a/faiss/IndexRaBitQFastScan.h b/faiss/IndexRaBitQFastScan.h index af588c87c4..7f318a8e9d 100644 --- a/faiss/IndexRaBitQFastScan.h +++ b/faiss/IndexRaBitQFastScan.h @@ -151,10 +151,7 @@ struct RaBitQHeapHandler const size_t storage_size; const size_t packed_block_size; const size_t full_block_size; - std::unique_ptr packer; // cached for unpack in hot path - // Handler-local scratch reused across refinements. This assumes a handler - // instance is confined to one search slice and not entered concurrently. - std::vector unpack_buf; // reusable buffer for unpack_1 + std::vector unpack_buf; // sign bits scratch buffer // Use float-based comparator for heap operations using Cfloat = typename std::conditional< @@ -182,8 +179,7 @@ struct RaBitQHeapHandler storage_size(index->compute_per_vector_storage_size()), packed_block_size(((index->M2 + 1) / 2) * index->bbs), full_block_size(index->get_block_stride()), - packer(index->get_CodePacker()), - unpack_buf(index->code_size) { + unpack_buf((index->d + 7) / 8) { #pragma omp parallel for if (nq > 100) for (int64_t q = 0; q < static_cast(nq); q++) { float* heap_dis = heap_distances + q * k; @@ -331,9 +327,13 @@ struct RaBitQHeapHandler const rabitq_utils::QueryFactorsData& query_factors = context->query_factors[q]; - // Reuse pre-allocated unpack_buf to avoid per-refinement heap - // allocation. - packer->unpack_1(rabitq_index->codes.get(), db_idx, unpack_buf.data()); + rabitq_utils::unpack_sign_bits_from_packed( + rabitq_index->codes.get(), + rabitq_index->bbs, + rabitq_index->M2, + db_idx, + full_block_size, + unpack_buf.data()); const uint8_t* sign_bits = unpack_buf.data(); return rabitq_utils::compute_full_multibit_distance( diff --git a/faiss/impl/RaBitQUtils.h b/faiss/impl/RaBitQUtils.h index 1eb2ebf04c..e3cc4b9d39 100644 --- a/faiss/impl/RaBitQUtils.h +++ b/faiss/impl/RaBitQUtils.h @@ -380,6 +380,41 @@ inline T* get_block_aux_ptr( (vec_pos % bbs) * storage_size; } +/// Extract sign bits from PQ4-interleaved block into flat byte packing. +/// Like CodePackerRaBitQ::unpack_1 but sign-bits-only and with the +/// vector's in-block address hoisted out of the per-SQ loop. +inline void unpack_sign_bits_from_packed( + const uint8_t* block, + size_t bbs, + size_t nsq, + size_t offset, + size_t block_stride, + uint8_t* sign_bits_out) { + block += (offset / bbs) * block_stride; + offset = offset % bbs; + + const bool nibble_high = offset > 15; + const size_t vid = offset & 15; + const size_t in_group_addr = + (vid < 8) ? (vid << 1) : (((vid - 8) << 1) + 1); + + const size_t num_pairs = nsq / 2; + for (size_t k = 0; k < num_pairs; k++) { + const size_t base = k * bbs; + const uint8_t raw_even = block[base + in_group_addr]; + const uint8_t raw_odd = block[base + in_group_addr + 16]; + + const uint8_t nib0 = nibble_high ? (raw_even >> 4) : (raw_even & 0xF); + const uint8_t nib1 = nibble_high ? (raw_odd >> 4) : (raw_odd & 0xF); + sign_bits_out[k] = nib0 | (nib1 << 4); + } + + if (nsq & 1) { + const uint8_t raw = block[num_pairs * bbs + in_group_addr]; + sign_bits_out[num_pairs] = nibble_high ? (raw >> 4) : (raw & 0xF); + } +} + /** Compute per-vector auxiliary storage size. * * @param nb_bits number of quantization bits (1 = sign-bit only) diff --git a/faiss/impl/fast_scan/FastScanDistancePostProcessing.h b/faiss/impl/fast_scan/FastScanDistancePostProcessing.h index 9a09a2165c..72b5b27f6a 100644 --- a/faiss/impl/fast_scan/FastScanDistancePostProcessing.h +++ b/faiss/impl/fast_scan/FastScanDistancePostProcessing.h @@ -35,6 +35,14 @@ struct FastScanDistancePostProcessing { /// Set to 0 to use index->nprobe as fallback. size_t nprobe = 0; + /// RaBitQ query quantization bits override. + /// Set to 0 to use the index default (index->qb). + uint8_t qb = 0; + + /// RaBitQ centered scalar quantizer override. + /// Only used when qb > 0 (i.e., when params are overridden). + bool centered = false; + /// Default constructor - no processing FastScanDistancePostProcessing() = default; diff --git a/faiss/impl/fast_scan/rabitq_result_handler.h b/faiss/impl/fast_scan/rabitq_result_handler.h index 7ea9149680..3cd29163a4 100644 --- a/faiss/impl/fast_scan/rabitq_result_handler.h +++ b/faiss/impl/fast_scan/rabitq_result_handler.h @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -58,10 +57,7 @@ struct IVFRaBitQHeapHandler : ResultHandlerCompare { const size_t storage_size; const size_t packed_block_size; const size_t full_block_size; - std::unique_ptr packer; // cached for unpack in hot path - // Handler-local scratch reused across refinements. This assumes a handler - // instance is confined to one search slice and not entered concurrently. - std::vector unpack_buf; // reusable buffer for unpack_1 + std::vector unpack_buf; // sign bits scratch buffer // Cached per-list values (set in set_list_context, avoid recomputing in // handle) diff --git a/tests/test_rabitq_fastscan.py b/tests/test_rabitq_fastscan.py index 14645094a5..347dfa0abd 100644 --- a/tests/test_rabitq_fastscan.py +++ b/tests/test_rabitq_fastscan.py @@ -1070,6 +1070,59 @@ def test_ivf_factory_with_batch_size(self): self.assertEqual(D.shape, (ds.nq, 5)) +class TestRaBitQFastScanSearchParams(unittest.TestCase): + """Test that IVFRaBitQSearchParameters qb/centered are respected.""" + + def test_higher_qb_improves_recall(self): + """Search with qb=4 should give better recall than qb=1.""" + d = 64 + nlist = 16 + nprobe = 4 + k = 10 + ds = datasets.SyntheticDataset(d, 5000, 5000, 50) + + # Ground truth with flat index + index_flat = faiss.IndexFlatL2(d) + index_flat.add(ds.get_database()) + _, I_gt = index_flat.search(ds.get_queries(), k) + + # Build IVF RaBitQ FastScan index with default qb=8 + quantizer = faiss.IndexFlat(d, faiss.METRIC_L2) + index = faiss.IndexIVFRaBitQFastScan( + quantizer, d, nlist, faiss.METRIC_L2, 32, True + ) + index.nprobe = nprobe + index.train(ds.get_train()) + index.add(ds.get_database()) + + # Search with qb=1 (coarse quantization) + params_qb1 = faiss.IVFRaBitQSearchParameters() + params_qb1.nprobe = nprobe + params_qb1.qb = 1 + _, I_qb1 = index.search(ds.get_queries(), k, params=params_qb1) + + # Search with qb=4 (finer quantization) + params_qb4 = faiss.IVFRaBitQSearchParameters() + params_qb4.nprobe = nprobe + params_qb4.qb = 4 + _, I_qb4 = index.search(ds.get_queries(), k, params=params_qb4) + + # Compute recall@k + recall_qb1 = np.mean([ + len(np.intersect1d(I_qb1[i], I_gt[i])) / k + for i in range(ds.nq) + ]) + recall_qb4 = np.mean([ + len(np.intersect1d(I_qb4[i], I_gt[i])) / k + for i in range(ds.nq) + ]) + + self.assertGreater( + recall_qb4, recall_qb1, + f"qb=4 recall ({recall_qb4:.3f}) should be higher " + f"than qb=1 recall ({recall_qb1:.3f})" + ) + if __name__ == "__main__": unittest.main()