Skip to content

Commit 9b67422

Browse files
algoriddlefacebook-github-bot
authored andcommitted
Extract PQ4 kernels to includable headers in impl/pq_4bit/ (#4868)
Summary: Move kernel templates from .cpp anonymous namespaces into includable headers, parameterized on SIMDLevel SL. No behavior change — existing .cpp files include the headers and instantiate with defaults. New headers: - kernels_simd256.h: multi-BB kernel (from search_1.cpp) + single-BB QBS 256-bit kernel (from search_qbs.cpp non-AVX512 path) - kernels_simd512.h: AVX512 nq1/nqx kernels + dispatcher (from search_qbs.cpp) - decompose_qbs.h: unified kernel_accumulate_block<NQ, SL> that replaces #ifndef __AVX512F__ with if constexpr on SL, plus QBS decomposition logic Template param order: <int NQ, SIMDLevel SL, class ResultHandler, class Scaler> to enable ergonomic SL propagation via kernel_accumulate_block<Q1, SL>(...). ~900 lines moved (code motion), ~100 lines changed. Pure refactor. Reviewed By: mdouze, mnorris11 Differential Revision: D95392155
1 parent d28354e commit 9b67422

6 files changed

Lines changed: 864 additions & 801 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ set(FAISS_HEADERS
263263
impl/lattice_Zn.h
264264
impl/platform_macros.h
265265
impl/pq4_fast_scan.h
266+
impl/pq_4bit/decompose_qbs.h
267+
impl/pq_4bit/kernels_simd256.h
268+
impl/pq_4bit/kernels_simd512.h
266269
impl/residual_quantizer_encode_steps.h
267270
impl/simd_dispatch.h
268271
impl/simd_result_handlers.h

faiss/impl/pq4_fast_scan_search_1.cpp

Lines changed: 1 addition & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <faiss/impl/FaissAssert.h>
1111
#include <faiss/impl/LookupTableScaler.h>
12+
#include <faiss/impl/pq_4bit/kernels_simd256.h>
1213
#include <faiss/impl/simd_result_handlers.h>
1314

1415
namespace faiss {
@@ -21,101 +22,6 @@ using namespace simd_result_handlers;
2122

2223
namespace {
2324

24-
/*
25-
* The computation kernel
26-
* It accumulates results for NQ queries and BB * 32 database elements
27-
* writes results in a ResultHandler
28-
*/
29-
30-
template <int NQ, int BB, class ResultHandler, class Scaler>
31-
void kernel_accumulate_block(
32-
int nsq,
33-
const uint8_t* codes,
34-
const uint8_t* LUT,
35-
ResultHandler& res,
36-
const Scaler& scaler) {
37-
// distance accumulators
38-
simd16uint16 accu[NQ][BB][4];
39-
40-
for (int q = 0; q < NQ; q++) {
41-
for (int b = 0; b < BB; b++) {
42-
accu[q][b][0].clear();
43-
accu[q][b][1].clear();
44-
accu[q][b][2].clear();
45-
accu[q][b][3].clear();
46-
}
47-
}
48-
49-
for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
50-
simd32uint8 lut_cache[NQ];
51-
for (int q = 0; q < NQ; q++) {
52-
lut_cache[q] = simd32uint8(LUT);
53-
LUT += 32;
54-
}
55-
56-
for (int b = 0; b < BB; b++) {
57-
simd32uint8 c = simd32uint8(codes);
58-
codes += 32;
59-
simd32uint8 mask(15);
60-
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
61-
simd32uint8 clo = c & mask;
62-
63-
for (int q = 0; q < NQ; q++) {
64-
simd32uint8 lut = lut_cache[q];
65-
simd32uint8 res0 = lut.lookup_2_lanes(clo);
66-
simd32uint8 res1 = lut.lookup_2_lanes(chi);
67-
68-
accu[q][b][0] += simd16uint16(res0);
69-
accu[q][b][1] += simd16uint16(res0) >> 8;
70-
71-
accu[q][b][2] += simd16uint16(res1);
72-
accu[q][b][3] += simd16uint16(res1) >> 8;
73-
}
74-
}
75-
}
76-
77-
for (int sq = 0; sq < scaler.nscale; sq += 2) {
78-
simd32uint8 lut_cache[NQ];
79-
for (int q = 0; q < NQ; q++) {
80-
lut_cache[q] = simd32uint8(LUT);
81-
LUT += 32;
82-
}
83-
84-
for (int b = 0; b < BB; b++) {
85-
simd32uint8 c = simd32uint8(codes);
86-
codes += 32;
87-
simd32uint8 mask(15);
88-
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
89-
simd32uint8 clo = c & mask;
90-
91-
for (int q = 0; q < NQ; q++) {
92-
simd32uint8 lut = lut_cache[q];
93-
94-
simd32uint8 res0 = scaler.lookup(lut, clo);
95-
accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
96-
accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
97-
98-
simd32uint8 res1 = scaler.lookup(lut, chi);
99-
accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
100-
accu[q][b][3] +=
101-
scaler.scale_hi(res1); // handle vectors 24..31
102-
}
103-
}
104-
}
105-
106-
for (int q = 0; q < NQ; q++) {
107-
for (int b = 0; b < BB; b++) {
108-
accu[q][b][0] -= accu[q][b][1] << 8;
109-
simd16uint16 dis0 = combine2x2(accu[q][b][0], accu[q][b][1]);
110-
111-
accu[q][b][2] -= accu[q][b][3] << 8;
112-
simd16uint16 dis1 = combine2x2(accu[q][b][2], accu[q][b][3]);
113-
114-
res.handle(q, b, dis0, dis1);
115-
}
116-
}
117-
}
118-
11925
template <int NQ, int BB, class ResultHandler, class Scaler>
12026
void accumulate_fixed_blocks(
12127
size_t nb,

0 commit comments

Comments
 (0)