Skip to content

Commit a40bd22

Browse files
committed
Add TurboQuant index and benchmark support
1 parent acac823 commit a40bd22

16 files changed

+906
-25
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ faiss/python/swigfaiss_sve.swig
3131
# Python package build outputs
3232
/dist/
3333
/*.egg-info/
34+
/build-conda/
35+
/benchs/sift1M/

benchs/__init__.py

Whitespace-only changes.

benchs/bench_quantizer.py

Lines changed: 147 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,80 @@
1010

1111
try:
1212
from faiss.contrib.datasets_fb import \
13-
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN
13+
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN, DatasetGlove
1414
except ImportError:
1515
from faiss.contrib.datasets import \
16-
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN
16+
DatasetSIFT1M, DatasetDeep1B, DatasetBigANN, DatasetGlove, DatasetDBpedia1536_1M, DatasetDBpedia3072_1M
1717

1818

19-
def eval_codec(q, xq, xb, gt):
19+
TURBOQUANT_OPTIONS = {"turboquant", "tq"}
20+
RABITQ_OPTIONS = {"rabitq", "rbq"}
21+
22+
23+
def get_metric_type(ds):
24+
if ds.metric == "IP":
25+
return faiss.METRIC_INNER_PRODUCT
26+
if ds.metric == "L2":
27+
return faiss.METRIC_L2
28+
raise RuntimeError(f"unsupported dataset metric {ds.metric}")
29+
30+
31+
def get_training_vectors(ds, xb, maxtrain):
32+
try:
33+
return ds.get_train(maxtrain=maxtrain)
34+
except NotImplementedError:
35+
print("No training set: training on database")
36+
return xb[:maxtrain]
37+
38+
39+
def encode(codec, x):
40+
if hasattr(codec, "compute_codes") and hasattr(codec, "decode"):
41+
return codec.compute_codes(x)
42+
if hasattr(codec, "sa_encode") and hasattr(codec, "sa_decode"):
43+
return codec.sa_encode(x)
44+
raise TypeError(f"unsupported codec type {type(codec).__name__}")
45+
46+
47+
def decode(codec, codes):
48+
if hasattr(codec, "compute_codes") and hasattr(codec, "decode"):
49+
return codec.decode(codes)
50+
if hasattr(codec, "sa_encode") and hasattr(codec, "sa_decode"):
51+
return codec.sa_decode(codes)
52+
raise TypeError(f"unsupported codec type {type(codec).__name__}")
53+
54+
55+
def get_code_size(codec):
56+
if hasattr(codec, "code_size"):
57+
return int(codec.code_size)
58+
if hasattr(codec, "sa_code_size"):
59+
return int(codec.sa_code_size())
60+
return None
61+
62+
63+
def eval_codec(q, xq, xb, gt, metric_type):
2064
t0 = time.time()
21-
codes = q.compute_codes(xb)
65+
codes = encode(q, xb)
2266
t1 = time.time()
23-
xb_decoded = q.decode(codes)
67+
xb_decoded = decode(q, codes)
2468
recons_err = ((xb - xb_decoded) ** 2).sum() / xb.shape[0]
2569
# for compatibility with the codec benchmarks
2670
err_compat = np.linalg.norm(xb - xb_decoded, axis=1).mean()
27-
xq_decoded = q.decode(q.compute_codes(xq))
28-
D, I = faiss.knn(xq_decoded, xb_decoded, 1)
71+
xq_decoded = decode(q, encode(q, xq))
72+
D, I = faiss.knn(xq_decoded, xb_decoded, 1, metric=metric_type)
2973
recall = (I[:, 0] == gt[:, 0]).sum() / nq
74+
code_size = get_code_size(q)
75+
code_size_s = (
76+
f" code_size: {code_size} B/vector"
77+
if code_size is not None
78+
else ""
79+
)
3080
print(
3181
f"\tencode time: {t1 - t0:.3f} reconstruction error: {recons_err:.3f} "
32-
f"1-recall@1: {recall:.4f} recons_err_compat {err_compat:.3f}")
82+
f"recall@1: {recall:.4f} recons_err_compat {err_compat:.3f}"
83+
f"{code_size_s}")
3384

3485

35-
def eval_quantizer(q, xq, xb, gt, xt, variants=None):
86+
def eval_quantizer(q, xq, xb, gt, xt, metric_type, variants=None):
3687
if variants is None:
3788
variants = [(None, None)]
3889
t0 = time.time()
@@ -53,20 +104,41 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
53104
getattr(q, name) # make sure field exists
54105
setattr(q, name, val)
55106

56-
eval_codec(q, xq, xb, gt)
107+
eval_codec(q, xq, xb, gt, metric_type)
57108

58109

59110
todo = sys.argv[1:]
60111

61-
if len(todo) > 0 and "deep1M" in todo[0]:
62-
ds = DatasetDeep1B(10**6)
112+
dataset_name = "sift1M"
113+
if len(todo) > 0 and todo[0] in (
114+
"sift1M",
115+
"deep1M",
116+
"bigann1M",
117+
"glove",
118+
"dbpedia-1536-1M",
119+
"dbpedia-3072-1M",
120+
):
121+
dataset_name = todo[0]
63122
del todo[0]
64-
elif len(todo) > 0 and "bigann1M" in todo[0]:
123+
124+
if dataset_name == "deep1M":
125+
ds = DatasetDeep1B(10**6)
126+
elif dataset_name == "bigann1M":
65127
ds = DatasetBigANN(nb_M=1)
66-
del todo[0]
128+
elif dataset_name == "glove":
129+
ds = DatasetGlove()
130+
elif dataset_name == "dbpedia-1536-1M":
131+
ds = DatasetDBpedia1536_1M()
132+
elif dataset_name == "dbpedia-3072-1M":
133+
ds = DatasetDBpedia3072_1M()
67134
else:
68135
ds = DatasetSIFT1M()
69136

137+
M = None
138+
nsplits = None
139+
Msub = None
140+
nbits = None
141+
70142
if len(todo) > 0:
71143
if todo[0].count("x") == 1:
72144
M, nbits = [int(x) for x in todo[0].split("x")]
@@ -75,15 +147,42 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
75147
nsplits, Msub, nbits = [int(x) for x in todo[0].split("x")]
76148
M = nsplits * Msub
77149
del todo[0]
150+
elif todo[0].isdigit():
151+
nbits = int(todo[0])
152+
del todo[0]
153+
154+
selected_options = set(todo)
155+
156+
if nbits is None:
157+
raise RuntimeError(
158+
"expected a codec bit specification: Mxnbits, nsplitsxMsubxnbits, "
159+
"or plain nbits for turboquant/rabitq"
160+
)
161+
162+
if selected_options & TURBOQUANT_OPTIONS and not 1 <= nbits <= 8:
163+
raise RuntimeError("TurboQuant supports nbits in [1, 8]")
164+
165+
if selected_options & RABITQ_OPTIONS and not 1 <= nbits <= 9:
166+
raise RuntimeError("RaBitQ supports nbits in [1, 9]")
167+
168+
if M is None and selected_options & {"pq", "opq", "rq", "rq_lut", "lsq", "lsq-gpu"}:
169+
raise RuntimeError("expected Mxnbits for pq/opq/rq/rq_lut/lsq benchmarks")
170+
171+
if M is None and selected_options & {"prq", "plsq"}:
172+
raise RuntimeError("expected nsplitsxMsubxnbits for prq/plsq benchmarks")
78173

79174
maxtrain = max(100 << nbits, 10**5)
80-
print(f"eval on {M}x{nbits} maxtrain={maxtrain}")
175+
if M is None:
176+
print(f"eval on {dataset_name} {nbits}-bit maxtrain={maxtrain}")
177+
else:
178+
print(f"eval on {dataset_name} {M}x{nbits} maxtrain={maxtrain}")
81179

82180
xq = ds.get_queries()
83181
xb = ds.get_database()
84182
gt = ds.get_groundtruth()
183+
metric_type = get_metric_type(ds)
85184

86-
xt = ds.get_train(maxtrain=maxtrain)
185+
xt = get_training_vectors(ds, xb, maxtrain=maxtrain)
87186

88187
nb, d = xb.shape
89188
nq, d = xq.shape
@@ -97,12 +196,12 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
97196
ngpus = faiss.get_num_gpus()
98197
lsq.icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
99198
lsq.verbose = True
100-
eval_quantizer(lsq, xb, xt, 'lsq-gpu')
199+
eval_quantizer(lsq, xq, xb, gt, xt, metric_type)
101200

102201
if 'pq' in todo:
103202
pq = faiss.ProductQuantizer(d, M, nbits)
104203
print("===== PQ")
105-
eval_quantizer(pq, xq, xb, gt, xt)
204+
eval_quantizer(pq, xq, xb, gt, xt, metric_type)
106205

107206
if 'opq' in todo:
108207
d2 = ((d + M - 1) // M) * M
@@ -114,19 +213,19 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
114213
xt2 = opq.apply(xt)
115214
pq = faiss.ProductQuantizer(d2, M, nbits)
116215
print("===== PQ")
117-
eval_quantizer(pq, xq2, xb2, gt, xt2)
216+
eval_quantizer(pq, xq2, xb2, gt, xt2, metric_type)
118217

119218
if 'prq' in todo:
120219
print(f"===== PRQ{nsplits}x{Msub}x{nbits}")
121220
prq = faiss.ProductResidualQuantizer(d, nsplits, Msub, nbits)
122221
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
123-
eval_quantizer(prq, xq, xb, gt, xt, variants=variants)
222+
eval_quantizer(prq, xq, xb, gt, xt, metric_type, variants=variants)
124223

125224
if 'plsq' in todo:
126225
print(f"===== PLSQ{nsplits}x{Msub}x{nbits}")
127226
plsq = faiss.ProductLocalSearchQuantizer(d, nsplits, Msub, nbits)
128227
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
129-
eval_quantizer(plsq, xq, xb, gt, xt, variants=variants)
228+
eval_quantizer(plsq, xq, xb, gt, xt, metric_type, variants=variants)
130229

131230
if 'rq' in todo:
132231
print("===== RQ")
@@ -136,7 +235,7 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
136235
# rq.train_type = faiss.ResidualQuantizer.Train_default
137236
# rq.verbose = True
138237
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32)]
139-
eval_quantizer(rq, xq, xb, gt, xt, variants=variants)
238+
eval_quantizer(rq, xq, xb, gt, xt, metric_type, variants=variants)
140239

141240
if 'rq_lut' in todo:
142241
print("===== RQ")
@@ -148,10 +247,34 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
148247
# rq.train_type = faiss.ResidualQuantizer.Train_default
149248
# rq.verbose = True
150249
variants = [("max_beam_size", i) for i in (1, 2, 4, 8, 16, 32, 64)]
151-
eval_quantizer(rq, xq, xb, gt, xt, variants=variants)
250+
eval_quantizer(rq, xq, xb, gt, xt, metric_type, variants=variants)
152251

153252
if 'lsq' in todo:
154253
print("===== LSQ")
155254
lsq = faiss.LocalSearchQuantizer(d, M, nbits)
156255
variants = [("encode_ils_iters", i) for i in (2, 3, 4, 8, 16)]
157-
eval_quantizer(lsq, xq, xb, gt, xt, variants=variants)
256+
eval_quantizer(lsq, xq, xb, gt, xt, metric_type, variants=variants)
257+
258+
if selected_options & TURBOQUANT_OPTIONS:
259+
print("===== TurboQuant")
260+
store_norm = dataset_name != "glove"
261+
if hasattr(faiss, "IndexTurboQuantMSE"):
262+
tq = faiss.IndexTurboQuantMSE(d, nbits, metric_type, 12345, store_norm)
263+
elif hasattr(faiss, "TurboQuantizer"):
264+
tq = faiss.TurboQuantizer(d, nbits, 12345, store_norm)
265+
else:
266+
raise RuntimeError(
267+
"TurboQuant is not available in this faiss Python build. "
268+
"Rebuild the Python bindings so TurboQuant symbols are exported."
269+
)
270+
eval_quantizer(tq, xq, xb, gt, xt, metric_type)
271+
272+
if selected_options & RABITQ_OPTIONS:
273+
print("===== RaBitQ")
274+
if not hasattr(faiss, "IndexRaBitQ"):
275+
raise RuntimeError(
276+
"RaBitQ is not available in this faiss Python build. "
277+
"Rebuild the Python bindings so RaBitQ symbols are exported."
278+
)
279+
rbq = faiss.IndexRaBitQ(d, metric_type, nbits)
280+
eval_quantizer(rbq, xq, xb, gt, xt, metric_type)

0 commit comments

Comments
 (0)