Skip to content

Commit 7531d62

Browse files
committed
fix: static SIMD dispatch falls to scalar for avx512_spr/avx512/arm_sve builds
The static (non-DD) dispatch path in with_selected_simd_levels performs a single exact-match check against SINGLE_SIMD_LEVEL. When the compiled level is not in the available-levels mask, it falls directly to NONE (scalar) instead of trying lower SIMD levels. No predefined mask includes AVX512_SPR, and no AVX512_SPR template specializations exist, so static avx512_spr builds dispatch every SIMD-accelerated function to scalar. Static avx512 builds also regress to scalar for 256-bit operations (AVX2_NEON mask), and static arm_sve builds lose ARM_NEON fallback. Add a compile-time fallthrough chain mirroring the DD switch/fallthrough: x86: AVX512_SPR -> AVX512 -> AVX2 -> NONE ARM: ARM_SVE -> ARM_NEON -> NONE Fixes 9 broken (level x mask) combinations across distances, RaBitQ, scalar/product quantizer, HNSW, IndexFlat, IVF, and fused distances. Signed-off-by: Mulugeta Mammo <mulugeta.mammo@intel.com>
1 parent 364749e commit 7531d62

1 file changed

Lines changed: 24 additions & 3 deletions

File tree

faiss/impl/simd_dispatch.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,32 @@ inline auto with_selected_simd_levels(LambdaType&& action) {
106106
}
107107
#else // static dispatch
108108
// In static mode, SINGLE_SIMD_LEVEL is a constexpr resolved at compile
109-
// time. If the compiled level is not in the available set, fall through
110-
// to NONE (mirroring the DD fallthrough behavior). Only SINGLE_SIMD_LEVEL
111-
// and NONE have compiled specializations.
109+
// time. We mirror the DD fallthrough behavior at compile time:
110+
// x86: AVX512_SPR -> AVX512 -> AVX2 -> NONE
111+
// ARM: ARM_SVE -> ARM_NEON -> NONE
112+
// This ensures that e.g. an AVX512_SPR build can use AVX512
113+
// specializations when no AVX512_SPR-specific implementation exists,
114+
// and an AVX512 build can use AVX2 for 256-bit operations.
115+
112116
if constexpr (available_levels & (1 << int(SINGLE_SIMD_LEVEL))) {
117+
// Exact match — use the compiled-in level directly.
113118
return action.template operator()<SINGLE_SIMD_LEVEL>();
119+
} else if constexpr (
120+
SINGLE_SIMD_LEVEL == SIMDLevel::AVX512_SPR &&
121+
(available_levels & (1 << int(SIMDLevel::AVX512)))) {
122+
// AVX512_SPR -> AVX512 fallthrough.
123+
return action.template operator()<SIMDLevel::AVX512>();
124+
} else if constexpr (
125+
(SINGLE_SIMD_LEVEL == SIMDLevel::AVX512_SPR ||
126+
SINGLE_SIMD_LEVEL == SIMDLevel::AVX512) &&
127+
(available_levels & (1 << int(SIMDLevel::AVX2)))) {
128+
// AVX512_SPR/AVX512 -> AVX2 fallthrough
129+
return action.template operator()<SIMDLevel::AVX2>();
130+
} else if constexpr (
131+
SINGLE_SIMD_LEVEL == SIMDLevel::ARM_SVE &&
132+
(available_levels & (1 << int(SIMDLevel::ARM_NEON)))) {
133+
// ARM_SVE -> ARM_NEON fallthrough.
134+
return action.template operator()<SIMDLevel::ARM_NEON>();
114135
} else {
115136
return action.template operator()<SIMDLevel::NONE>();
116137
}

0 commit comments

Comments
 (0)