Skip to content

Commit db91a6c

Browse files
committed
Add VPOPCNTDQ-based RaBitQ kernel for Sapphire Rapids+
Add SIMDLevel::AVX512_SPR specializations of bitwise_and_dot_product, bitwise_xor_dot_product, and popcount that use vpopcntq instead of the shuffle-based lookup-table popcount. Used on SPR+ when built with FAISS_OPT_LEVEL=avx512_spr or =dd; AVX-512-only builds are unchanged. Recall is bit-identical. End-to-end ~1.07x on IndexIVFRaBitQ, ~1.4-1.6x on flat RaBitQ scans (kernel-isolated). Measured on Sapphire Rapids with benchs/bench_rabitq.py at d=256/512/768/1024. Signed-off-by: Mulugeta Mammo <mulugeta.mammo@intel.com>
1 parent abdd37b commit db91a6c

6 files changed

Lines changed: 526 additions & 58 deletions

File tree

benchs/bench_rabitq.py

Lines changed: 123 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@
1010
# NOTEBOOK_NUMBER: N7030784 (685760243832285)
1111

1212
""":py"""
13+
import statistics
1314
import timeit
1415
from collections import defaultdict
1516

1617
import faiss
1718
from faiss.contrib.datasets import SyntheticDataset
1819

1920
""":py"""
20-
ds: SyntheticDataset = SyntheticDataset(256, 1_000_000, 1_000_000, 10_000)
21+
# Dimensions to sweep. The rabitq SIMD kernel
22+
# (bitwise_and_dot_product / bitwise_xor_dot_product / popcount) selects
23+
# its widest tier based on size = d / 8 bytes:
24+
# d=256 -> size=32 -> 256-bit ymm only (no 512-bit work)
25+
# d=512 -> size=64 -> 512-bit zmm, 1 iteration per bit-plane
26+
# d=768 -> size=96 -> 512-bit zmm + 256-bit tail
27+
# d=1024 -> size=128 -> 512-bit zmm only, 2 iterations per bit-plane
28+
# Sweeping these is useful for verifying the AVX512_SPR (vpopcntdq)
29+
# specialization in faiss/utils/simd_impl/rabitq_avx512_spr.cpp and for
30+
# profiling perf-record annotations across SIMD-width tiers.
31+
DIMENSIONS = [256, 512, 768, 1024]
2132
nlist: int = 1000
2233
qb: int = 8
34+
# Number of independent timing samples to take per (index, k, nprobe)
35+
# combination. Each sample is itself an average over `trials=10` calls
36+
# inside timeit, so total searches per row = ITERATIONS * 10. Using 3
37+
# samples is enough to flag whether differences across dimensions are
38+
# noise or real, while keeping the bench cheap.
39+
ITERATIONS: int = 3
2340
# This will contain <"index name", ([recalls],[speeds],[labels (the k)])>
2441
recall_speed_data = defaultdict(lambda: [[], [], []])
2542
# This will contain <"index name", ([recalls],[memory for this index])>
2643
recall_memory_data = defaultdict(lambda: [[], []])
2744

45+
# Set when entering each per-d block below; used by helpers that close
46+
# over the active dataset.
47+
ds: SyntheticDataset = None # type: ignore
48+
2849
""":py"""
2950
# Helpers
3051

@@ -62,6 +83,32 @@ def compute_recall(ground_truth_I, predicted_I):
6283
return recall
6384

6485

86+
def repeated_trials(trials_fn, *args, n=ITERATIONS, **kwargs):
87+
"""Run a trials function n times and return the list of per-iteration
88+
average speeds (each in ms). Each call to trials_fn is itself an
89+
average over multiple back-to-back searches, so the returned list
90+
contains n independent samples of that average.
91+
"""
92+
return [trials_fn(*args, **kwargs) for _ in range(n)]
93+
94+
95+
def summarize(samples):
96+
"""Return (mean, median, stdev) over a list of timing samples in ms.
97+
stdev is the sample standard deviation (n-1); returns 0.0 for n==1
98+
since stdev is undefined.
99+
"""
100+
mean = statistics.mean(samples)
101+
median = statistics.median(samples)
102+
stdev = statistics.stdev(samples) if len(samples) > 1 else 0.0
103+
return mean, median, stdev
104+
105+
106+
def fmt_speed(samples):
107+
"""Format a list of timing samples as 'mean=X median=Y stdev=Z'."""
108+
mean, median, stdev = summarize(samples)
109+
return f"mean={mean:.1f}ms median={median:.1f}ms stdev={stdev:.2f}ms"
110+
111+
65112
def create_index(ds, factory_string):
66113
index = faiss.index_factory(ds.d, factory_string)
67114
index.train(ds.get_train())
@@ -73,13 +120,14 @@ def create_index(ds, factory_string):
73120
def handle_index(prefix, index, ds, mem, k):
74121
gt_I = ds.get_groundtruth(k)
75122
_, I_res = index.search(ds.get_queries(), k)
76-
avg_speed = trials(index, ds.get_queries(), k)
123+
speed_samples = repeated_trials(trials, index, ds.get_queries(), k)
124+
mean_speed, _, _ = summarize(speed_samples)
77125
recall = compute_recall(gt_I, I_res)
78126
print(
79-
f"{prefix} recall@{k}: {recall}. Average speed: {avg_speed:.1f}ms. Memory: {mem/1e6:.3f}MB"
127+
f"{prefix} recall@{k}: {recall}. Speed: {fmt_speed(speed_samples)}. Memory: {mem/1e6:.3f}MB"
80128
)
81129
recall_speed_data[prefix][0].append(recall)
82-
recall_speed_data[prefix][1].append(avg_speed)
130+
recall_speed_data[prefix][1].append(mean_speed)
83131
recall_speed_data[prefix][2].append(f"k={k}")
84132
recall_memory_data[prefix][0].append(recall)
85133
recall_memory_data[prefix][1].append(mem)
@@ -91,13 +139,16 @@ def handle_ivf_index(prefix, index, ds, mem, k, params):
91139
for nprobe in 4, 16, 32:
92140
params.nprobe = nprobe
93141
_, I_res = faiss.search_with_parameters(index, ds.get_queries(), k, params)
94-
avg_speed = trials_ivf(index, ds.get_queries(), k, params)
142+
speed_samples = repeated_trials(
143+
trials_ivf, index, ds.get_queries(), k, params
144+
)
145+
mean_speed, _, _ = summarize(speed_samples)
95146
recall = compute_recall(gt_I, I_res)
96147
print(
97-
f"{prefix} nprobe={nprobe}: recall@{k}: {recall}. Average speed: {avg_speed:.1f}ms. Memory: {mem/1e6:.3f}MB"
148+
f"{prefix} nprobe={nprobe}: recall@{k}: {recall}. Speed: {fmt_speed(speed_samples)}. Memory: {mem/1e6:.3f}MB"
98149
)
99150
recall_speed_data[prefix][0].append(recall)
100-
recall_speed_data[prefix][1].append(avg_speed)
151+
recall_speed_data[prefix][1].append(mean_speed)
101152
recall_speed_data[prefix][2].append(f"k={k}, nprobe={nprobe}")
102153
recall_memory_data[prefix][0].append(recall)
103154
recall_memory_data[prefix][1].append(mem)
@@ -106,7 +157,7 @@ def handle_ivf_index(prefix, index, ds, mem, k, params):
106157
# pyre-ignore
107158
def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
108159
classname = type(index).__name__
109-
for k in 1, 10, 100:
160+
for k in (100,):
110161
if classname in [
111162
"IndexRaBitQ",
112163
"IndexPQFastScan",
@@ -131,51 +182,69 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
131182
handle_ivf_index(prefix, index, ds, mem, k, params)
132183

133184
""":py '605360559215064'"""
134-
# IndexRaBitQ
135-
136-
fac_s = "RaBitQ"
137-
non_ivf_rbq = faiss.index_factory(ds.d, fac_s)
138-
non_ivf_rbq.qb = qb
139-
non_ivf_rbq.train(ds.get_train())
140-
non_ivf_rbq.add(ds.get_database())
141-
mem = non_ivf_rbq.code_size * non_ivf_rbq.ntotal
142-
143-
vary_k_nprobe_measuring_recall_and_memory(fac_s, non_ivf_rbq, ds, mem)
144-
145-
del non_ivf_rbq
146-
147-
""":py '3928150077498381'"""
148-
# IndexIVFRaBitQ with no random rotation
149-
150-
fac_s = f"IVF{nlist},RaBitQ"
151-
rbq1 = faiss.index_factory(ds.d, fac_s)
152-
rbq1.qb = qb
153-
rbq1.train(ds.get_train())
154-
rbq1.add(ds.get_database())
155-
mem = rbq1.code_size * rbq1.ntotal
156-
157-
vary_k_nprobe_measuring_recall_and_memory(fac_s, rbq1, ds, mem)
158-
159-
del rbq1
160-
161-
""":py '1484145352968190'"""
162-
# IndexIVFRaBitQ with random rotation
163-
164-
fac_s = f"IVF{nlist},RaBitQ"
165-
rbq2 = faiss.index_factory(ds.d, fac_s)
166-
rbq2.qb = qb
167-
rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
168-
rrot.init(123)
169-
index_pt = faiss.IndexPreTransform(rrot, rbq2)
170-
index_pt.train(ds.get_train())
171-
index_pt.add(ds.get_database())
172-
mem = rbq2.code_size * index_pt.ntotal
173-
174-
vary_k_nprobe_measuring_recall_and_memory(fac_s + "_RROT", index_pt, ds, mem)
185+
# RaBitQ kernels swept across dimensions. Each iteration rebuilds the
186+
# dataset and the three rabitq index variants. Suffix _d{d} on the
187+
# result key keeps the per-dimension series distinct in the plots.
188+
189+
for d in DIMENSIONS:
190+
print(f"\n========== d={d} ==========")
191+
# Dataset sized to keep the full 4-dimension sweep under ~10 minutes.
192+
# nq=1k is enough for stable timeit averages across 10 trials; nb=200k
193+
# keeps groundtruth (brute-force knn over xb) tractable at d=1024;
194+
# nt=100k still satisfies the IVF k-means training-points floor for
195+
# nlist=1000 (39 × 1000 = 39k minimum).
196+
ds = SyntheticDataset(d, 100_000, 200_000, 1_000)
197+
198+
# IndexRaBitQ
199+
fac_s = "RaBitQ"
200+
non_ivf_rbq = faiss.index_factory(ds.d, fac_s)
201+
non_ivf_rbq.qb = qb
202+
non_ivf_rbq.train(ds.get_train())
203+
non_ivf_rbq.add(ds.get_database())
204+
mem = non_ivf_rbq.code_size * non_ivf_rbq.ntotal
205+
206+
vary_k_nprobe_measuring_recall_and_memory(f"{fac_s}_d{d}", non_ivf_rbq, ds, mem)
207+
208+
del non_ivf_rbq
209+
210+
# IndexIVFRaBitQ with no random rotation
211+
fac_s = f"IVF{nlist},RaBitQ"
212+
rbq1 = faiss.index_factory(ds.d, fac_s)
213+
rbq1.qb = qb
214+
rbq1.train(ds.get_train())
215+
rbq1.add(ds.get_database())
216+
mem = rbq1.code_size * rbq1.ntotal
217+
218+
vary_k_nprobe_measuring_recall_and_memory(f"{fac_s}_d{d}", rbq1, ds, mem)
219+
220+
del rbq1
221+
222+
# IndexIVFRaBitQ with random rotation
223+
fac_s = f"IVF{nlist},RaBitQ"
224+
rbq2 = faiss.index_factory(ds.d, fac_s)
225+
rbq2.qb = qb
226+
rrot = faiss.RandomRotationMatrix(ds.d, ds.d)
227+
rrot.init(123)
228+
index_pt = faiss.IndexPreTransform(rrot, rbq2)
229+
index_pt.train(ds.get_train())
230+
index_pt.add(ds.get_database())
231+
mem = rbq2.code_size * index_pt.ntotal
232+
233+
vary_k_nprobe_measuring_recall_and_memory(
234+
f"{fac_s}_RROT_d{d}", index_pt, ds, mem
235+
)
175236

176-
del index_pt
237+
del index_pt
177238

178239
""":py '644702398382829'"""
240+
# Non-rabitq baselines (SQ, PQfs, HNSW) below. These don't exercise the
241+
# rabitq SIMD kernels, so we don't sweep dimensions for them; instead
242+
# we pick one dimension and build them once. Change BASELINE_D if you
243+
# want a different working point, or comment out the cells below if
244+
# you only care about the rabitq sweep.
245+
BASELINE_D = 256
246+
ds = SyntheticDataset(BASELINE_D, 100_000, 200_000, 1_000)
247+
179248
# IndexScalarQuantizer
180249

181250
for M in [4, 6, 8]:
@@ -270,7 +339,7 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
270339
speeds,
271340
linestyle=" ",
272341
marker="o",
273-
color=colors[i],
342+
color=colors[i % len(colors)],
274343
label=key,
275344
markersize=15,
276345
)
@@ -311,15 +380,15 @@ def vary_k_nprobe_measuring_recall_and_memory(prefix, index, ds, mem):
311380
mems,
312381
linestyle=" ",
313382
marker="o",
314-
color=colors[i],
383+
color=colors[i % len(colors)],
315384
label=key,
316385
markersize=10,
317386
)
318387

319388
texts = []
320389
if i == 0:
321-
texts.append(plt.text(recalls[0], mems[0], "RaBitQ"))
322-
texts.append(plt.text(recalls[1], mems[1], "RaBitQ"))
390+
for j in range(min(2, len(recalls))):
391+
texts.append(plt.text(recalls[j], mems[j], "RaBitQ"))
323392
adjust_text(
324393
texts,
325394
arrowprops=dict(arrowstyle="-", color="black", lw=0.5),

faiss/CMakeLists.txt

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ set(FAISS_SIMD_AVX512_SRC
3434
utils/simd_impl/rabitq_avx512.cpp
3535
utils/simd_impl/super_kmeans_kernels_avx512.cpp
3636
)
37+
# AVX-512 sources that additionally require AVX512_VPOPCNTDQ
38+
# (Sapphire Rapids and later). Compiled into faiss_avx512_spr only,
39+
# and into the DD faiss target with an extra per-file -mavx512vpopcntdq
40+
# flag. Not compiled into faiss_avx512 — that target stops at the
41+
# baseline AVX-512F/CD/VL/DQ/BW feature set.
42+
set(FAISS_SIMD_AVX512_SPR_SRC
43+
utils/simd_impl/rabitq_avx512_spr.cpp
44+
)
3745
set(FAISS_SIMD_NEON_SRC
3846
impl/fast_scan/impl-neon.cpp
3947
impl/scalar_quantizer/sq-neon.cpp
@@ -61,7 +69,7 @@ set(FAISS_SIMD_RVV_SRC
6169
)
6270
# Select SIMD sources based on target architecture
6371
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64|amd64|AMD64)")
64-
set(FAISS_SIMD_SRC ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC})
72+
set(FAISS_SIMD_SRC ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC} ${FAISS_SIMD_AVX512_SPR_SRC})
6573
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
6674
set(FAISS_SIMD_SRC ${FAISS_SIMD_NEON_SRC} ${FAISS_SIMD_SVE_SRC})
6775
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(riscv64|riscv)")
@@ -467,7 +475,7 @@ else()
467475
# we need bigobj for the swig wrapper
468476
add_compile_options(/bigobj)
469477
endif()
470-
target_sources(faiss_avx512_spr PRIVATE ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC})
478+
target_sources(faiss_avx512_spr PRIVATE ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC} ${FAISS_SIMD_AVX512_SPR_SRC})
471479
target_compile_definitions(faiss_avx512_spr PRIVATE COMPILE_SIMD_AVX2 COMPILE_SIMD_AVX512 COMPILE_SIMD_AVX512_SPR )
472480

473481
add_library(faiss_sve ${FAISS_SRC})
@@ -525,6 +533,14 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
525533
PROPERTIES COMPILE_OPTIONS
526534
"-mavx512f;-mavx512cd;-mavx512vl;-mavx512dq;-mavx512bw;-mfma;-mf16c;-mpopcnt"
527535
)
536+
# SPR-only sources additionally require AVX512_VPOPCNTDQ.
537+
# vpopcntq is the whole point of these files, so the extra flag
538+
# is mandatory; without it the intrinsics fail to compile.
539+
set_source_files_properties(${FAISS_SIMD_AVX512_SPR_SRC}
540+
TARGET_DIRECTORY faiss
541+
PROPERTIES COMPILE_OPTIONS
542+
"-mavx512f;-mavx512cd;-mavx512vl;-mavx512dq;-mavx512bw;-mavx512vpopcntdq;-mfma;-mf16c;-mpopcnt"
543+
)
528544
else()
529545
# Per-file SIMD flags (MSVC)
530546
add_compile_options(/bigobj)
@@ -536,6 +552,15 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
536552
TARGET_DIRECTORY faiss
537553
PROPERTIES COMPILE_OPTIONS "/arch:AVX512"
538554
)
555+
# MSVC has no per-feature flag for VPOPCNTDQ; /arch:AVX512
556+
# enables it as part of the AVX-512 feature set used by recent
557+
# toolchains. (Newer MSVC supports __isa_available checks at
558+
# runtime; the SPR specialization is gated by COMPILE_SIMD_AVX512_SPR
559+
# at the source level, so this is safe.)
560+
set_source_files_properties(${FAISS_SIMD_AVX512_SPR_SRC}
561+
TARGET_DIRECTORY faiss
562+
PROPERTIES COMPILE_OPTIONS "/arch:AVX512"
563+
)
539564
endif()
540565
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
541566
# ARM NEON is always available on aarch64, no special compiler flags needed

faiss/impl/RaBitQUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ float compute_full_multibit_distance(
321321
size_t d,
322322
size_t ex_bits,
323323
MetricType metric_type) {
324-
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
324+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
325325
[&]<SIMDLevel SL>() {
326326
return compute_full_multibit_distance<SL>(
327327
sign_bits,

faiss/impl/RaBitQuantizer.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,13 @@ FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
551551
// Dispatch on SIMDLevel once here so the distance computer methods
552552
// call the SIMD-specialized rabitq functions directly (no per-call
553553
// with_simd_level overhead).
554-
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
554+
//
555+
// Use A0_SPR (which includes AVX512_SPR) so that on Sapphire Rapids
556+
// and later x86 microarchitectures the VPOPCNTDQ-based RaBitQ
557+
// specialization in rabitq_avx512_spr.cpp is selected. On AVX-512
558+
// CPUs without VPOPCNTDQ, dispatch falls through to the AVX512
559+
// specialization in rabitq_avx512.cpp.
560+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
555561
[&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
556562
if (qb == 0) {
557563
auto dc =

faiss/impl/simd_dispatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ constexpr int AVAILABLE_SIMD_LEVELS_AVX2_NEON = AVAILABLE_SIMD_LEVELS_NONE |
3636
constexpr int AVAILABLE_SIMD_LEVELS_A0 = AVAILABLE_SIMD_LEVELS_AVX2_NEON |
3737
(1 << int(SIMDLevel::AVX512)) | (1 << int(SIMDLevel::RISCV_RVV));
3838

39+
// A0_SPR: same as A0 + AVX512_SPR (for functions with a dedicated SPR
40+
// specialization on top of an AVX512 fallback). Currently used by the
41+
// RaBitQ popcount kernels, which use VPOPCNTDQ on SPR+.
42+
constexpr int AVAILABLE_SIMD_LEVELS_A0_SPR =
43+
AVAILABLE_SIMD_LEVELS_A0 | (1 << int(SIMDLevel::AVX512_SPR));
44+
3945
// A1: same + ARM_SVE (for functions with dedicated SVE implementations)
4046
constexpr int AVAILABLE_SIMD_LEVELS_A1 =
4147
AVAILABLE_SIMD_LEVELS_A0 | (1 << int(SIMDLevel::ARM_SVE));

0 commit comments

Comments
 (0)