Skip to content

Commit 1ea755a

Browse files
alibeklfcfacebook-github-bot
authored andcommitted
Add explicit template specialization declarations for MSVC linking (facebookresearch#4992)
Summary: Add explicit specialization declarations for all SIMD-templated distance functions in `distances.h`. C++ [temp.expl.spec]/7 requires that explicit specialization declarations appear before any translation unit that might implicitly instantiate the primary template. GCC/Clang are lenient about this ordering, but MSVC strictly enforces it — without these declarations, the linker emits LNK2001 (unresolved external symbol) for the specializations defined in the `_avx2` translation units. The macro `FAISS_DECLARE_DISTANCES_SPECIALIZATIONS(SL)` declares all 17 distance function specializations for a given `SIMDLevel`, and is expanded for `SIMDLevel::NONE` and `SIMDLevel::AVX2`. This is a no-op on GCC/Clang (declarations are redundant but harmless). Differential Revision: D98232371
1 parent 8550f23 commit 1ea755a

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

faiss/utils/distances.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,4 +600,67 @@ int fvec_madd_and_argmin(
600600
const float* b,
601601
float* c);
602602

603+
/* Explicit specialization declarations for all SIMD-templated distance
604+
functions. C++ [temp.expl.spec]/7 requires that these appear before any
605+
translation unit that might implicitly instantiate them. GCC/Clang are
606+
lenient about this, but MSVC is not — without these declarations the
607+
linker emits LNK2001 for the specializations defined in the _avx2
608+
translation units. */
609+
610+
// clang-format off
611+
#define FAISS_DECLARE_DISTANCES_SPECIALIZATIONS(SL) \
612+
template <> float fvec_L2sqr<SL>( \
613+
const float* x, const float* y, size_t d); \
614+
template <> float fvec_inner_product<SL>( \
615+
const float* x, const float* y, size_t d); \
616+
template <> float fvec_L1<SL>( \
617+
const float* x, const float* y, size_t d); \
618+
template <> float fvec_Linf<SL>( \
619+
const float* x, const float* y, size_t d); \
620+
template <> void fvec_inner_product_batch_4<SL>( \
621+
const float* x, const float* y0, const float* y1, \
622+
const float* y2, const float* y3, const size_t d, \
623+
float& dis0, float& dis1, float& dis2, float& dis3); \
624+
template <> void fvec_L2sqr_batch_4<SL>( \
625+
const float* x, const float* y0, const float* y1, \
626+
const float* y2, const float* y3, const size_t d, \
627+
float& dis0, float& dis1, float& dis2, float& dis3); \
628+
template <> void fvec_inner_products_ny<SL>( \
629+
float* ip, const float* x, const float* y, \
630+
size_t d, size_t ny); \
631+
template <> void fvec_L2sqr_ny<SL>( \
632+
float* dis, const float* x, const float* y, \
633+
size_t d, size_t ny); \
634+
template <> void fvec_L2sqr_ny_transposed<SL>( \
635+
float* dis, const float* x, const float* y, \
636+
const float* y_sqlen, size_t d, size_t d_offset, size_t ny); \
637+
template <> size_t fvec_L2sqr_ny_nearest<SL>( \
638+
float* distances_tmp_buffer, const float* x, \
639+
const float* y, size_t d, size_t ny); \
640+
template <> size_t fvec_L2sqr_ny_nearest_y_transposed<SL>( \
641+
float* distances_tmp_buffer, const float* x, \
642+
const float* y, const float* y_sqlen, \
643+
size_t d, size_t d_offset, size_t ny); \
644+
template <> float fvec_norm_L2sqr<SL>(const float* x, size_t d); \
645+
template <> void fvec_add<SL>( \
646+
size_t d, const float* a, const float* b, float* c); \
647+
template <> void fvec_add<SL>( \
648+
size_t d, const float* a, float b, float* c); \
649+
template <> void fvec_sub<SL>( \
650+
size_t d, const float* a, const float* b, float* c); \
651+
template <> void compute_PQ_dis_tables_dsub2<SL>( \
652+
size_t d, size_t ksub, const float* centroids, \
653+
size_t nx, const float* x, bool is_inner_product, \
654+
float* dis_tables); \
655+
template <> void fvec_madd<SL>( \
656+
size_t n, const float* a, float bf, const float* b, float* c); \
657+
template <> int fvec_madd_and_argmin<SL>( \
658+
size_t n, const float* a, float bf, const float* b, float* c);
659+
660+
FAISS_DECLARE_DISTANCES_SPECIALIZATIONS(SIMDLevel::NONE)
661+
FAISS_DECLARE_DISTANCES_SPECIALIZATIONS(SIMDLevel::AVX2)
662+
663+
#undef FAISS_DECLARE_DISTANCES_SPECIALIZATIONS
664+
// clang-format on
665+
603666
} // namespace faiss

0 commit comments

Comments
 (0)