Skip to content

Commit 0bd934e

Browse files
committed
Add Sapphire Rapids optimizations for ScalarQuantizer (L2, IP)
Adds an AVX512_SPR specialization path for ScalarQuantizer that uses Sapphire Rapids-specific instructions for byte-code distance computation on QT_8bit_direct and QT_8bit_direct_signed. Inner product (8-bit codes): Replaces the AVX512 path that processes 16 bytes per iteration via cvtepu8_epi32 + mullo_epi32 with a VNNI loop that processes 64 bytes per iteration using _mm512_dpbusd_epi32. VNNI computes unsigned*signed dot products, so the standard bias trick is used to bridge unsigned*unsigned: subtract 128 from code2, run dpbusd, then add the 128 * sum(code1) correction. A scalar tail handles d % 64. For QT_8bit_direct_signed (storage = value + 128), the same VNNI loop runs and an additional closed-form correction is applied: (a-128) * (b-128) = a*b - 128*(a+b) + 16384 sum(a) and sum(b) are accumulated cheaply via _mm512_sad_epu8 (one PSADBW per 64-byte iteration). L2 (8-bit codes): Replaces the 16-bytes-per-iter cvtepu8_epi32 + sub + mullo_epi32 path with a 16-bit pipeline: load 64 bytes, zero-extend to 16-bit lanes via _mm512_cvtepu8_epi16, subtract in 16-bit, square-and-accumulate to 32-bit with _mm512_madd_epi16. Squared differences of two uint8_t values fit in 16 bits (max 255^2 = 65025), so the widened representation is safe. Falls through to a 32-byte step and a scalar tail for arbitrary d. The same kernel is bit-exact for the signed variant: (a - 128) - (b - 128) == a - b, so no correction is needed. Signed-off-by: Mulugeta Mammo <[email protected]>
1 parent 9d5491a commit 0bd934e

7 files changed

Lines changed: 446 additions & 8 deletions

File tree

faiss/CMakeLists.txt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ set(FAISS_SIMD_AVX512_SRC
3434
utils/simd_impl/rabitq_avx512.cpp
3535
utils/simd_impl/super_kmeans_kernels_avx512.cpp
3636
)
37+
set(FAISS_SIMD_AVX512_SPR_SRC
38+
impl/scalar_quantizer/sq-avx512-spr.cpp
39+
)
3740
set(FAISS_SIMD_NEON_SRC
3841
impl/fast_scan/impl-neon.cpp
3942
impl/scalar_quantizer/sq-neon.cpp
@@ -61,7 +64,7 @@ set(FAISS_SIMD_RVV_SRC
6164
)
6265
# Select SIMD sources based on target architecture
6366
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64|amd64|AMD64)")
64-
set(FAISS_SIMD_SRC ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC})
67+
set(FAISS_SIMD_SRC ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC} ${FAISS_SIMD_AVX512_SPR_SRC})
6568
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64|arm64|ARM64)")
6669
set(FAISS_SIMD_SRC ${FAISS_SIMD_NEON_SRC} ${FAISS_SIMD_SVE_SRC})
6770
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "(riscv64|riscv)")
@@ -461,13 +464,13 @@ endif()
461464
if(NOT WIN32)
462465
# Architecture mode to support AVX512 extensions available since Intel(R) Sapphire Rapids.
463466
# Ref: https://networkbuilders.intel.com/solutionslibrary/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide
464-
target_compile_options(faiss_avx512_spr PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw -mavx512vpopcntdq -mpopcnt -mavx512fp16 -mavx512bf16 ${FAISS_BMI2_FLAGS}>)
467+
target_compile_options(faiss_avx512_spr PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw -mavx512vnni -mavx512vpopcntdq -mpopcnt -mavx512fp16 -mavx512bf16 ${FAISS_BMI2_FLAGS}>)
465468
else()
466469
target_compile_options(faiss_avx512_spr PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
467470
# we need bigobj for the swig wrapper
468471
add_compile_options(/bigobj)
469472
endif()
470-
target_sources(faiss_avx512_spr PRIVATE ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC})
473+
target_sources(faiss_avx512_spr PRIVATE ${FAISS_SIMD_AVX2_SRC} ${FAISS_SIMD_AVX512_SRC} ${FAISS_SIMD_AVX512_SPR_SRC})
471474
target_compile_definitions(faiss_avx512_spr PRIVATE COMPILE_SIMD_AVX2 COMPILE_SIMD_AVX512 COMPILE_SIMD_AVX512_SPR )
472475

473476
add_library(faiss_sve ${FAISS_SRC})
@@ -525,6 +528,11 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
525528
PROPERTIES COMPILE_OPTIONS
526529
"-mavx512f;-mavx512cd;-mavx512vl;-mavx512dq;-mavx512bw;-mfma;-mf16c;-mpopcnt"
527530
)
531+
set_source_files_properties(${FAISS_SIMD_AVX512_SPR_SRC}
532+
TARGET_DIRECTORY faiss
533+
PROPERTIES COMPILE_OPTIONS
534+
"-mavx512f;-mavx512cd;-mavx512vl;-mavx512dq;-mavx512bw;-mavx512vnni;-mavx512vpopcntdq;-mfma;-mf16c;-mpopcnt;-mavx512fp16;-mavx512bf16"
535+
)
528536
else()
529537
# Per-file SIMD flags (MSVC)
530538
add_compile_options(/bigobj)

faiss/impl/ScalarQuantizer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
154154
}
155155

156156
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
157-
return with_simd_level([&]<SIMDLevel SL>() -> SQuantizer* {
157+
return with_simd_level_spr([&]<SIMDLevel SL>() -> SQuantizer* {
158158
if constexpr (SL != SIMDLevel::NONE) {
159159
auto* q = scalar_quantizer::sq_select_quantizer<SL>(
160160
qtype, d, trained);
@@ -197,7 +197,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
197197
ScalarQuantizer::SQDistanceComputer* ScalarQuantizer::get_distance_computer(
198198
MetricType metric) const {
199199
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
200-
return with_simd_level([&]<SIMDLevel SL>() -> SQDistanceComputer* {
200+
return with_simd_level_spr([&]<SIMDLevel SL>() -> SQDistanceComputer* {
201201
if constexpr (SL != SIMDLevel::NONE) {
202202
auto* dc = scalar_quantizer::sq_select_distance_computer<SL>(
203203
metric, qtype, d, trained);
@@ -216,7 +216,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
216216
bool store_pairs,
217217
const IDSelector* sel,
218218
bool by_residual) const {
219-
return with_simd_level([&]<SIMDLevel SL>() -> InvertedListScanner* {
219+
return with_simd_level_spr([&]<SIMDLevel SL>() -> InvertedListScanner* {
220220
if constexpr (SL != SIMDLevel::NONE) {
221221
auto* s = scalar_quantizer::sq_select_InvertedListScanner<SL>(
222222
qtype,

faiss/impl/scalar_quantizer/distance_computers.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ struct DCTemplate<Quantizer, Similarity, SIMDLevel::NONE> : SQDistanceComputer {
7777
template <class Similarity, SIMDLevel SL>
7878
struct DistanceComputerByte : SQDistanceComputer {};
7979

80+
// Byte-domain distance computer for QT_8bit_direct_signed (storage is
81+
// value+128). Only specialized for AVX512_SPR; other levels fall back to
82+
// the float-domain DCTemplate path via the dispatch logic.
83+
template <class Similarity, SIMDLevel SL>
84+
struct DistanceComputerByteSigned : SQDistanceComputer {};
85+
8086
template <class Similarity>
8187
struct DistanceComputerByte<Similarity, SIMDLevel::NONE> : SQDistanceComputer {
8288
using Sim = Similarity;

0 commit comments

Comments
 (0)