Skip to content

Commit 734751a

Browse files
algoriddlemeta-codesync[bot]
authored andcommitted
Extract PQ4 kernels to includable headers in impl/pq_4bit/ (#4868)
Summary: Pull Request resolved: #4868 Move kernel templates from .cpp anonymous namespaces into includable headers, parameterized on SIMDLevel SL. No behavior change — existing .cpp files include the headers and instantiate with defaults. New headers: - kernels_simd256.h: multi-BB kernel (from search_1.cpp) + single-BB QBS 256-bit kernel (from search_qbs.cpp non-AVX512 path) - kernels_simd512.h: AVX512 nq1/nqx kernels + dispatcher (from search_qbs.cpp) - decompose_qbs.h: unified kernel_accumulate_block<NQ, SL> that replaces #ifndef __AVX512F__ with if constexpr on SL, plus QBS decomposition logic Template param order: <int NQ, SIMDLevel SL, class ResultHandler, class Scaler> to enable ergonomic SL propagation via kernel_accumulate_block<Q1, SL>(...). ~900 lines moved (code motion), ~100 lines changed. Pure refactor. Reviewed By: mdouze, mnorris11 Differential Revision: D95392155 fbshipit-source-id: 4020ff66847152aada7271629b05f636f1bc3dc3
1 parent 1e4544a commit 734751a

6 files changed

Lines changed: 864 additions & 801 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ set(FAISS_HEADERS
263263
impl/lattice_Zn.h
264264
impl/platform_macros.h
265265
impl/fast_scan/pq4_fast_scan.h
266+
impl/fast_scan/decompose_qbs.h
267+
impl/fast_scan/kernels_simd256.h
268+
impl/fast_scan/kernels_simd512.h
266269
impl/residual_quantizer_encode_steps.h
267270
impl/simd_dispatch.h
268271
impl/fast_scan/simd_result_handlers.h
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <cassert>
11+
12+
#include <faiss/impl/FaissAssert.h>
13+
#include <faiss/impl/fast_scan/kernels_simd256.h>
14+
#include <faiss/impl/fast_scan/kernels_simd512.h>
15+
#include <faiss/impl/fast_scan/simd_result_handlers.h>
16+
17+
namespace faiss {
18+
19+
using namespace simd_result_handlers;
20+
21+
/*
22+
* Unified kernel: selects 256-bit vs 512-bit path based on
23+
* compile-time __AVX512F__ guard.
24+
*/
25+
template <int NQ, class ResultHandler, class Scaler>
26+
void kernel_accumulate_block(
27+
int nsq,
28+
const uint8_t* codes,
29+
const uint8_t* LUT,
30+
ResultHandler& res,
31+
const Scaler& scaler) {
32+
#ifdef __AVX512F__
33+
pq4_kernel_qbs_512<NQ>(nsq, codes, LUT, res, scaler);
34+
#else
35+
pq4_kernel_qbs_256<NQ>(nsq, codes, LUT, res, scaler);
36+
#endif
37+
}
38+
39+
// handle at most 4 blocks of queries
40+
template <int QBS, class ResultHandler, class Scaler>
41+
void accumulate_q_4step(
42+
size_t ntotal2,
43+
int nsq,
44+
const uint8_t* codes,
45+
const uint8_t* LUT0,
46+
ResultHandler& res,
47+
const Scaler& scaler,
48+
size_t block_stride) {
49+
constexpr int Q1 = QBS & 15;
50+
constexpr int Q2 = (QBS >> 4) & 15;
51+
constexpr int Q3 = (QBS >> 8) & 15;
52+
constexpr int Q4 = (QBS >> 12) & 15;
53+
constexpr int SQ = Q1 + Q2 + Q3 + Q4;
54+
55+
for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
56+
FixedStorageHandler<SQ, 2> res2;
57+
const uint8_t* LUT = LUT0;
58+
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
59+
LUT += Q1 * nsq * 16;
60+
if (Q2 > 0) {
61+
res2.set_block_origin(Q1, 0);
62+
kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler);
63+
LUT += Q2 * nsq * 16;
64+
}
65+
if (Q3 > 0) {
66+
res2.set_block_origin(Q1 + Q2, 0);
67+
kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler);
68+
LUT += Q3 * nsq * 16;
69+
}
70+
if (Q4 > 0) {
71+
res2.set_block_origin(Q1 + Q2 + Q3, 0);
72+
kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler);
73+
}
74+
res.set_block_origin(0, j0);
75+
res2.to_other_handler(res);
76+
codes += block_stride;
77+
}
78+
}
79+
80+
template <int NQ, class ResultHandler, class Scaler>
81+
void kernel_accumulate_block_loop(
82+
size_t ntotal2,
83+
int nsq,
84+
const uint8_t* codes,
85+
const uint8_t* LUT,
86+
ResultHandler& res,
87+
const Scaler& scaler,
88+
size_t block_stride) {
89+
for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
90+
res.set_block_origin(0, j0);
91+
kernel_accumulate_block<NQ, ResultHandler>(
92+
nsq, codes, LUT, res, scaler);
93+
codes += block_stride;
94+
}
95+
}
96+
97+
// non-template version of accumulate kernel -- dispatches dynamically
98+
template <class ResultHandler, class Scaler>
99+
void accumulate(
100+
int nq,
101+
size_t ntotal2,
102+
int nsq,
103+
const uint8_t* codes,
104+
const uint8_t* LUT,
105+
ResultHandler& res,
106+
const Scaler& scaler,
107+
size_t block_stride) {
108+
assert(nsq % 2 == 0);
109+
assert(is_aligned_pointer(LUT));
110+
111+
#define DISPATCH(NQ) \
112+
case NQ: \
113+
kernel_accumulate_block_loop<NQ, ResultHandler>( \
114+
ntotal2, nsq, codes, LUT, res, scaler, block_stride); \
115+
return
116+
117+
switch (nq) {
118+
DISPATCH(1);
119+
DISPATCH(2);
120+
DISPATCH(3);
121+
DISPATCH(4);
122+
}
123+
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
124+
125+
#undef DISPATCH
126+
}
127+
128+
template <class ResultHandler, class Scaler>
129+
void pq4_accumulate_loop_qbs_fixed_scaler(
130+
int qbs,
131+
size_t ntotal2,
132+
int nsq,
133+
const uint8_t* codes,
134+
const uint8_t* LUT0,
135+
ResultHandler& res,
136+
const Scaler& scaler,
137+
size_t block_stride = 0) {
138+
assert(nsq % 2 == 0);
139+
assert(is_aligned_pointer(codes));
140+
assert(is_aligned_pointer(LUT0));
141+
142+
// try out optimized versions
143+
switch (qbs) {
144+
#define DISPATCH(QBS) \
145+
case QBS: \
146+
accumulate_q_4step<QBS>( \
147+
ntotal2, nsq, codes, LUT0, res, scaler, block_stride); \
148+
return;
149+
DISPATCH(0x3333); // 12
150+
DISPATCH(0x2333); // 11
151+
DISPATCH(0x2233); // 10
152+
DISPATCH(0x333); // 9
153+
DISPATCH(0x2223); // 9
154+
DISPATCH(0x233); // 8
155+
DISPATCH(0x1223); // 8
156+
DISPATCH(0x223); // 7
157+
DISPATCH(0x34); // 7
158+
DISPATCH(0x133); // 7
159+
DISPATCH(0x6); // 6
160+
DISPATCH(0x33); // 6
161+
DISPATCH(0x123); // 6
162+
DISPATCH(0x222); // 6
163+
DISPATCH(0x23); // 5
164+
DISPATCH(0x5); // 5
165+
DISPATCH(0x13); // 4
166+
DISPATCH(0x22); // 4
167+
DISPATCH(0x4); // 4
168+
DISPATCH(0x3); // 3
169+
DISPATCH(0x21); // 3
170+
DISPATCH(0x2); // 2
171+
DISPATCH(0x1); // 1
172+
#undef DISPATCH
173+
}
174+
175+
// default implementation where qbs is not known at compile time
176+
for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
177+
const uint8_t* LUT = LUT0;
178+
int qi = qbs;
179+
int i0 = 0;
180+
while (qi) {
181+
int nq = qi & 15;
182+
qi >>= 4;
183+
res.set_block_origin(i0, j0);
184+
#define DISPATCH(NQ) \
185+
case NQ: \
186+
kernel_accumulate_block<NQ, ResultHandler>( \
187+
nsq, codes, LUT, res, scaler); \
188+
break
189+
switch (nq) {
190+
DISPATCH(1);
191+
DISPATCH(2);
192+
DISPATCH(3);
193+
DISPATCH(4);
194+
#undef DISPATCH
195+
default:
196+
FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq);
197+
}
198+
i0 += nq;
199+
LUT += nq * nsq * 16;
200+
}
201+
codes += block_stride;
202+
}
203+
}
204+
205+
} // namespace faiss

0 commit comments

Comments
 (0)