Skip to content

Commit 6b9279c

Browse files
Michael Norrisfacebook-github-bot
authored andcommitted
Add QT_0bit to ScalarQuantizer for centroid-only IVF distance (facebookresearch#5079)
Summary: Adds a new ScalarQuantizer::QT_0bit type that encodes 0 bits per component, enabling IndexIVFScalarQuantizer to operate in centroid-only distance mode (code_size=0). In this mode, distance_to_code() returns the coarse distance from the quantizer, no per-vector data is stored, and reconstruction returns the centroid vector. This is useful for IVF configurations where only query-to-centroid distances are needed (e.g., Unicorn's SQ0 use case). Changes: - Add QT_0bit enum to ScalarQuantizer::QuantizerType - Add IVFCoarseDistanceScanner that returns coarse_dis from set_list() - Handle code_size=0 in encode/decode/reconstruct/add paths - Add 'SQ0' to index_factory - Force by_residual=false for QT_0bit - Guard memcpy in ArrayInvertedLists for code_size=0 - Handle QT_0bit in index_read.cpp validation switch - Add tests for L2, IP, and index_factory --- Differential Revision: D100348052
1 parent aa3ce37 commit 6b9279c

File tree

9 files changed

+245
-5
lines changed

9 files changed

+245
-5
lines changed

faiss/IndexScalarQuantizer.cpp

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
133133
invlists->code_size = code_size;
134134
}
135135
is_trained = false;
136+
if (qtype == ScalarQuantizer::QT_0bit) {
137+
by_residual = false;
138+
is_trained = true; // no training needed
139+
}
136140
}
137141

138142
IndexIVFScalarQuantizer::IndexIVFScalarQuantizer() : IndexIVF() {
@@ -156,6 +160,19 @@ void IndexIVFScalarQuantizer::encode_vectors(
156160
const idx_t* list_nos,
157161
uint8_t* codes,
158162
bool include_listnos) const {
163+
if (sq.code_size == 0) {
164+
// QT_0bit: nothing to encode, but handle coarse codes if needed
165+
if (include_listnos) {
166+
size_t coarse_size = coarse_code_size();
167+
for (idx_t i = 0; i < n; i++) {
168+
int64_t list_no = list_nos[i];
169+
if (list_no >= 0) {
170+
encode_listno(list_no, codes + i * coarse_size);
171+
}
172+
}
173+
}
174+
return;
175+
}
159176
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
160177
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
161178
memset(codes, 0, (code_size + coarse_size) * n);
@@ -186,14 +203,46 @@ void IndexIVFScalarQuantizer::encode_vectors(
186203
void IndexIVFScalarQuantizer::decode_vectors(
187204
idx_t n,
188205
const uint8_t* codes,
189-
const idx_t*,
206+
const idx_t* list_nos,
190207
float* x) const {
208+
if (sq.code_size == 0) {
209+
// QT_0bit: reconstruct centroids if list_nos provided
210+
if (list_nos) {
211+
for (idx_t i = 0; i < n; i++) {
212+
quantizer->reconstruct(list_nos[i], x + i * d);
213+
}
214+
} else {
215+
memset(x, 0, sizeof(float) * d * n);
216+
}
217+
return;
218+
}
191219
FAISS_THROW_IF_NOT(is_trained);
192-
return sq.decode(codes, x, n);
220+
sq.decode(codes, x, n);
221+
if (by_residual) {
222+
FAISS_THROW_IF_NOT_MSG(
223+
list_nos, "decode_vectors with by_residual requires list_nos");
224+
#pragma omp parallel for if (n > 1000)
225+
for (idx_t i = 0; i < n; i++) {
226+
std::vector<float> centroid(d);
227+
quantizer->reconstruct(list_nos[i], centroid.data());
228+
for (size_t j = 0; j < static_cast<size_t>(d); j++) {
229+
x[i * d + j] += centroid[j];
230+
}
231+
}
232+
}
193233
}
194234

195235
void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
196236
const {
237+
if (sq.code_size == 0) {
238+
size_t coarse_size = coarse_code_size();
239+
for (idx_t i = 0; i < n; i++) {
240+
const uint8_t* code = codes + i * coarse_size;
241+
int64_t list_no = decode_listno(code);
242+
quantizer->reconstruct(list_no, x + i * d);
243+
}
244+
return;
245+
}
197246
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
198247
size_t coarse_size = coarse_code_size();
199248

@@ -224,6 +273,23 @@ void IndexIVFScalarQuantizer::add_core(
224273
const idx_t* coarse_idx,
225274
void* inverted_list_context) {
226275
FAISS_THROW_IF_NOT(is_trained);
276+
if (sq.code_size == 0) {
277+
// QT_0bit: just add IDs with empty codes
278+
uint8_t dummy_code = 0;
279+
DirectMapAdd dm_add(direct_map, n, xids);
280+
for (idx_t i = 0; i < n; i++) {
281+
int64_t list_no = coarse_idx[i];
282+
if (list_no >= 0) {
283+
int64_t id = xids ? xids[i] : ntotal + i;
284+
size_t ofs = invlists->add_entry(list_no, id, &dummy_code);
285+
dm_add.add(i, list_no, ofs);
286+
} else {
287+
dm_add.add(i, -1, 0);
288+
}
289+
}
290+
ntotal += n;
291+
return;
292+
}
227293

228294
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
229295

@@ -277,6 +343,11 @@ void IndexIVFScalarQuantizer::reconstruct_from_offset(
277343
int64_t list_no,
278344
int64_t offset,
279345
float* recons) const {
346+
if (sq.code_size == 0) {
347+
// QT_0bit: reconstruct from centroid
348+
quantizer->reconstruct(list_no, recons);
349+
return;
350+
}
280351
const uint8_t* code = invlists->get_single_code(list_no, offset);
281352

282353
if (by_residual) {

faiss/impl/ScalarQuantizer.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ void ScalarQuantizer::set_derived_sizes() {
6262
code_size = d * 2;
6363
bits = 16;
6464
break;
65+
case QT_0bit:
66+
code_size = 0;
67+
bits = 0;
68+
break;
6569
default:
6670
break;
6771
}
@@ -71,6 +75,10 @@ void ScalarQuantizer::train(size_t n, const float* x) {
7175
using scalar_quantizer::train_NonUniform;
7276
using scalar_quantizer::train_Uniform;
7377

78+
if (qtype == QT_0bit) {
79+
return; // nothing to train for centroid-only mode
80+
}
81+
7482
int bit_per_dim = qtype == QT_4bit_uniform ? 4
7583
: qtype == QT_4bit ? 4
7684
: qtype == QT_6bit ? 6
@@ -128,6 +136,9 @@ ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
128136

129137
void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
130138
const {
139+
if (code_size == 0) {
140+
return; // QT_0bit: nothing to encode
141+
}
131142
std::unique_ptr<SQuantizer> squant(select_quantizer());
132143

133144
memset(codes, 0, code_size * n);
@@ -138,6 +149,10 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
138149
}
139150

140151
void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
152+
if (code_size == 0) {
153+
memset(x, 0, sizeof(float) * d * n);
154+
return; // QT_0bit: no per-vector data, zero-fill
155+
}
141156
std::unique_ptr<SQuantizer> squant(select_quantizer());
142157

143158
#pragma omp parallel for

faiss/impl/ScalarQuantizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct ScalarQuantizer : Quantizer {
3333
QT_bf16,
3434
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
3535
///< [-128 to 127]
36+
QT_0bit, ///< 0 bits per component, centroid-only distance (for IVF)
3637
QT_count
3738
};
3839

faiss/impl/index_read.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ void read_ScalarQuantizer(
903903
case ScalarQuantizer::QT_bf16:
904904
case ScalarQuantizer::QT_8bit_direct:
905905
case ScalarQuantizer::QT_8bit_direct_signed:
906+
case ScalarQuantizer::QT_0bit:
906907
case ScalarQuantizer::QT_count:
907908
expected = 0;
908909
break;

faiss/impl/scalar_quantizer/scanners.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,32 @@ InvertedListScanner* sq_select_InvertedListScanner(
159159
const IDSelector* sel,
160160
bool by_residual);
161161

162+
/// Scanner for QT_0bit / centroid-only distance: always returns the
163+
/// coarse distance that was set via set_list().
164+
struct IVFCoarseDistanceScanner : InvertedListScanner {
165+
float coarse_dis = 0;
166+
167+
IVFCoarseDistanceScanner(
168+
bool is_similarity,
169+
bool store_pairs,
170+
const IDSelector* sel)
171+
: InvertedListScanner(store_pairs, sel) {
172+
code_size = 0;
173+
keep_max = is_similarity;
174+
}
175+
176+
void set_query(const float* /*query_vector*/) override {}
177+
178+
void set_list(idx_t list_no_in, float coarse_dis_in) override {
179+
this->list_no = list_no_in;
180+
this->coarse_dis = coarse_dis_in;
181+
}
182+
183+
float distance_to_code(const uint8_t* /*code*/) const override {
184+
return coarse_dis;
185+
}
186+
};
187+
162188
} // namespace scalar_quantizer
163189

164190
} // namespace faiss

faiss/impl/scalar_quantizer/sq-dispatch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ ScalarQuantizer::SQuantizer* sq_select_quantizer<THE_LEVEL_TO_DISPATCH>(
8585
return new Quantizer8bitDirect<SL>(d, trained);
8686
case ScalarQuantizer::QT_8bit_direct_signed:
8787
return new Quantizer8bitDirectSigned<SL>(d, trained);
88+
case ScalarQuantizer::QT_0bit:
89+
FAISS_THROW_MSG(
90+
"QT_0bit does not support standalone quantization, use IndexIVFScalarQuantizer");
8891
default:
8992
FAISS_THROW_MSG("unknown qtype");
9093
}
@@ -175,6 +178,9 @@ SQDistanceComputer* select_distance_computer_body(
175178
case ScalarQuantizer::QT_8bit_direct_signed:
176179
return new DCTemplate<Quantizer8bitDirectSigned<SL2>, Sim, SL2>(
177180
d, trained);
181+
case ScalarQuantizer::QT_0bit:
182+
FAISS_THROW_MSG(
183+
"QT_0bit does not support standalone distance computation, use IndexIVFScalarQuantizer");
178184
default:
179185
FAISS_THROW_MSG("unknown qtype");
180186
}
@@ -309,6 +315,9 @@ InvertedListScanner* sq_select_InvertedListScanner<THE_LEVEL_TO_DISPATCH>(
309315
Quantizer8bitDirectSigned<SL2>,
310316
Similarity,
311317
SL2>>();
318+
case ScalarQuantizer::QT_0bit:
319+
return new IVFCoarseDistanceScanner(
320+
Similarity::metric_type != METRIC_L2, store_pairs, sel);
312321
default:
313322
FAISS_THROW_MSG("unknown qtype");
314323
}

faiss/index_factory.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,10 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
154154
{"SQbf16", ScalarQuantizer::QT_bf16},
155155
{"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
156156
{"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
157+
{"SQ0", ScalarQuantizer::QT_0bit},
157158
};
158159
const std::string sq_pattern =
159-
"(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
160+
"(SQ0|SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
160161

161162
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
162163
{"_Nfloat", AdditiveQuantizer::ST_norm_float},

faiss/invlists/InvertedLists.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ size_t ArrayInvertedLists::add_entries(
289289
ids[list_no].resize(o + n_entry);
290290
memcpy(&ids[list_no][o], ids_in, sizeof(ids_in[0]) * n_entry);
291291
codes[list_no].resize((o + n_entry) * code_size);
292-
memcpy(&codes[list_no][o * code_size], code, code_size * n_entry);
292+
if (code_size > 0) {
293+
memcpy(&codes[list_no][o * code_size], code, code_size * n_entry);
294+
}
293295
return o;
294296
}
295297

@@ -328,7 +330,11 @@ void ArrayInvertedLists::update_entries(
328330
assert(list_no < nlist);
329331
assert(n_entry + offset <= ids[list_no].size());
330332
memcpy(&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
331-
memcpy(&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
333+
if (code_size > 0) {
334+
memcpy(&codes[list_no][offset * code_size],
335+
codes_in,
336+
code_size * n_entry);
337+
}
332338
}
333339

334340
void ArrayInvertedLists::permute_invlists(const idx_t* map) {

tests/test_scalar_quantizer.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77

88
#include <gtest/gtest.h>
99

10+
#include <cmath>
11+
#include <memory>
1012
#include <vector>
1113

14+
#include <faiss/IndexFlat.h>
15+
#include <faiss/IndexIVFFlat.h>
16+
#include <faiss/IndexScalarQuantizer.h>
1217
#include <faiss/impl/ScalarQuantizer.h>
18+
#include <faiss/index_factory.h>
1319

1420
TEST(ScalarQuantizer, RSQuantilesClamping) {
1521
int d = 8;
@@ -95,3 +101,107 @@ TEST(ScalarQuantizer, RSQuantilesSmallDataset) {
95101

96102
ASSERT_NO_THROW(sq.train(n, x.data()));
97103
}
104+
105+
TEST(TestSQ0bit, CoarseOnlySearch) {
106+
// Test QT_0bit: centroid-only distance
107+
int d = 64;
108+
int nlist = 8;
109+
int nb = 1000;
110+
int nq = 10;
111+
int k = 5;
112+
113+
std::vector<float> xb(nb * d), xq(nq * d);
114+
for (int i = 0; i < nb * d; i++)
115+
xb[i] = drand48();
116+
for (int i = 0; i < nq * d; i++)
117+
xq[i] = drand48();
118+
119+
faiss::IndexFlatL2 quantizer(d);
120+
faiss::IndexIVFScalarQuantizer index(
121+
&quantizer,
122+
d,
123+
nlist,
124+
faiss::ScalarQuantizer::QT_0bit,
125+
faiss::METRIC_L2,
126+
false);
127+
EXPECT_EQ(index.code_size, 0);
128+
EXPECT_FALSE(index.by_residual);
129+
130+
index.train(nb, xb.data());
131+
index.add(nb, xb.data());
132+
EXPECT_EQ(index.ntotal, nb);
133+
134+
index.nprobe = nlist;
135+
std::vector<float> distances(nq * k);
136+
std::vector<faiss::idx_t> labels(nq * k);
137+
index.search(nq, xq.data(), k, distances.data(), labels.data());
138+
139+
// Verify we got results
140+
for (int q = 0; q < nq; q++) {
141+
EXPECT_GE(labels[q * k], 0);
142+
}
143+
144+
// Compare with direct quantizer search - distances should match
145+
std::vector<float> coarse_dis(nq * nlist);
146+
std::vector<faiss::idx_t> coarse_ids(nq * nlist);
147+
quantizer.search(
148+
nq, xq.data(), nlist, coarse_dis.data(), coarse_ids.data());
149+
150+
for (int q = 0; q < nq; q++) {
151+
float ivf_dis = distances[q * k];
152+
bool found = false;
153+
for (int j = 0; j < nlist; j++) {
154+
if (std::abs(ivf_dis - coarse_dis[q * nlist + j]) < 1e-5) {
155+
found = true;
156+
break;
157+
}
158+
}
159+
EXPECT_TRUE(found) << "IVF distance " << ivf_dis
160+
<< " not found in coarse distances for query " << q;
161+
}
162+
}
163+
164+
TEST(TestSQ0bit, IndexFactory) {
165+
int d = 32;
166+
std::unique_ptr<faiss::Index> index(faiss::index_factory(d, "IVF8,SQ0"));
167+
EXPECT_NE(index, nullptr);
168+
auto* ivfsq = dynamic_cast<faiss::IndexIVFScalarQuantizer*>(index.get());
169+
EXPECT_NE(ivfsq, nullptr);
170+
EXPECT_EQ(ivfsq->sq.qtype, faiss::ScalarQuantizer::QT_0bit);
171+
EXPECT_EQ(ivfsq->code_size, 0);
172+
}
173+
174+
TEST(TestSQ0bit, InnerProduct) {
175+
int d = 64;
176+
int nlist = 4;
177+
int nb = 500;
178+
int nq = 5;
179+
int k = 3;
180+
181+
std::vector<float> xb(nb * d), xq(nq * d);
182+
for (int i = 0; i < nb * d; i++)
183+
xb[i] = drand48();
184+
for (int i = 0; i < nq * d; i++)
185+
xq[i] = drand48();
186+
187+
faiss::IndexFlatIP quantizer(d);
188+
faiss::IndexIVFScalarQuantizer index(
189+
&quantizer,
190+
d,
191+
nlist,
192+
faiss::ScalarQuantizer::QT_0bit,
193+
faiss::METRIC_INNER_PRODUCT,
194+
false);
195+
196+
index.train(nb, xb.data());
197+
index.add(nb, xb.data());
198+
199+
index.nprobe = nlist;
200+
std::vector<float> distances(nq * k);
201+
std::vector<faiss::idx_t> labels(nq * k);
202+
index.search(nq, xq.data(), k, distances.data(), labels.data());
203+
204+
for (int q = 0; q < nq; q++) {
205+
EXPECT_GE(labels[q * k], 0);
206+
}
207+
}

0 commit comments

Comments
 (0)