Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions faiss/IndexIVFRaBitQFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,10 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
const float* residual,
QueryFactorsData& query_factors,
float* lut_out,
uint8_t qb_param,
bool centered_param,
const float* original_query) const {
FAISS_THROW_IF_NOT(qb > 0 && qb <= 8);
FAISS_THROW_IF_NOT(qb_param > 0 && qb_param <= 8);

std::vector<float> rotated_q(d);
std::vector<uint8_t> rotated_qq(d);
Expand All @@ -287,8 +289,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
residual,
d,
nullptr,
qb,
centered,
qb_param,
centered_param,
metric_type,
rotated_q,
rotated_qq);
Expand All @@ -305,8 +307,8 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
query_factors.rotated_q = rotated_q;
}

if (centered) {
const float max_code_value = (1 << qb) - 1;
if (centered_param) {
const float max_code_value = (1 << qb_param) - 1;

for (size_t m = 0; m < M; m++) {
const size_t dim_start = m * 4;
Expand Down Expand Up @@ -372,15 +374,24 @@ void IndexIVFRaBitQFastScan::search_preassigned(
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");

size_t cur_nprobe = this->nprobe;
uint8_t used_qb = qb;
bool used_centered = centered;
if (params) {
FAISS_THROW_IF_NOT(params->max_codes == 0);
cur_nprobe = params->nprobe;
if (auto rparams =
dynamic_cast<const IVFRaBitQSearchParameters*>(params)) {
used_qb = rparams->qb;
used_centered = rparams->centered;
}
}

std::vector<QueryFactorsData> query_factors_storage(n * cur_nprobe);
FastScanDistancePostProcessing context;
context.query_factors = query_factors_storage.data();
context.nprobe = cur_nprobe;
context.qb = used_qb;
context.centered = used_centered;

const CoarseQuantized cq = {cur_nprobe, centroid_dis, assign};
search_dispatch_implem(n, x, k, distances, labels, cq, context, params);
Expand All @@ -396,6 +407,10 @@ void IndexIVFRaBitQFastScan::compute_LUT(
FAISS_THROW_IF_NOT(is_trained);
FAISS_THROW_IF_NOT(by_residual);

// Use overridden qb/centered from context if provided, else index defaults
const uint8_t used_qb = context.qb > 0 ? context.qb : qb;
const bool used_centered = context.qb > 0 ? context.centered : centered;

size_t cq_nprobe = cq.nprobe;

size_t dim12 = 16 * M;
Expand Down Expand Up @@ -424,6 +439,8 @@ void IndexIVFRaBitQFastScan::compute_LUT(
xij,
query_factors_data,
dis_tables.get() + ij * dim12,
used_qb,
used_centered,
x + i * d);

// Store query factors using compact indexing (ij directly)
Expand Down Expand Up @@ -624,6 +641,8 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
context = FastScanDistancePostProcessing{};
context.query_factors = &query_factors;
context.nprobe = 1;
context.qb = qb;
context.centered = centered;

index.compute_LUT_uint8(
1, xi, cq, dis_tables, biases, &normalizers[0], context);
Expand Down
19 changes: 13 additions & 6 deletions faiss/IndexIVFRaBitQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
const float* residual,
QueryFactorsData& query_factors,
float* lut_out,
uint8_t qb_param,
bool centered_param,
const float* original_query = nullptr) const;

/// Decode FastScan code to RaBitQ residual vector with explicit
Expand Down Expand Up @@ -204,8 +206,7 @@ IVFRaBitQHeapHandler<C, SL>::IVFRaBitQHeapHandler(
storage_size(idx->compute_per_vector_storage_size()),
packed_block_size(((idx->M2 + 1) / 2) * idx->bbs),
full_block_size(idx->get_block_stride()),
packer(idx->get_CodePacker()),
unpack_buf(idx->code_size) {
unpack_buf((idx->d + 7) / 8) {
current_list_no = 0;
probe_indices.clear();
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
Expand Down Expand Up @@ -265,8 +266,9 @@ void IVFRaBitQHeapHandler<C, SL>::handle(
(idx_base / index->bbs) * full_block_size + packed_block_size;

// Cache index fields used in the inner loop.
const bool centered = index->centered;
const size_t qb = index->qb;
// Use overridden qb/centered from context if provided, else index defaults.
const bool centered = context->qb > 0 ? context->centered : index->centered;
const size_t qb = context->qb > 0 ? context->qb : index->qb;
const size_t d = index->d;

#ifndef NDEBUG
Expand Down Expand Up @@ -394,8 +396,13 @@ float IVFRaBitQHeapHandler<C, SL>::compute_full_multibit_distance(
const size_t storage_idx_val = global_q * cached_nprobe + probe_rank;
const auto& query_factors = context->query_factors[storage_idx_val];

// Unpack PQ4-interleaved sign bits for this vector into a linear buffer.
packer->unpack_1(this->list_codes_ptr, local_offset, unpack_buf.data());
rabitq_utils::unpack_sign_bits_from_packed(
this->list_codes_ptr,
index->bbs,
index->M2,
local_offset,
full_block_size,
unpack_buf.data());

return rabitq_utils::compute_full_multibit_distance(
unpack_buf.data(),
Expand Down
18 changes: 9 additions & 9 deletions faiss/IndexRaBitQFastScan.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,7 @@ struct RaBitQHeapHandler
const size_t storage_size;
const size_t packed_block_size;
const size_t full_block_size;
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
// Handler-local scratch reused across refinements. This assumes a handler
// instance is confined to one search slice and not entered concurrently.
std::vector<uint8_t> unpack_buf; // reusable buffer for unpack_1
std::vector<uint8_t> unpack_buf; // sign bits scratch buffer

// Use float-based comparator for heap operations
using Cfloat = typename std::conditional<
Expand Down Expand Up @@ -182,8 +179,7 @@ struct RaBitQHeapHandler
storage_size(index->compute_per_vector_storage_size()),
packed_block_size(((index->M2 + 1) / 2) * index->bbs),
full_block_size(index->get_block_stride()),
packer(index->get_CodePacker()),
unpack_buf(index->code_size) {
unpack_buf((index->d + 7) / 8) {
#pragma omp parallel for if (nq > 100)
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
float* heap_dis = heap_distances + q * k;
Expand Down Expand Up @@ -331,9 +327,13 @@ struct RaBitQHeapHandler
const rabitq_utils::QueryFactorsData& query_factors =
context->query_factors[q];

// Reuse pre-allocated unpack_buf to avoid per-refinement heap
// allocation.
packer->unpack_1(rabitq_index->codes.get(), db_idx, unpack_buf.data());
rabitq_utils::unpack_sign_bits_from_packed(
rabitq_index->codes.get(),
rabitq_index->bbs,
rabitq_index->M2,
db_idx,
full_block_size,
unpack_buf.data());
const uint8_t* sign_bits = unpack_buf.data();

return rabitq_utils::compute_full_multibit_distance(
Expand Down
35 changes: 35 additions & 0 deletions faiss/impl/RaBitQUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,41 @@ inline T* get_block_aux_ptr(
(vec_pos % bbs) * storage_size;
}

/// Extract sign bits from PQ4-interleaved block into flat byte packing.
/// Like CodePackerRaBitQ::unpack_1 but sign-bits-only and with the
/// vector's in-block address hoisted out of the per-SQ loop.
inline void unpack_sign_bits_from_packed(
const uint8_t* block,
size_t bbs,
size_t nsq,
size_t offset,
size_t block_stride,
uint8_t* sign_bits_out) {
block += (offset / bbs) * block_stride;
offset = offset % bbs;

const bool nibble_high = offset > 15;
const size_t vid = offset & 15;
const size_t in_group_addr =
(vid < 8) ? (vid << 1) : (((vid - 8) << 1) + 1);

const size_t num_pairs = nsq / 2;
for (size_t k = 0; k < num_pairs; k++) {
const size_t base = k * bbs;
const uint8_t raw_even = block[base + in_group_addr];
const uint8_t raw_odd = block[base + in_group_addr + 16];

const uint8_t nib0 = nibble_high ? (raw_even >> 4) : (raw_even & 0xF);
const uint8_t nib1 = nibble_high ? (raw_odd >> 4) : (raw_odd & 0xF);
sign_bits_out[k] = nib0 | (nib1 << 4);
}

if (nsq & 1) {
const uint8_t raw = block[num_pairs * bbs + in_group_addr];
sign_bits_out[num_pairs] = nibble_high ? (raw >> 4) : (raw & 0xF);
}
}

/** Compute per-vector auxiliary storage size.
*
* @param nb_bits number of quantization bits (1 = sign-bit only)
Expand Down
8 changes: 8 additions & 0 deletions faiss/impl/fast_scan/FastScanDistancePostProcessing.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ struct FastScanDistancePostProcessing {
/// Set to 0 to use index->nprobe as fallback.
size_t nprobe = 0;

/// RaBitQ query quantization bits override.
/// Set to 0 to use the index default (index->qb).
uint8_t qb = 0;

/// RaBitQ centered scalar quantizer override.
/// Only used when qb > 0 (i.e., when params are overridden).
bool centered = false;

/// Default constructor - no processing
FastScanDistancePostProcessing() = default;

Expand Down
6 changes: 1 addition & 5 deletions faiss/impl/fast_scan/rabitq_result_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <memory>
#include <vector>

#include <faiss/impl/CodePacker.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/RaBitQStats.h>
#include <faiss/impl/RaBitQUtils.h>
Expand Down Expand Up @@ -58,10 +57,7 @@ struct IVFRaBitQHeapHandler : ResultHandlerCompare<C, true, SL> {
const size_t storage_size;
const size_t packed_block_size;
const size_t full_block_size;
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
// Handler-local scratch reused across refinements. This assumes a handler
// instance is confined to one search slice and not entered concurrently.
std::vector<uint8_t> unpack_buf; // reusable buffer for unpack_1
std::vector<uint8_t> unpack_buf; // sign bits scratch buffer

// Cached per-list values (set in set_list_context, avoid recomputing in
// handle)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_rabitq_fastscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,59 @@ def test_ivf_factory_with_batch_size(self):
self.assertEqual(D.shape, (ds.nq, 5))


class TestRaBitQFastScanSearchParams(unittest.TestCase):
"""Test that IVFRaBitQSearchParameters qb/centered are respected."""

def test_higher_qb_improves_recall(self):
"""Search with qb=4 should give better recall than qb=1."""
d = 64
nlist = 16
nprobe = 4
k = 10
ds = datasets.SyntheticDataset(d, 5000, 5000, 50)

# Ground truth with flat index
index_flat = faiss.IndexFlatL2(d)
index_flat.add(ds.get_database())
_, I_gt = index_flat.search(ds.get_queries(), k)

# Build IVF RaBitQ FastScan index with default qb=8
quantizer = faiss.IndexFlat(d, faiss.METRIC_L2)
index = faiss.IndexIVFRaBitQFastScan(
quantizer, d, nlist, faiss.METRIC_L2, 32, True
)
index.nprobe = nprobe
index.train(ds.get_train())
index.add(ds.get_database())

# Search with qb=1 (coarse quantization)
params_qb1 = faiss.IVFRaBitQSearchParameters()
params_qb1.nprobe = nprobe
params_qb1.qb = 1
_, I_qb1 = index.search(ds.get_queries(), k, params=params_qb1)

# Search with qb=4 (finer quantization)
params_qb4 = faiss.IVFRaBitQSearchParameters()
params_qb4.nprobe = nprobe
params_qb4.qb = 4
_, I_qb4 = index.search(ds.get_queries(), k, params=params_qb4)

# Compute recall@k
recall_qb1 = np.mean([
len(np.intersect1d(I_qb1[i], I_gt[i])) / k
for i in range(ds.nq)
])
recall_qb4 = np.mean([
len(np.intersect1d(I_qb4[i], I_gt[i])) / k
for i in range(ds.nq)
])

self.assertGreater(
recall_qb4, recall_qb1,
f"qb=4 recall ({recall_qb4:.3f}) should be higher "
f"than qb=1 recall ({recall_qb1:.3f})"
)


if __name__ == "__main__":
unittest.main()
Loading