1010
1111try :
1212 from faiss .contrib .datasets_fb import \
13- DatasetSIFT1M , DatasetDeep1B , DatasetBigANN
13+ DatasetSIFT1M , DatasetDeep1B , DatasetBigANN , DatasetGlove
1414except ImportError :
1515 from faiss .contrib .datasets import \
16- DatasetSIFT1M , DatasetDeep1B , DatasetBigANN
16+ DatasetSIFT1M , DatasetDeep1B , DatasetBigANN , DatasetGlove , DatasetDBpedia1536_1M , DatasetDBpedia3072_1M
1717
1818
19- def eval_codec (q , xq , xb , gt ):
19+ TURBOQUANT_OPTIONS = {"turboquant" , "tq" }
20+ RABITQ_OPTIONS = {"rabitq" , "rbq" }
21+
22+
23+ def get_metric_type (ds ):
24+ if ds .metric == "IP" :
25+ return faiss .METRIC_INNER_PRODUCT
26+ if ds .metric == "L2" :
27+ return faiss .METRIC_L2
28+ raise RuntimeError (f"unsupported dataset metric { ds .metric } " )
29+
30+
31+ def get_training_vectors (ds , xb , maxtrain ):
32+ try :
33+ return ds .get_train (maxtrain = maxtrain )
34+ except NotImplementedError :
35+ print ("No training set: training on database" )
36+ return xb [:maxtrain ]
37+
38+
39+ def encode (codec , x ):
40+ if hasattr (codec , "compute_codes" ) and hasattr (codec , "decode" ):
41+ return codec .compute_codes (x )
42+ if hasattr (codec , "sa_encode" ) and hasattr (codec , "sa_decode" ):
43+ return codec .sa_encode (x )
44+ raise TypeError (f"unsupported codec type { type (codec ).__name__ } " )
45+
46+
47+ def decode (codec , codes ):
48+ if hasattr (codec , "compute_codes" ) and hasattr (codec , "decode" ):
49+ return codec .decode (codes )
50+ if hasattr (codec , "sa_encode" ) and hasattr (codec , "sa_decode" ):
51+ return codec .sa_decode (codes )
52+ raise TypeError (f"unsupported codec type { type (codec ).__name__ } " )
53+
54+
55+ def get_code_size (codec ):
56+ if hasattr (codec , "code_size" ):
57+ return int (codec .code_size )
58+ if hasattr (codec , "sa_code_size" ):
59+ return int (codec .sa_code_size ())
60+ return None
61+
62+
63+ def eval_codec (q , xq , xb , gt , metric_type ):
2064 t0 = time .time ()
21- codes = q . compute_codes ( xb )
65+ codes = encode ( q , xb )
2266 t1 = time .time ()
23- xb_decoded = q . decode (codes )
67+ xb_decoded = decode (q , codes )
2468 recons_err = ((xb - xb_decoded ) ** 2 ).sum () / xb .shape [0 ]
2569 # for compatibility with the codec benchmarks
2670 err_compat = np .linalg .norm (xb - xb_decoded , axis = 1 ).mean ()
27- xq_decoded = q . decode (q . compute_codes ( xq ))
28- D , I = faiss .knn (xq_decoded , xb_decoded , 1 )
71+ xq_decoded = decode (q , encode ( q , xq ))
72+ D , I = faiss .knn (xq_decoded , xb_decoded , 1 , metric = metric_type )
2973 recall = (I [:, 0 ] == gt [:, 0 ]).sum () / nq
74+ code_size = get_code_size (q )
75+ code_size_s = (
76+ f" code_size: { code_size } B/vector"
77+ if code_size is not None
78+ else ""
79+ )
3080 print (
3181 f"\t encode time: { t1 - t0 :.3f} reconstruction error: { recons_err :.3f} "
32- f"1-recall@1: { recall :.4f} recons_err_compat { err_compat :.3f} " )
82+ f"recall@1: { recall :.4f} recons_err_compat { err_compat :.3f} "
83+ f"{ code_size_s } " )
3384
3485
35- def eval_quantizer (q , xq , xb , gt , xt , variants = None ):
86+ def eval_quantizer (q , xq , xb , gt , xt , metric_type , variants = None ):
3687 if variants is None :
3788 variants = [(None , None )]
3889 t0 = time .time ()
@@ -53,20 +104,41 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
53104 getattr (q , name ) # make sure field exists
54105 setattr (q , name , val )
55106
56- eval_codec (q , xq , xb , gt )
107+ eval_codec (q , xq , xb , gt , metric_type )
57108
58109
59110todo = sys .argv [1 :]
60111
61- if len (todo ) > 0 and "deep1M" in todo [0 ]:
62- ds = DatasetDeep1B (10 ** 6 )
112+ dataset_name = "sift1M"
113+ if len (todo ) > 0 and todo [0 ] in (
114+ "sift1M" ,
115+ "deep1M" ,
116+ "bigann1M" ,
117+ "glove" ,
118+ "dbpedia-1536-1M" ,
119+ "dbpedia-3072-1M" ,
120+ ):
121+ dataset_name = todo [0 ]
63122 del todo [0 ]
64- elif len (todo ) > 0 and "bigann1M" in todo [0 ]:
123+
124+ if dataset_name == "deep1M" :
125+ ds = DatasetDeep1B (10 ** 6 )
126+ elif dataset_name == "bigann1M" :
65127 ds = DatasetBigANN (nb_M = 1 )
66- del todo [0 ]
128+ elif dataset_name == "glove" :
129+ ds = DatasetGlove ()
130+ elif dataset_name == "dbpedia-1536-1M" :
131+ ds = DatasetDBpedia1536_1M ()
132+ elif dataset_name == "dbpedia-3072-1M" :
133+ ds = DatasetDBpedia3072_1M ()
67134else :
68135 ds = DatasetSIFT1M ()
69136
137+ M = None
138+ nsplits = None
139+ Msub = None
140+ nbits = None
141+
70142if len (todo ) > 0 :
71143 if todo [0 ].count ("x" ) == 1 :
72144 M , nbits = [int (x ) for x in todo [0 ].split ("x" )]
@@ -75,15 +147,42 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
75147 nsplits , Msub , nbits = [int (x ) for x in todo [0 ].split ("x" )]
76148 M = nsplits * Msub
77149 del todo [0 ]
150+ elif todo [0 ].isdigit ():
151+ nbits = int (todo [0 ])
152+ del todo [0 ]
153+
154+ selected_options = set (todo )
155+
156+ if nbits is None :
157+ raise RuntimeError (
158+ "expected a codec bit specification: Mxnbits, nsplitsxMsubxnbits, "
159+ "or plain nbits for turboquant/rabitq"
160+ )
161+
162+ if selected_options & TURBOQUANT_OPTIONS and not 1 <= nbits <= 8 :
163+ raise RuntimeError ("TurboQuant supports nbits in [1, 8]" )
164+
165+ if selected_options & RABITQ_OPTIONS and not 1 <= nbits <= 9 :
166+ raise RuntimeError ("RaBitQ supports nbits in [1, 9]" )
167+
168+ if M is None and selected_options & {"pq" , "opq" , "rq" , "rq_lut" , "lsq" , "lsq-gpu" }:
169+ raise RuntimeError ("expected Mxnbits for pq/opq/rq/rq_lut/lsq benchmarks" )
170+
171+ if M is None and selected_options & {"prq" , "plsq" }:
172+ raise RuntimeError ("expected nsplitsxMsubxnbits for prq/plsq benchmarks" )
78173
79174maxtrain = max (100 << nbits , 10 ** 5 )
80- print (f"eval on { M } x{ nbits } maxtrain={ maxtrain } " )
175+ if M is None :
176+ print (f"eval on { dataset_name } { nbits } -bit maxtrain={ maxtrain } " )
177+ else :
178+ print (f"eval on { dataset_name } { M } x{ nbits } maxtrain={ maxtrain } " )
81179
82180xq = ds .get_queries ()
83181xb = ds .get_database ()
84182gt = ds .get_groundtruth ()
183+ metric_type = get_metric_type (ds )
85184
86- xt = ds . get_train ( maxtrain = maxtrain )
185+ xt = get_training_vectors ( ds , xb , maxtrain = maxtrain )
87186
88187nb , d = xb .shape
89188nq , d = xq .shape
@@ -97,12 +196,12 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
97196 ngpus = faiss .get_num_gpus ()
98197 lsq .icm_encoder_factory = faiss .GpuIcmEncoderFactory (ngpus )
99198 lsq .verbose = True
100- eval_quantizer (lsq , xb , xt , 'lsq-gpu' )
199+ eval_quantizer (lsq , xq , xb , gt , xt , metric_type )
101200
102201if 'pq' in todo :
103202 pq = faiss .ProductQuantizer (d , M , nbits )
104203 print ("===== PQ" )
105- eval_quantizer (pq , xq , xb , gt , xt )
204+ eval_quantizer (pq , xq , xb , gt , xt , metric_type )
106205
107206if 'opq' in todo :
108207 d2 = ((d + M - 1 ) // M ) * M
@@ -114,19 +213,19 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
114213 xt2 = opq .apply (xt )
115214 pq = faiss .ProductQuantizer (d2 , M , nbits )
116215 print ("===== PQ" )
117- eval_quantizer (pq , xq2 , xb2 , gt , xt2 )
216+ eval_quantizer (pq , xq2 , xb2 , gt , xt2 , metric_type )
118217
119218if 'prq' in todo :
120219 print (f"===== PRQ{ nsplits } x{ Msub } x{ nbits } " )
121220 prq = faiss .ProductResidualQuantizer (d , nsplits , Msub , nbits )
122221 variants = [("max_beam_size" , i ) for i in (1 , 2 , 4 , 8 , 16 , 32 )]
123- eval_quantizer (prq , xq , xb , gt , xt , variants = variants )
222+ eval_quantizer (prq , xq , xb , gt , xt , metric_type , variants = variants )
124223
125224if 'plsq' in todo :
126225 print (f"===== PLSQ{ nsplits } x{ Msub } x{ nbits } " )
127226 plsq = faiss .ProductLocalSearchQuantizer (d , nsplits , Msub , nbits )
128227 variants = [("encode_ils_iters" , i ) for i in (2 , 3 , 4 , 8 , 16 )]
129- eval_quantizer (plsq , xq , xb , gt , xt , variants = variants )
228+ eval_quantizer (plsq , xq , xb , gt , xt , metric_type , variants = variants )
130229
131230if 'rq' in todo :
132231 print ("===== RQ" )
@@ -136,7 +235,7 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
136235 # rq.train_type = faiss.ResidualQuantizer.Train_default
137236 # rq.verbose = True
138237 variants = [("max_beam_size" , i ) for i in (1 , 2 , 4 , 8 , 16 , 32 )]
139- eval_quantizer (rq , xq , xb , gt , xt , variants = variants )
238+ eval_quantizer (rq , xq , xb , gt , xt , metric_type , variants = variants )
140239
141240if 'rq_lut' in todo :
142241 print ("===== RQ" )
@@ -148,10 +247,34 @@ def eval_quantizer(q, xq, xb, gt, xt, variants=None):
148247 # rq.train_type = faiss.ResidualQuantizer.Train_default
149248 # rq.verbose = True
150249 variants = [("max_beam_size" , i ) for i in (1 , 2 , 4 , 8 , 16 , 32 , 64 )]
151- eval_quantizer (rq , xq , xb , gt , xt , variants = variants )
250+ eval_quantizer (rq , xq , xb , gt , xt , metric_type , variants = variants )
152251
153252if 'lsq' in todo :
154253 print ("===== LSQ" )
155254 lsq = faiss .LocalSearchQuantizer (d , M , nbits )
156255 variants = [("encode_ils_iters" , i ) for i in (2 , 3 , 4 , 8 , 16 )]
157- eval_quantizer (lsq , xq , xb , gt , xt , variants = variants )
256+ eval_quantizer (lsq , xq , xb , gt , xt , metric_type , variants = variants )
257+
258+ if selected_options & TURBOQUANT_OPTIONS :
259+ print ("===== TurboQuant" )
260+ store_norm = dataset_name != "glove"
261+ if hasattr (faiss , "IndexTurboQuantMSE" ):
262+ tq = faiss .IndexTurboQuantMSE (d , nbits , metric_type , 12345 , store_norm )
263+ elif hasattr (faiss , "TurboQuantizer" ):
264+ tq = faiss .TurboQuantizer (d , nbits , 12345 , store_norm )
265+ else :
266+ raise RuntimeError (
267+ "TurboQuant is not available in this faiss Python build. "
268+ "Rebuild the Python bindings so TurboQuant symbols are exported."
269+ )
270+ eval_quantizer (tq , xq , xb , gt , xt , metric_type )
271+
272+ if selected_options & RABITQ_OPTIONS :
273+ print ("===== RaBitQ" )
274+ if not hasattr (faiss , "IndexRaBitQ" ):
275+ raise RuntimeError (
276+ "RaBitQ is not available in this faiss Python build. "
277+ "Rebuild the Python bindings so RaBitQ symbols are exported."
278+ )
279+ rbq = faiss .IndexRaBitQ (d , metric_type , nbits )
280+ eval_quantizer (rbq , xq , xb , gt , xt , metric_type )
0 commit comments