Skip to content

Commit d28354e

Browse files
algoriddlefacebook-github-bot
authored andcommitted
Add defaulted SIMDLevel template parameter to handler and scaler types (#4867)
Summary: Templatize the result handler hierarchy and scaler types on SIMDLevel SL, defaulted to SINGLE_SIMD_LEVEL_256. This allows per-SIMD TUs to instantiate handlers and scalers with explicit SIMD levels (e.g., AVX2) for native dispatch. Result handlers: ResultHandlerCompare, SingleResultHandler, HeapHandler, ReservoirHandler, RangeHandler, PartialRangeHandler — all gain SL parameter. Scalers: DummyScaler templatized on SL. 512-bit methods use SL directly (removing #ifdef __AVX512F__ guard — safe because template bodies only instantiated when called). NormTableScaler stays non-template (public API). FixedStorageHandler: add SL parameter, remove SIMDResultHandler base class (never used polymorphically), remove final/virtual. Pure refactor. All existing callers use defaults and compile unchanged. Differential Revision: D95392149
1 parent 8d8268c commit d28354e

14 files changed

Lines changed: 142 additions & 152 deletions

faiss/IndexAdditiveQuantizerFastScan.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <faiss/impl/FaissAssert.h>
1414
#include <faiss/impl/FastScanDistancePostProcessing.h>
1515
#include <faiss/impl/LocalSearchQuantizer.h>
16-
#include <faiss/impl/LookupTableScaler.h>
1716
#include <faiss/impl/ResidualQuantizer.h>
1817
#include <faiss/impl/pq4_fast_scan.h>
1918
#include <faiss/utils/quantize_lut.h>
@@ -202,9 +201,8 @@ void IndexAdditiveQuantizerFastScan::search(
202201
return;
203202
}
204203

205-
NormTableScaler scaler(norm_scale);
206204
FastScanDistancePostProcessing context;
207-
context.norm_scaler = &scaler;
205+
context.pq2x4_scale = norm_scale;
208206
if (metric_type == METRIC_L2) {
209207
search_dispatch_implem<true>(n, x, k, distances, labels, context);
210208
} else {

faiss/IndexFastScan.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <faiss/impl/FaissAssert.h>
1616
#include <faiss/impl/FastScanDistancePostProcessing.h>
1717
#include <faiss/impl/IDSelector.h>
18-
#include <faiss/impl/LookupTableScaler.h>
1918
#include <faiss/impl/RaBitQUtils.h>
2019
#include <faiss/impl/pq4_fast_scan.h>
2120
#include <faiss/impl/simd_result_handlers.h>
@@ -187,18 +186,18 @@ void estimators_from_tables_generic(
187186
BitstringReader bsr(codes + j * index.code_size, index.code_size);
188187
accu_t dis = 0;
189188
const dis_t* dt = dis_table;
190-
int nscale = context.norm_scaler ? context.norm_scaler->nscale : 0;
189+
int nscale = context.pq2x4_scale ? 2 : 0;
191190

192191
for (size_t m = 0; m < index.M - nscale; m++) {
193192
uint64_t c = bsr.read(index.nbits);
194193
dis += dt[c];
195194
dt += index.ksub;
196195
}
197196

198-
if (nscale && context.norm_scaler) {
197+
if (nscale) {
199198
for (size_t m = 0; m < nscale; m++) {
200199
uint64_t c = bsr.read(index.nbits);
201-
dis += context.norm_scaler->scale_one(dt[c]);
200+
dis += dt[c] * context.pq2x4_scale;
202201
dt += index.ksub;
203202
}
204203
}
@@ -545,7 +544,7 @@ void IndexFastScan::search_implem_12(
545544
codes.get(),
546545
LUT.get(),
547546
*handler.get(),
548-
context.norm_scaler,
547+
context.pq2x4_scale,
549548
get_block_stride());
550549
}
551550
if (!(skip & 8)) {
@@ -629,7 +628,7 @@ void IndexFastScan::search_implem_14(
629628
codes.get(),
630629
LUT.get(),
631630
*handler.get(),
632-
context.norm_scaler,
631+
context.pq2x4_scale,
633632
get_block_stride());
634633
}
635634
if (!(skip & 8)) {

faiss/IndexFastScan.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
namespace faiss {
1515

1616
struct CodePacker;
17-
struct NormTableScaler;
1817
struct IDSelector;
1918
struct SIMDResultHandlerToFloat;
2019

faiss/IndexIVFAdditiveQuantizerFastScan.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <faiss/impl/AuxIndexStructures.h>
1616
#include <faiss/impl/FaissAssert.h>
1717
#include <faiss/impl/FastScanDistancePostProcessing.h>
18-
#include <faiss/impl/LookupTableScaler.h>
1918
#include <faiss/impl/pq4_fast_scan.h>
2019
#include <faiss/impl/simd_dispatch.h>
2120
#include <faiss/invlists/BlockInvertedLists.h>
@@ -317,9 +316,8 @@ void IndexIVFAdditiveQuantizerFastScan::search(
317316
return;
318317
}
319318

320-
NormTableScaler scaler(norm_scale);
321319
FastScanDistancePostProcessing context;
322-
context.norm_scaler = &scaler;
320+
context.pq2x4_scale = norm_scale;
323321
IndexIVFFastScan::CoarseQuantized cq{nprobe};
324322
search_dispatch_implem(n, x, k, distances, labels, cq, context);
325323
}

faiss/IndexIVFFastScan.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <faiss/impl/AuxIndexStructures.h>
1919
#include <faiss/impl/FaissAssert.h>
2020
#include <faiss/impl/FastScanDistancePostProcessing.h>
21-
#include <faiss/impl/LookupTableScaler.h>
2221
#include <faiss/impl/RaBitQUtils.h>
2322
#include <faiss/impl/pq4_fast_scan.h>
2423
#include <faiss/impl/simd_result_handlers.h>
@@ -239,7 +238,7 @@ void estimators_from_tables_generic(
239238
int64_t* heap_ids,
240239
const FastScanDistancePostProcessing& context) {
241240
using accu_t = typename C::T;
242-
size_t nscale = context.norm_scaler ? context.norm_scaler->nscale : 0;
241+
size_t nscale = context.pq2x4_scale ? 2 : 0;
243242
for (size_t j = 0; j < ncodes; ++j) {
244243
BitstringReader bsr(codes + j * index.code_size, index.code_size);
245244
accu_t dis = bias;
@@ -251,10 +250,10 @@ void estimators_from_tables_generic(
251250
dt += index.ksub;
252251
}
253252

254-
if (context.norm_scaler) {
253+
if (nscale) {
255254
for (size_t m = 0; m < nscale; m++) {
256255
uint64_t c = bsr.read(index.nbits);
257-
dis += context.norm_scaler->scale_one(dt[c]);
256+
dis += dt[c] * context.pq2x4_scale;
258257
dt += index.ksub;
259258
}
260259
}
@@ -1031,7 +1030,7 @@ void IndexIVFFastScan::search_implem_10(
10311030
codes.get(),
10321031
LUT,
10331032
handler,
1034-
context.norm_scaler,
1033+
context.pq2x4_scale,
10351034
get_block_stride());
10361035

10371036
ndis += ls;
@@ -1183,7 +1182,7 @@ void IndexIVFFastScan::search_implem_12(
11831182
codes.get(),
11841183
LUT.get(),
11851184
handler,
1186-
context.norm_scaler,
1185+
context.pq2x4_scale,
11871186
get_block_stride());
11881187
// prepare for next loop
11891188
i0 = i1;
@@ -1407,7 +1406,7 @@ void IndexIVFFastScan::search_implem_14(
14071406
codes.get(),
14081407
LUT.get(),
14091408
*handler.get(),
1410-
context.norm_scaler,
1409+
context.pq2x4_scale,
14111410
get_block_stride());
14121411
}
14131412

faiss/IndexIVFFastScan.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
namespace faiss {
1515

16-
struct NormTableScaler;
1716
struct SIMDResultHandlerToFloat;
1817
struct Quantizer;
1918

faiss/IndexIVFPQFastScan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ struct IVFPQFastScanScanner : InvertedListScanner {
419419
codes,
420420
LUT,
421421
*handler,
422-
nullptr,
422+
0,
423423
index.get_block_stride());
424424

425425
// The handler is for the results of this iteration.

faiss/IndexIVFRaBitQFastScan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ struct IVFRaBitQFastScanScanner : InvertedListScanner {
938938
codes,
939939
LUT,
940940
*handler,
941-
nullptr,
941+
0,
942942
index.get_block_stride());
943943

944944
// Combine results across iterations

faiss/impl/FastScanDistancePostProcessing.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
namespace faiss {
1313

14-
// Forward declarations
15-
struct NormTableScaler;
16-
1714
namespace rabitq_utils {
1815
struct QueryFactorsData;
1916
}
@@ -22,8 +19,10 @@ struct QueryFactorsData;
2219
* Simple context object that holds processors for FastScan operations.
2320
* */
2421
struct FastScanDistancePostProcessing {
25-
/// Norm scaling processor for Additive Quantizers (nullptr if not needed)
26-
const NormTableScaler* norm_scaler = nullptr;
22+
/// Norm scaling processor for Additive Quantizers.
23+
/// The scale is encoded in a 2x4 bit PQ table, then scaled by this int.
24+
/// Set to 0 if unused.
25+
int pq2x4_scale = 0;
2726

2827
/// Query factors data pointer for RaBitQ (nullptr if not needed)
2928
/// This pointer should point to the beginning of the relevant
@@ -41,7 +40,7 @@ struct FastScanDistancePostProcessing {
4140

4241
/// Check if norm scaling is enabled
4342
bool has_norm_scaling() const {
44-
return norm_scaler != nullptr;
43+
return pq2x4_scale != 0;
4544
}
4645

4746
/// Check if query factors processing is enabled

faiss/impl/LookupTableScaler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020
namespace faiss {
2121

2222
/// no-op handler
23+
template <SIMDLevel SL = SINGLE_SIMD_LEVEL>
2324
struct DummyScaler {
2425
static constexpr int nscale = 0;
26+
static constexpr SIMDLevel SL256 = simd256_level_selector<SL>::value;
27+
using simd32uint8 = simd32uint8_tpl<SL256>;
28+
using simd16uint16 = simd16uint16_tpl<SL256>;
29+
using simd64uint8 = simd64uint8_tpl<SL>;
30+
using simd32uint16 = simd32uint16_tpl<SL>;
2531

2632
inline simd32uint8 lookup(const simd32uint8&, const simd32uint8&) const {
2733
FAISS_THROW_MSG("DummyScaler::lookup should not be called.");
@@ -38,7 +44,6 @@ struct DummyScaler {
3844
return simd16uint16(0);
3945
}
4046

41-
#ifdef __AVX512F__
4247
inline simd64uint8 lookup(const simd64uint8&, const simd64uint8&) const {
4348
FAISS_THROW_MSG("DummyScaler::lookup should not be called.");
4449
return simd64uint8(0);
@@ -53,7 +58,6 @@ struct DummyScaler {
5358
FAISS_THROW_MSG("DummyScaler::scale_hi should not be called.");
5459
return simd32uint16(0);
5560
}
56-
#endif
5761

5862
template <class dist_t>
5963
inline dist_t scale_one(const dist_t&) const {

0 commit comments

Comments
 (0)