Skip to content

Commit 8431e04

Browse files
alibeklfcmeta-codesync[bot]
authored andcommitted
Replace squared-distance IP with direct dot-product in multi-bit RaBitQ (#4877)
Summary: Pull Request resolved: #4877 Refactor the multi-bit RaBitQ inner product computation from the squared-distance-then-convert approach to a direct dot-product formulation. Before: IP = -0.5 * (||q-c||² + (||r||² - ||x||²) + (-2·||r||/ipnorm)·ex_ip - ||q||²) After: IP = <q,c> + <c,r> + (||r||/ipnorm)·ex_ip Both are mathematically equivalent to <q, x>. The new form is simpler and makes the code structurally immune to the D95166460 bug class where the degenerate case (x ≈ centroid) required complex metric-specific branching that was easy to get wrong. Changes: - Per-document factors (compute_ex_factors): IP branch now computes f_add_ex = <c, r> and f_rescale_ex = ||r||/ipnorm (positive, no -2 factor) - Degenerate case simplified to metric-agnostic f_add_ex=0, f_rescale_ex=0 - Core distance function (compute_full_multibit_distance): accepts a single qr_base parameter instead of qr_to_c_L2sqr + qr_norm_L2sqr, eliminates the -0.5*(dist - qr_norm_L2sqr) IP post-processing - Added q_dot_c field to QueryFactorsData, computed at all query-setup sites - L2 path, 1-bit path, and SIMD kernels are completely unchanged - All four index types covered: IndexRaBitQ, IndexIVFRaBitQ, IndexRaBitQFastScan, IndexIVFRaBitQFastScan Breaking change: serialized multi-bit IP indexes must be re-encoded. L2 indexes are unaffected. Reviewed By: ddrcoder, latham-meta Differential Revision: D95419974 fbshipit-source-id: 3a5a33b5a3065d172a2f57578f3a78bdf562b87c
1 parent 471ddad commit 8431e04

7 files changed

Lines changed: 81 additions & 40 deletions

File tree

faiss/IndexIVFRaBitQFastScan.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,11 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
283283
rotated_q,
284284
rotated_qq);
285285

286-
// Override query norm for inner product if original query is provided
287286
if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
288287
original_query != nullptr) {
289288
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
289+
query_factors.q_dot_c = query_factors.qr_norm_L2sqr -
290+
fvec_inner_product(original_query, residual, d);
290291
}
291292

292293
const size_t ex_bits = rabitq.nb_bits - 1;
@@ -813,8 +814,9 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
813814
ex_code,
814815
ex_fac,
815816
query_factors.rotated_q.data(),
816-
query_factors.qr_to_c_L2sqr,
817-
query_factors.qr_norm_L2sqr,
817+
(index->metric_type == MetricType::METRIC_INNER_PRODUCT)
818+
? query_factors.q_dot_c
819+
: query_factors.qr_to_c_L2sqr,
818820
dim,
819821
ex_bits,
820822
index->metric_type);

faiss/IndexRaBitQFastScan.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,8 +751,9 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
751751
ex_code,
752752
ex_fac,
753753
query_factors.rotated_q.data(),
754-
query_factors.qr_to_c_L2sqr,
755-
query_factors.qr_norm_L2sqr,
754+
(rabitq_index->metric_type == MetricType::METRIC_INNER_PRODUCT)
755+
? query_factors.q_dot_c
756+
: query_factors.qr_to_c_L2sqr,
756757
dim,
757758
ex_bits,
758759
rabitq_index->metric_type);

faiss/impl/RaBitQUtils.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,12 @@ QueryFactorsData compute_query_factors(
244244

245245
// Compute query norm for inner product metric
246246
query_factors.qr_norm_L2sqr = 0.0f;
247+
query_factors.q_dot_c = 0.0f;
247248
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
248249
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d);
250+
if (centroid != nullptr) {
251+
query_factors.q_dot_c = fvec_inner_product(query, centroid, d);
252+
}
249253
}
250254

251255
return query_factors;
@@ -307,8 +311,7 @@ float compute_full_multibit_distance(
307311
const uint8_t* ex_code,
308312
const ExtraBitsFactors& ex_fac,
309313
const float* rotated_q,
310-
float qr_to_c_L2sqr,
311-
float qr_norm_L2sqr,
314+
float qr_base,
312315
size_t d,
313316
size_t ex_bits,
314317
MetricType metric_type) {
@@ -317,11 +320,9 @@ float compute_full_multibit_distance(
317320
float ex_ip = rabitq::multibit::compute_inner_product(
318321
sign_bits, ex_code, rotated_q, d, ex_bits, cb);
319322

320-
float dist = qr_to_c_L2sqr + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
323+
float dist = qr_base + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
321324

322-
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
323-
dist = -0.5f * (dist - qr_norm_L2sqr);
324-
} else {
325+
if (metric_type == MetricType::METRIC_L2) {
325326
dist = std::max(0.0f, dist);
326327
}
327328

faiss/impl/RaBitQUtils.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct QueryFactorsData {
7070

7171
float qr_to_c_L2sqr = 0;
7272
float qr_norm_L2sqr = 0;
73+
float q_dot_c = 0; // <query, centroid> for IP metric; 0 for L2
7374

7475
float int_dot_scale = 1;
7576

@@ -320,8 +321,7 @@ inline int extract_code_inline(
320321
* @param ex_code packed ex-bit codes
321322
* @param ex_fac ex-bit factors (f_add_ex, f_rescale_ex)
322323
* @param rotated_q rotated query vector
323-
* @param qr_to_c_L2sqr precomputed ||query_rotated - centroid||^2
324-
* @param qr_norm_L2sqr precomputed ||query_rotated||^2 (0 for L2 metric)
324+
* @param qr_base precomputed base term: ||q-c||^2 for L2, <q,c> for IP
325325
* @param d dimensionality
326326
* @param ex_bits number of extra bits (nb_bits - 1)
327327
* @param metric_type distance metric (L2 or Inner Product)
@@ -332,8 +332,7 @@ float compute_full_multibit_distance(
332332
const uint8_t* ex_code,
333333
const ExtraBitsFactors& ex_fac,
334334
const float* rotated_q,
335-
float qr_to_c_L2sqr,
336-
float qr_norm_L2sqr,
335+
float qr_base,
337336
size_t d,
338337
size_t ex_bits,
339338
MetricType metric_type);

faiss/impl/RaBitQuantizer.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,15 @@ float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
314314
ex_code + (d * ex_bits + 7) / 8);
315315

316316
// Call shared utility directly with rotated_q pointer
317+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
318+
? query_fac.q_dot_c
319+
: query_fac.qr_to_c_L2sqr;
317320
return rabitq_utils::compute_full_multibit_distance(
318321
binary_data,
319322
ex_code,
320323
*ex_fac,
321324
rotated_q.data(),
322-
query_fac.qr_to_c_L2sqr,
323-
query_fac.qr_norm_L2sqr,
325+
qr_base,
324326
d,
325327
ex_bits,
326328
metric_type);
@@ -366,6 +368,8 @@ void RaBitQDistanceComputerNotQ::set_query(const float* x) {
366368
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
367369
// precompute if needed
368370
query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
371+
query_fac.q_dot_c =
372+
centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
369373
}
370374
}
371375

@@ -480,13 +484,15 @@ float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
480484
ex_code + (d * ex_bits + 7) / 8);
481485

482486
// Call shared utility directly with rotated_q pointer
487+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
488+
? query_fac.q_dot_c
489+
: query_fac.qr_to_c_L2sqr;
483490
return rabitq_utils::compute_full_multibit_distance(
484491
binary_data,
485492
ex_code,
486493
*ex_fac,
487494
rotated_q.data(),
488-
query_fac.qr_to_c_L2sqr,
489-
query_fac.qr_norm_L2sqr,
495+
qr_base,
490496
d,
491497
ex_bits,
492498
metric_type);

faiss/impl/RaBitQuantizerMultiBit.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -216,25 +216,13 @@ void compute_ex_factors(
216216
ex_factors.f_add_ex = l2_sqr;
217217
ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
218218
} else {
219-
// For IP: Need to compute ||x||^2 for the correction term
220-
// The distance formula expects f_add_ex = ||x - c||^2 - ||x||^2
221-
// to match the 1-bit formula's or_minus_c_l2sqr for IP metric
222-
//
223-
// Reconstruct ||x||^2 from residual and centroid:
224-
// x = residual + centroid, so ||x||^2 = ||residual + centroid||^2
225-
float or_L2sqr = 0;
226-
if (centroid != nullptr) {
227-
for (size_t i = 0; i < d; i++) {
228-
float x_val = residual[i] + centroid[i];
229-
or_L2sqr += x_val * x_val;
230-
}
231-
} else {
232-
// If no centroid, x = residual
233-
or_L2sqr = l2_sqr;
234-
}
235-
236-
ex_factors.f_add_ex = l2_sqr - or_L2sqr;
237-
ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
219+
// For IP: direct dot-product formulation
220+
// f_add_ex = <c, r> (dot product of centroid and residual)
221+
// f_rescale_ex = ||r|| / ipnorm (positive scaling)
222+
float c_dot_r =
223+
centroid ? fvec_inner_product(residual, centroid, d) : 0.0f;
224+
ex_factors.f_add_ex = c_dot_r;
225+
ex_factors.f_rescale_ex = ipnorm_inv * norm;
238226
}
239227
}
240228

@@ -276,12 +264,14 @@ void quantize_ex_bits(
276264
float norm_sqr = fvec_norm_L2sqr(residual, d);
277265
float norm = std::sqrt(norm_sqr);
278266

279-
// Handle degenerate case
267+
// Handle degenerate case: residual is (near-)zero, meaning x ≈ centroid.
268+
// For both L2 and IP, f_add_ex and f_rescale_ex are trivially zero:
269+
// L2: ||r||² ≈ 0, IP: <c,r> ≈ 0 and ||r||/ipnorm ≈ 0
280270
if (norm < 1e-10f) {
281271
size_t code_size = (d * ex_bits + 7) / 8;
282272
memset(ex_code, 0, code_size);
283273
ex_factors.f_add_ex = 0.0f;
284-
ex_factors.f_rescale_ex = 1.0f;
274+
ex_factors.f_rescale_ex = 0.0f;
285275
return;
286276
}
287277

tests/test_rabitq.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,48 @@ def test_factory_end_to_end(self):
965965
self.assertEqual(D_ivf.shape, (ds.nq, 5))
966966
self.assertTrue(np.all(I_ivf >= 0))
967967

968+
def test_degenerate_centroid_distance(self):
969+
"""Test that a doc identical to its centroid gets correct distances.
970+
971+
When a vector's residual (x - centroid) has near-zero norm,
972+
quantize_ex_bits hits a degenerate early-return path. A previous
973+
bug set f_rescale_ex=1 (should be 0) and f_add_ex=0 (should be
974+
metric-dependent), causing wildly wrong multi-bit distances.
975+
"""
976+
d = 128
977+
nlist = 16
978+
rs = np.random.RandomState(42)
979+
xt = rs.randn(2000, d).astype("float32")
980+
xt /= np.linalg.norm(xt, axis=1, keepdims=True)
981+
982+
query = rs.randn(1, d).astype("float32")
983+
query /= np.linalg.norm(query)
984+
985+
for metric in [faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT]:
986+
for nb_bits in [2, 3, 4]:
987+
with self.subTest(metric=metric, nb_bits=nb_bits):
988+
quantizer = faiss.IndexFlat(d, metric)
989+
index = faiss.IndexIVFRaBitQ(
990+
quantizer, d, nlist, metric, True, nb_bits
991+
)
992+
index.train(xt)
993+
index.nprobe = nlist # exhaustive
994+
995+
# Add the centroid itself as the only document
996+
centroid = index.quantizer.reconstruct(0).reshape(1, d)
997+
index.add(centroid)
998+
999+
D, I = index.search(query, 1)
1000+
1001+
if metric == faiss.METRIC_INNER_PRODUCT:
1002+
true_dist = (query @ centroid.T).item()
1003+
else:
1004+
true_dist = float(np.sum((query - centroid) ** 2))
1005+
1006+
np.testing.assert_allclose(
1007+
D[0, 0], true_dist, atol=0.15,
1008+
err_msg=f"nb_bits={nb_bits}")
1009+
9681010

9691011
class TestRaBitQStats(unittest.TestCase):
9701012
"""Test RaBitQStats tracking for multi-bit two-stage search."""

0 commit comments

Comments
 (0)