Skip to content

Commit 45a7921

Browse files
committed
Speed up polysemous training with AVX-512
Add AVX-512 implementations of the compute_cost and cost_update hot loops for both ReproduceWithHammingObjective and ReproduceDistancesObjective. The vectorized paths use 512-bit packed double FMA, masked blends for branchless swap handling, and a portable popcnt_512 helper that uses _mm512_popcnt_epi64 when AVX512VPOPCNTDQ is available or falls back to a nibble-lookup approach. Dispatch is guarded by COMPILE_SIMD_AVX512 and the SIMD dynamic dispatch level, falling back to the existing scalar code with zero overhead on non-AVX-512 systems. Benchmarks of the training phase on SIFT1M (bench_polysemous_sift1m.py) show ~1.09x speedup over the scalar path on Sapphire Rapids. Signed-off-by: Mulugeta Mammo <[email protected]>
1 parent ca87f41 commit 45a7921

5 files changed

Lines changed: 381 additions & 11 deletions

File tree

benchs/bench_polysemous_sift1m.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,39 @@
66

77
from __future__ import print_function
88

9+
import time
10+
import numpy as np
911
import faiss
1012
from datasets import load_sift1M, evaluate
1113

14+
NUM_TRAIN_RUNS = 5
1215

1316
print("load data")
1417
xb, xq, xt, gt = load_sift1M()
1518
nq, d = xq.shape
1619

17-
# index with 16 subquantizers, 8 bit each
18-
index = faiss.IndexPQ(d, 16, 8)
19-
index.do_polysemous_training = True
20-
index.verbose = True
20+
train_times = []
21+
for run in range(NUM_TRAIN_RUNS):
22+
index = faiss.IndexPQ(d, 16, 8)
23+
index.do_polysemous_training = True
24+
index.verbose = (run == 0)
2125

22-
print("train")
26+
print("train run %d/%d" % (run + 1, NUM_TRAIN_RUNS))
2327

24-
index.train(xt)
28+
t0 = time.time()
29+
index.train(xt)
30+
t1 = time.time()
31+
elapsed = t1 - t0
32+
train_times.append(elapsed)
33+
print(" Training time: %.2f s" % elapsed)
2534

26-
print("add vectors to index")
35+
times = np.array(train_times)
36+
print("\nTraining time over %d runs: "
37+
"median %.2f s, mean %.2f s, std %.2f s, min %.2f s, max %.2f s"
38+
% (NUM_TRAIN_RUNS, np.median(times), np.mean(times),
39+
np.std(times), np.min(times), np.max(times)))
40+
41+
print("\nadd vectors to index")
2742

2843
index.add(xb)
2944

@@ -42,3 +57,4 @@
4257
index.polysemous_ht = ht
4358
t, r = evaluate(index, xq, gt, 1)
4459
print("\t %7.3f ms per query, R@1 %.4f" % (t, r[1]))
60+

faiss/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ set(FAISS_SIMD_AVX512_SRC
2626
impl/fast_scan/impl-avx512.cpp
2727
impl/hnsw/avx512.cpp
2828
impl/pq_code_distance/avx512.cpp
29+
impl/polysemous_training/avx512.cpp
2930
impl/scalar_quantizer/sq-avx512.cpp
3031
impl/binary_hamming/avx512.cpp
3132
utils/simd_impl/distances_avx512.cpp
@@ -282,6 +283,7 @@ set(FAISS_HEADERS
282283
impl/PanoramaStats.h
283284
impl/PdxLayout.h
284285
impl/PolysemousTraining.h
286+
impl/polysemous_training/avx512.h
285287
impl/ProductQuantizer-inl.h
286288
impl/ProductQuantizer.h
287289
impl/Quantizer.h

faiss/impl/PolysemousTraining.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cstring>
1919
#include <memory>
2020

21+
#include <faiss/impl/polysemous_training/avx512.h>
2122
#include <faiss/impl/simd_dispatch.h>
2223
#include <faiss/utils/distances.h>
2324
#include <faiss/utils/hamming.h>
@@ -168,17 +169,17 @@ static inline int hamming_dis(uint64_t a, uint64_t b) {
168169
return popcount64(a ^ b);
169170
}
170171

172+
static inline double sqr(double x) {
173+
return x * x;
174+
}
175+
171176
namespace {
172177

173178
/// optimize permutation to reproduce a distance table with Hamming distances
174179
struct ReproduceWithHammingObjective : PermutationObjective {
175180
int nbits;
176181
double dis_weight_factor;
177182

178-
static double sqr(double x) {
179-
return x * x;
180-
}
181-
182183
// weighting of distances: it is more important to reproduce small
183184
// distances well
184185
double dis_weight(double x) const {
@@ -190,6 +191,13 @@ struct ReproduceWithHammingObjective : PermutationObjective {
190191

191192
// cost = quadratic difference between actual distance and Hamming distance
192193
double compute_cost(const int* perm) const override {
194+
#ifdef COMPILE_SIMD_AVX512
195+
if (SIMDConfig::level == SIMDLevel::AVX512 ||
196+
SIMDConfig::level == SIMDLevel::AVX512_SPR) {
197+
return polysemous_avx512::hamming_compute_cost_avx512(
198+
n, perm, target_dis.data(), weights.data());
199+
}
200+
#endif
193201
double cost = 0;
194202
for (int i = 0; i < n; i++) {
195203
for (int j = 0; j < n; j++) {
@@ -205,6 +213,13 @@ struct ReproduceWithHammingObjective : PermutationObjective {
205213
// what would the cost update be if iw and jw were swapped?
206214
// computed in O(n) instead of O(n^2) for the full re-computation
207215
double cost_update(const int* perm, int iw, int jw) const override {
216+
#ifdef COMPILE_SIMD_AVX512
217+
if (SIMDConfig::level == SIMDLevel::AVX512 ||
218+
SIMDConfig::level == SIMDLevel::AVX512_SPR) {
219+
return polysemous_avx512::hamming_cost_update_avx512(
220+
n, perm, iw, jw, target_dis.data(), weights.data());
221+
}
222+
#endif
208223
double delta_cost = 0;
209224

210225
for (int i = 0; i < n; i++) {
@@ -308,6 +323,12 @@ double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
308323

309324
// cost = quadratic difference between actual distance and Hamming distance
310325
double ReproduceDistancesObjective::compute_cost(const int* perm) const {
326+
#ifdef COMPILE_SIMD_AVX512
327+
if (SIMDConfig::level == SIMDLevel::AVX512 ||
328+
SIMDConfig::level == SIMDLevel::AVX512_SPR) {
329+
return polysemous_avx512::distances_compute_cost_avx512(*this, perm);
330+
}
331+
#endif
311332
double cost = 0;
312333
for (int i = 0; i < n; i++) {
313334
for (int j = 0; j < n; j++) {
@@ -324,6 +345,13 @@ double ReproduceDistancesObjective::compute_cost(const int* perm) const {
324345
// computed in O(n) instead of O(n^2) for the full re-computation
325346
double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
326347
const {
348+
#ifdef COMPILE_SIMD_AVX512
349+
if (SIMDConfig::level == SIMDLevel::AVX512 ||
350+
SIMDConfig::level == SIMDLevel::AVX512_SPR) {
351+
return polysemous_avx512::distances_cost_update_avx512(
352+
*this, perm, iw, jw);
353+
}
354+
#endif
327355
double delta_cost = 0;
328356
for (int i = 0; i < n; i++) {
329357
if (i == iw) {

0 commit comments

Comments
 (0)