Skip to content

Commit 9d56749

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Wire 512-bit QBS kernels into fast scan DD dispatch (facebookresearch#5075)
Summary: Pull Request resolved: facebookresearch#5075 In DD mode, the QBS (bbs=32) accumulate path always used 256-bit kernels, even in the AVX512 per-ISA TU. The 512-bit kernels in kernels_simd512.h were dead because bare simdlib aliases resolve to _tpl<NONE> in DD mode, and 512-bit NONE types don't exist (empty primary templates). Fix: add function-local using declarations in both 512-bit kernel functions to bind types to explicit AVX512/AVX2 levels. Create accumulate_loops_512.h with FixedStorage512 (a non-virtual intermediate handler that bridges the AVX2→NONE type gap via storeu/loadu at the handler boundary) and the 512-bit QBS accumulate loop. Wire it into dispatching.h's ScannerMixIn behind an Reviewed By: mdouze Differential Revision: D100151879 fbshipit-source-id: b801f897f2d061a8448842f42edcdeb3a447eafd
1 parent 9cbc8da commit 9d56749

5 files changed

Lines changed: 274 additions & 14 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ set(FAISS_HEADERS
277277
impl/lattice_Zn.h
278278
impl/platform_macros.h
279279
impl/fast_scan/accumulate_loops.h
280+
impl/fast_scan/accumulate_loops_512.h
280281
impl/fast_scan/dispatching.h
281282
impl/fast_scan/fast_scan.h
282283
impl/fast_scan/decompose_qbs.h

faiss/impl/fast_scan/accumulate_loops.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
* - accumulate_q_4step_256 / pq4_accumulate_loop_qbs_fixed_scaler_256
1818
* (QBS path, bbs == 32, 256-bit kernel only)
1919
*
20-
* The QBS helpers use pq4_kernel_qbs_256 exclusively (not decompose_qbs.h)
21-
* because decompose_qbs.h includes kernels_simd512.h which uses 512-bit
22-
* types that are empty primary templates when SINGLE_SIMD_LEVEL=NONE
23-
* (DD mode). SL-parameterizing the 512-bit kernels is future work.
20+
* The QBS helpers here use pq4_kernel_qbs_256 exclusively (not
21+
* decompose_qbs.h) because decompose_qbs.h includes kernels_simd512.h
22+
* whose 512-bit types need explicit SIMD levels. The 512-bit QBS path
23+
* lives in accumulate_loops_512.h, used by the AVX512 per-ISA TU.
2424
*
2525
* All functions live in `namespace faiss` (not anonymous) so they can be
2626
* shared by both the per-SIMD TU dispatcher (dispatching.h) and the old
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
/**
11+
* @file accumulate_loops_512.h
12+
* @brief 512-bit QBS accumulation loop for AVX512 per-ISA TUs.
13+
*
14+
* Mirrors accumulate_loops.h's QBS path but uses pq4_kernel_qbs_512
15+
* (from kernels_simd512.h) instead of pq4_kernel_qbs_256.
16+
*
17+
* The 512-bit kernels produce simd16uint16_tpl<AVX2> results (via
18+
* combine4x2). The virtual SIMDResultHandler::handle() expects
19+
* simd16uint16_tpl<NONE> in DD mode. FixedStorage512 bridges this gap:
20+
* it stores AVX2-level results internally, then converts to the handler's
21+
* level via storeu/load in to_other_handler().
22+
*
23+
* Only included from the AVX512 per-ISA TU (impl-avx512.cpp) via
24+
* dispatching.h's conditional include.
25+
*/
26+
27+
#if defined(COMPILE_SIMD_AVX512) && defined(__AVX512F__)
28+
29+
#include <cassert>
30+
31+
#include <faiss/impl/FaissAssert.h>
32+
#include <faiss/impl/fast_scan/accumulate_loops.h>
33+
#include <faiss/impl/fast_scan/kernels_simd512.h>
34+
#include <faiss/impl/fast_scan/simd_result_handlers.h>
35+
36+
namespace faiss {
37+
38+
using namespace simd_result_handlers;
39+
40+
/***************************************************************
41+
* FixedStorage512: non-virtual intermediate result storage
42+
* for 512-bit kernels.
43+
*
44+
* Does NOT inherit from SIMDResultHandler — the virtual handle()
45+
* signature is pinned to simd16uint16_tpl<NONE> in DD mode, but
46+
* 512-bit kernels produce simd16uint16_tpl<AVX2>. By avoiding
47+
* inheritance, handle() can accept AVX2-level types directly.
48+
*
49+
* The conversion to the outer handler's type happens in
50+
* to_other_handler() via a store-to-memory roundtrip.
51+
***************************************************************/
52+
53+
template <int NQ, int BB>
54+
struct FixedStorage512 {
55+
using simd16uint16_avx2 = simd16uint16_tpl<SIMDLevel::AVX2>;
56+
57+
simd16uint16_avx2 dis[NQ][BB];
58+
int i0 = 0;
59+
60+
void handle(
61+
size_t q,
62+
size_t b,
63+
simd16uint16_avx2 d0,
64+
simd16uint16_avx2 d1) {
65+
dis[q + i0][2 * b] = d0;
66+
dis[q + i0][2 * b + 1] = d1;
67+
}
68+
69+
void set_block_origin(size_t i0_in, size_t) {
70+
this->i0 = i0_in;
71+
}
72+
73+
template <class OtherResultHandler>
74+
void to_other_handler(OtherResultHandler& other) const {
75+
using handler_simd16 = simd16uint16_tpl<SINGLE_SIMD_LEVEL_256>;
76+
for (int q = 0; q < NQ; q++) {
77+
for (int b = 0; b < BB; b += 2) {
78+
// Convert AVX2 → handler level (NONE in DD mode)
79+
ALIGNED(32) uint16_t buf0[16], buf1[16];
80+
dis[q][b].storeu(buf0);
81+
dis[q][b + 1].storeu(buf1);
82+
handler_simd16 h0, h1;
83+
h0.loadu(buf0);
84+
h1.loadu(buf1);
85+
other.handle(q, b / 2, h0, h1);
86+
}
87+
}
88+
}
89+
};
90+
91+
/***************************************************************
92+
* QBS path: 512-bit kernel variants
93+
***************************************************************/
94+
95+
template <int QBS, class ResultHandler, class Scaler>
96+
void accumulate_q_4step_512(
97+
size_t ntotal2,
98+
int nsq,
99+
const uint8_t* codes,
100+
const uint8_t* LUT0,
101+
ResultHandler& res,
102+
const Scaler& scaler,
103+
size_t block_stride) {
104+
constexpr int Q1 = QBS & 15;
105+
constexpr int Q2 = (QBS >> 4) & 15;
106+
constexpr int Q3 = (QBS >> 8) & 15;
107+
constexpr int Q4 = (QBS >> 12) & 15;
108+
constexpr int SQ = Q1 + Q2 + Q3 + Q4;
109+
110+
for_each_block<32>(ntotal2, codes, block_stride, res, [&](size_t) {
111+
FixedStorage512<SQ, 2> res2;
112+
const uint8_t* LUT = LUT0;
113+
pq4_kernel_qbs_512<Q1>(nsq, codes, LUT, res2, scaler);
114+
LUT += Q1 * nsq * 16;
115+
if (Q2 > 0) {
116+
res2.set_block_origin(Q1, 0);
117+
pq4_kernel_qbs_512<Q2>(nsq, codes, LUT, res2, scaler);
118+
LUT += Q2 * nsq * 16;
119+
}
120+
if (Q3 > 0) {
121+
res2.set_block_origin(Q1 + Q2, 0);
122+
pq4_kernel_qbs_512<Q3>(nsq, codes, LUT, res2, scaler);
123+
LUT += Q3 * nsq * 16;
124+
}
125+
if (Q4 > 0) {
126+
res2.set_block_origin(Q1 + Q2 + Q3, 0);
127+
pq4_kernel_qbs_512<Q4>(nsq, codes, LUT, res2, scaler);
128+
}
129+
res2.to_other_handler(res);
130+
});
131+
}
132+
133+
template <class ResultHandler, class Scaler>
134+
void pq4_accumulate_loop_qbs_fixed_scaler_512(
135+
int qbs,
136+
size_t ntotal2,
137+
int nsq,
138+
const uint8_t* codes,
139+
const uint8_t* LUT0,
140+
ResultHandler& res,
141+
const Scaler& scaler,
142+
size_t block_stride) {
143+
assert(nsq % 2 == 0);
144+
assert(is_aligned_pointer(codes));
145+
assert(is_aligned_pointer(LUT0));
146+
147+
switch (qbs) {
148+
#define FAISS_QBS512_DISPATCH(QBS) \
149+
case QBS: \
150+
accumulate_q_4step_512<QBS>( \
151+
ntotal2, nsq, codes, LUT0, res, scaler, block_stride); \
152+
return;
153+
FAISS_QBS512_DISPATCH(0x3333); // 12
154+
FAISS_QBS512_DISPATCH(0x2333); // 11
155+
FAISS_QBS512_DISPATCH(0x2233); // 10
156+
FAISS_QBS512_DISPATCH(0x333); // 9
157+
FAISS_QBS512_DISPATCH(0x2223); // 9
158+
FAISS_QBS512_DISPATCH(0x233); // 8
159+
FAISS_QBS512_DISPATCH(0x1223); // 8
160+
FAISS_QBS512_DISPATCH(0x223); // 7
161+
FAISS_QBS512_DISPATCH(0x34); // 7
162+
FAISS_QBS512_DISPATCH(0x133); // 7
163+
FAISS_QBS512_DISPATCH(0x6); // 6
164+
FAISS_QBS512_DISPATCH(0x33); // 6
165+
FAISS_QBS512_DISPATCH(0x123); // 6
166+
FAISS_QBS512_DISPATCH(0x222); // 6
167+
FAISS_QBS512_DISPATCH(0x23); // 5
168+
FAISS_QBS512_DISPATCH(0x5); // 5
169+
FAISS_QBS512_DISPATCH(0x13); // 4
170+
FAISS_QBS512_DISPATCH(0x22); // 4
171+
FAISS_QBS512_DISPATCH(0x4); // 4
172+
FAISS_QBS512_DISPATCH(0x3); // 3
173+
FAISS_QBS512_DISPATCH(0x21); // 3
174+
FAISS_QBS512_DISPATCH(0x2); // 2
175+
FAISS_QBS512_DISPATCH(0x1); // 1
176+
#undef FAISS_QBS512_DISPATCH
177+
}
178+
179+
// Fallback for unknown QBS values: use 256-bit path with NONE-level
180+
// scalers for type compatibility. This is rare — pq4_preferred_qbs()
181+
// covers all values above.
182+
if constexpr (Scaler::nscale == 0) {
183+
DummyScaler<> scaler_none;
184+
pq4_accumulate_loop_qbs_fixed_scaler_256(
185+
qbs, ntotal2, nsq, codes, LUT0, res, scaler_none, block_stride);
186+
} else {
187+
NormTableScaler<> scaler_none(scaler.scale_int);
188+
pq4_accumulate_loop_qbs_fixed_scaler_256(
189+
qbs, ntotal2, nsq, codes, LUT0, res, scaler_none, block_stride);
190+
}
191+
}
192+
193+
} // namespace faiss
194+
195+
#endif // COMPILE_SIMD_AVX512 && __AVX512F__

faiss/impl/fast_scan/dispatching.h

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
* #include <faiss/impl/fast_scan/dispatching.h>
2222
*
2323
* Kernel helpers come from accumulate_loops.h (search_1 multi-BB path
24-
* and QBS 256-bit path). The QBS helpers use pq4_kernel_qbs_256 only
25-
* because decompose_qbs.h pulls in 512-bit types that fail with
26-
* SINGLE_SIMD_LEVEL=NONE in DD mode.
24+
* and QBS 256-bit path) and accumulate_loops_512.h (QBS 512-bit path,
25+
* AVX512 TU only).
2726
*/
2827

2928
#ifndef THE_LEVEL_TO_DISPATCH
@@ -35,6 +34,10 @@
3534
#include <faiss/impl/fast_scan/accumulate_loops.h>
3635
#include <faiss/impl/fast_scan/fast_scan.h>
3736

37+
#if defined(COMPILE_SIMD_AVX512) && defined(__AVX512F__)
38+
#include <faiss/impl/fast_scan/accumulate_loops_512.h>
39+
#endif
40+
3841
namespace faiss {
3942

4043
using namespace simd_result_handlers;
@@ -101,14 +104,62 @@ struct ScannerMixIn : FastScanCodeScanner {
101104
const uint8_t* LUT,
102105
int pq2x4_scale,
103106
size_t block_stride) override {
104-
if (pq2x4_scale) {
105-
NormTableScaler<> scaler(pq2x4_scale);
106-
pq4_accumulate_loop_qbs_fixed_scaler_256(
107-
qbs, nb, nsq, codes, LUT, handler_, scaler, block_stride);
107+
#if defined(COMPILE_SIMD_AVX512) && defined(__AVX512F__)
108+
constexpr bool use_avx512_qbs =
109+
(THE_LEVEL_TO_DISPATCH == SIMDLevel::AVX512 ||
110+
THE_LEVEL_TO_DISPATCH == SIMDLevel::AVX512_SPR);
111+
#else
112+
constexpr bool use_avx512_qbs = false;
113+
#endif
114+
if constexpr (use_avx512_qbs) {
115+
// Use 512-bit QBS kernels with properly-leveled scalers.
116+
if (pq2x4_scale) {
117+
NormTableScaler<THE_LEVEL_TO_DISPATCH> scaler(pq2x4_scale);
118+
pq4_accumulate_loop_qbs_fixed_scaler_512(
119+
qbs,
120+
nb,
121+
nsq,
122+
codes,
123+
LUT,
124+
handler_,
125+
scaler,
126+
block_stride);
127+
} else {
128+
DummyScaler<THE_LEVEL_TO_DISPATCH> dummy;
129+
pq4_accumulate_loop_qbs_fixed_scaler_512(
130+
qbs,
131+
nb,
132+
nsq,
133+
codes,
134+
LUT,
135+
handler_,
136+
dummy,
137+
block_stride);
138+
}
108139
} else {
109-
DummyScaler<> dummy;
110-
pq4_accumulate_loop_qbs_fixed_scaler_256(
111-
qbs, nb, nsq, codes, LUT, handler_, dummy, block_stride);
140+
if (pq2x4_scale) {
141+
NormTableScaler<> scaler(pq2x4_scale);
142+
pq4_accumulate_loop_qbs_fixed_scaler_256(
143+
qbs,
144+
nb,
145+
nsq,
146+
codes,
147+
LUT,
148+
handler_,
149+
scaler,
150+
block_stride);
151+
} else {
152+
DummyScaler<> dummy;
153+
pq4_accumulate_loop_qbs_fixed_scaler_256(
154+
qbs,
155+
nb,
156+
nsq,
157+
codes,
158+
LUT,
159+
handler_,
160+
dummy,
161+
block_stride);
162+
}
112163
}
113164
}
114165
};

faiss/impl/fast_scan/kernels_simd512.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ void kernel_accumulate_block_avx512_nq1(
3030
const uint8_t* LUT,
3131
ResultHandler& res,
3232
const Scaler& scaler) {
33+
// Explicit SIMD levels for DD mode where bare aliases resolve to NONE
34+
// (512-bit NONE types don't exist — empty primary templates).
35+
using simd32uint16 = simd32uint16_tpl<SIMDLevel::AVX512>;
36+
using simd64uint8 = simd64uint8_tpl<SIMDLevel::AVX512>;
37+
using simd16uint16 = simd16uint16_tpl<SIMDLevel::AVX2>;
38+
using simd32uint8 = simd32uint8_tpl<SIMDLevel::AVX2>;
39+
3340
// NQ is kept in order to match the similarity to baseline function
3441
constexpr int NQ = 1;
3542
// distance accumulators. We can accept more for NQ=1
@@ -291,6 +298,12 @@ void kernel_accumulate_block_avx512_nqx(
291298
const uint8_t* LUT,
292299
ResultHandler& res,
293300
const Scaler& scaler) {
301+
// Explicit SIMD levels for DD mode (see nq1 variant for explanation).
302+
using simd32uint16 = simd32uint16_tpl<SIMDLevel::AVX512>;
303+
using simd64uint8 = simd64uint8_tpl<SIMDLevel::AVX512>;
304+
using simd16uint16 = simd16uint16_tpl<SIMDLevel::AVX2>;
305+
using simd32uint8 = simd32uint8_tpl<SIMDLevel::AVX2>;
306+
294307
// dummy alloc to keep the windows compiler happy
295308
constexpr int NQA = NQ > 0 ? NQ : 1;
296309
// distance accumulators

0 commit comments

Comments
 (0)