1010# NOTEBOOK_NUMBER: N7030784 (685760243832285)
1111
1212""":py"""
13+ import statistics
1314import timeit
1415from collections import defaultdict
1516
1617import faiss
1718from faiss .contrib .datasets import SyntheticDataset
1819
1920""":py"""
20- ds : SyntheticDataset = SyntheticDataset (256 , 1_000_000 , 1_000_000 , 10_000 )
21+ # Dimensions to sweep. The rabitq SIMD kernel
22+ # (bitwise_and_dot_product / bitwise_xor_dot_product / popcount) selects
23+ # its widest tier based on size = d / 8 bytes:
24+ # d=256 -> size=32 -> 256-bit ymm only (no 512-bit work)
25+ # d=512 -> size=64 -> 512-bit zmm, 1 iteration per bit-plane
26+ # d=768 -> size=96 -> 512-bit zmm + 256-bit tail
27+ # d=1024 -> size=128 -> 512-bit zmm only, 2 iterations per bit-plane
28+ # Sweeping these is useful for verifying the AVX512_SPR (vpopcntdq)
29+ # specialization in faiss/utils/simd_impl/rabitq_avx512_spr.cpp and for
30+ # profiling perf-record annotations across SIMD-width tiers.
31+ DIMENSIONS = [256 , 512 , 768 , 1024 ]
2132nlist : int = 1000
2233qb : int = 8
34+ # Number of independent timing samples to take per (index, k, nprobe)
35+ # combination. Each sample is itself an average over `trials=10` calls
36+ # inside timeit, so total searches per row = ITERATIONS * 10. Using 3
37+ # samples is enough to flag whether differences across dimensions are
38+ # noise or real, while keeping the bench cheap.
39+ ITERATIONS : int = 3
2340# This will contain <"index name", ([recalls],[speeds],[labels (the k)])>
2441recall_speed_data = defaultdict (lambda : [[], [], []])
2542# This will contain <"index name", ([recalls],[memory for this index])>
2643recall_memory_data = defaultdict (lambda : [[], []])
2744
45+ # Set when entering each per-d block below; used by helpers that close
46+ # over the active dataset.
47+ ds : SyntheticDataset = None # type: ignore
48+
2849""":py"""
2950# Helpers
3051
@@ -62,6 +83,32 @@ def compute_recall(ground_truth_I, predicted_I):
6283 return recall
6384
6485
86+ def repeated_trials (trials_fn , * args , n = ITERATIONS , ** kwargs ):
87+ """Run a trials function n times and return the list of per-iteration
88+ average speeds (each in ms). Each call to trials_fn is itself an
89+ average over multiple back-to-back searches, so the returned list
90+ contains n independent samples of that average.
91+ """
92+ return [trials_fn (* args , ** kwargs ) for _ in range (n )]
93+
94+
95+ def summarize (samples ):
96+ """Return (mean, median, stdev) over a list of timing samples in ms.
97+ stdev is the sample standard deviation (n-1); returns 0.0 for n==1
98+ since stdev is undefined.
99+ """
100+ mean = statistics .mean (samples )
101+ median = statistics .median (samples )
102+ stdev = statistics .stdev (samples ) if len (samples ) > 1 else 0.0
103+ return mean , median , stdev
104+
105+
106+ def fmt_speed (samples ):
107+ """Format a list of timing samples as 'mean=X median=Y stdev=Z'."""
108+ mean , median , stdev = summarize (samples )
109+ return f"mean={ mean :.1f} ms median={ median :.1f} ms stdev={ stdev :.2f} ms"
110+
111+
65112def create_index (ds , factory_string ):
66113 index = faiss .index_factory (ds .d , factory_string )
67114 index .train (ds .get_train ())
@@ -73,13 +120,14 @@ def create_index(ds, factory_string):
73120def handle_index (prefix , index , ds , mem , k ):
74121 gt_I = ds .get_groundtruth (k )
75122 _ , I_res = index .search (ds .get_queries (), k )
76- avg_speed = trials (index , ds .get_queries (), k )
123+ speed_samples = repeated_trials (trials , index , ds .get_queries (), k )
124+ mean_speed , _ , _ = summarize (speed_samples )
77125 recall = compute_recall (gt_I , I_res )
78126 print (
79- f"{ prefix } recall@{ k } : { recall } . Average speed : { avg_speed :.1f } ms . Memory: { mem / 1e6 :.3f} MB"
127+ f"{ prefix } recall@{ k } : { recall } . Speed : { fmt_speed ( speed_samples ) } . Memory: { mem / 1e6 :.3f} MB"
80128 )
81129 recall_speed_data [prefix ][0 ].append (recall )
82- recall_speed_data [prefix ][1 ].append (avg_speed )
130+ recall_speed_data [prefix ][1 ].append (mean_speed )
83131 recall_speed_data [prefix ][2 ].append (f"k={ k } " )
84132 recall_memory_data [prefix ][0 ].append (recall )
85133 recall_memory_data [prefix ][1 ].append (mem )
@@ -91,13 +139,16 @@ def handle_ivf_index(prefix, index, ds, mem, k, params):
91139 for nprobe in 4 , 16 , 32 :
92140 params .nprobe = nprobe
93141 _ , I_res = faiss .search_with_parameters (index , ds .get_queries (), k , params )
94- avg_speed = trials_ivf (index , ds .get_queries (), k , params )
142+ speed_samples = repeated_trials (
143+ trials_ivf , index , ds .get_queries (), k , params
144+ )
145+ mean_speed , _ , _ = summarize (speed_samples )
95146 recall = compute_recall (gt_I , I_res )
96147 print (
97- f"{ prefix } nprobe={ nprobe } : recall@{ k } : { recall } . Average speed : { avg_speed :.1f } ms . Memory: { mem / 1e6 :.3f} MB"
148+ f"{ prefix } nprobe={ nprobe } : recall@{ k } : { recall } . Speed : { fmt_speed ( speed_samples ) } . Memory: { mem / 1e6 :.3f} MB"
98149 )
99150 recall_speed_data [prefix ][0 ].append (recall )
100- recall_speed_data [prefix ][1 ].append (avg_speed )
151+ recall_speed_data [prefix ][1 ].append (mean_speed )
101152 recall_speed_data [prefix ][2 ].append (f"k={ k } , nprobe={ nprobe } " )
102153 recall_memory_data [prefix ][0 ].append (recall )
103154 recall_memory_data [prefix ][1 ].append (mem )
@@ -106,7 +157,7 @@ def handle_ivf_index(prefix, index, ds, mem, k, params):
106157# pyre-ignore
107158def vary_k_nprobe_measuring_recall_and_memory (prefix , index , ds , mem ):
108159 classname = type (index ).__name__
109- for k in 1 , 10 , 100 :
160+ for k in ( 100 ,) :
110161 if classname in [
111162 "IndexRaBitQ" ,
112163 "IndexPQFastScan" ,
@@ -131,51 +182,69 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
131182 handle_ivf_index (prefix , index , ds , mem , k , params )
132183
133184""":py '605360559215064'"""
134- # IndexRaBitQ
135-
136- fac_s = "RaBitQ"
137- non_ivf_rbq = faiss .index_factory (ds .d , fac_s )
138- non_ivf_rbq .qb = qb
139- non_ivf_rbq .train (ds .get_train ())
140- non_ivf_rbq .add (ds .get_database ())
141- mem = non_ivf_rbq .code_size * non_ivf_rbq .ntotal
142-
143- vary_k_nprobe_measuring_recall_and_memory (fac_s , non_ivf_rbq , ds , mem )
144-
145- del non_ivf_rbq
146-
147- """:py '3928150077498381'"""
148- # IndexIVFRaBitQ with no random rotation
149-
150- fac_s = f"IVF{ nlist } ,RaBitQ"
151- rbq1 = faiss .index_factory (ds .d , fac_s )
152- rbq1 .qb = qb
153- rbq1 .train (ds .get_train ())
154- rbq1 .add (ds .get_database ())
155- mem = rbq1 .code_size * rbq1 .ntotal
156-
157- vary_k_nprobe_measuring_recall_and_memory (fac_s , rbq1 , ds , mem )
158-
159- del rbq1
160-
161- """:py '1484145352968190'"""
162- # IndexIVFRaBitQ with random rotation
163-
164- fac_s = f"IVF{ nlist } ,RaBitQ"
165- rbq2 = faiss .index_factory (ds .d , fac_s )
166- rbq2 .qb = qb
167- rrot = faiss .RandomRotationMatrix (ds .d , ds .d )
168- rrot .init (123 )
169- index_pt = faiss .IndexPreTransform (rrot , rbq2 )
170- index_pt .train (ds .get_train ())
171- index_pt .add (ds .get_database ())
172- mem = rbq2 .code_size * index_pt .ntotal
173-
174- vary_k_nprobe_measuring_recall_and_memory (fac_s + "_RROT" , index_pt , ds , mem )
185+ # RaBitQ kernels swept across dimensions. Each iteration rebuilds the
186+ # dataset and the three rabitq index variants. Suffix _d{d} on the
187+ # result key keeps the per-dimension series distinct in the plots.
188+
189+ for d in DIMENSIONS :
190+ print (f"\n ========== d={ d } ==========" )
191+ # Dataset sized to keep the full 4-dimension sweep under ~10 minutes.
192+ # nq=1k is enough for stable timeit averages across 10 trials; nb=200k
193+ # keeps groundtruth (brute-force knn over xb) tractable at d=1024;
194+ # nt=100k still satisfies the IVF k-means training-points floor for
195+ # nlist=1000 (39 × 1000 = 39k minimum).
196+ ds = SyntheticDataset (d , 100_000 , 200_000 , 1_000 )
197+
198+ # IndexRaBitQ
199+ fac_s = "RaBitQ"
200+ non_ivf_rbq = faiss .index_factory (ds .d , fac_s )
201+ non_ivf_rbq .qb = qb
202+ non_ivf_rbq .train (ds .get_train ())
203+ non_ivf_rbq .add (ds .get_database ())
204+ mem = non_ivf_rbq .code_size * non_ivf_rbq .ntotal
205+
206+ vary_k_nprobe_measuring_recall_and_memory (f"{ fac_s } _d{ d } " , non_ivf_rbq , ds , mem )
207+
208+ del non_ivf_rbq
209+
210+ # IndexIVFRaBitQ with no random rotation
211+ fac_s = f"IVF{ nlist } ,RaBitQ"
212+ rbq1 = faiss .index_factory (ds .d , fac_s )
213+ rbq1 .qb = qb
214+ rbq1 .train (ds .get_train ())
215+ rbq1 .add (ds .get_database ())
216+ mem = rbq1 .code_size * rbq1 .ntotal
217+
218+ vary_k_nprobe_measuring_recall_and_memory (f"{ fac_s } _d{ d } " , rbq1 , ds , mem )
219+
220+ del rbq1
221+
222+ # IndexIVFRaBitQ with random rotation
223+ fac_s = f"IVF{ nlist } ,RaBitQ"
224+ rbq2 = faiss .index_factory (ds .d , fac_s )
225+ rbq2 .qb = qb
226+ rrot = faiss .RandomRotationMatrix (ds .d , ds .d )
227+ rrot .init (123 )
228+ index_pt = faiss .IndexPreTransform (rrot , rbq2 )
229+ index_pt .train (ds .get_train ())
230+ index_pt .add (ds .get_database ())
231+ mem = rbq2 .code_size * index_pt .ntotal
232+
233+ vary_k_nprobe_measuring_recall_and_memory (
234+ f"{ fac_s } _RROT_d{ d } " , index_pt , ds , mem
235+ )
175236
176- del index_pt
237+ del index_pt
177238
178239""":py '644702398382829'"""
240+ # Non-rabitq baselines (SQ, PQfs, HNSW) below. These don't exercise the
241+ # rabitq SIMD kernels, so we don't sweep dimensions for them; instead
242+ # we pick one dimension and build them once. Change BASELINE_D if you
243+ # want a different working point, or comment out the cells below if
244+ # you only care about the rabitq sweep.
245+ BASELINE_D = 256
246+ ds = SyntheticDataset (BASELINE_D , 100_000 , 200_000 , 1_000 )
247+
179248# IndexScalarQuantizer
180249
181250for M in [4 , 6 , 8 ]:
@@ -270,7 +339,7 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
270339 speeds ,
271340 linestyle = " " ,
272341 marker = "o" ,
273- color = colors [i ],
342+ color = colors [i % len ( colors ) ],
274343 label = key ,
275344 markersize = 15 ,
276345 )
@@ -311,15 +380,15 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
311380 mems ,
312381 linestyle = " " ,
313382 marker = "o" ,
314- color = colors [i ],
383+ color = colors [i % len ( colors ) ],
315384 label = key ,
316385 markersize = 10 ,
317386 )
318387
319388 texts = []
320389 if i == 0 :
321- texts . append ( plt . text ( recalls [ 0 ], mems [ 0 ], "RaBitQ" ))
322- texts .append (plt .text (recalls [1 ], mems [1 ], "RaBitQ" ))
390+ for j in range ( min ( 2 , len ( recalls ))):
391+ texts .append (plt .text (recalls [j ], mems [j ], "RaBitQ" ))
323392 adjust_text (
324393 texts ,
325394 arrowprops = dict (arrowstyle = "-" , color = "black" , lw = 0.5 ),
0 commit comments