diff --git a/faiss/impl/scalar_quantizer/sq-avx512.cpp b/faiss/impl/scalar_quantizer/sq-avx512.cpp index f85f8bbb8c..07d994f712 100644 --- a/faiss/impl/scalar_quantizer/sq-avx512.cpp +++ b/faiss/impl/scalar_quantizer/sq-avx512.cpp @@ -382,24 +382,54 @@ struct DistanceComputerByte int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { - // compute 16 lanes of 32-bit products (16-bytes) at once for + // compute 32 lanes of 16-bit products (32-bytes) at once for // the supported metrics __m512i accu = _mm512_setzero_si512(); - constexpr int kLanes = 16; - for (int i = 0; i < d; i += kLanes) { - __m128i c1 = _mm_loadu_si128((__m128i*)(code1 + i)); - __m128i c2 = _mm_loadu_si128((__m128i*)(code2 + i)); - __m512i c1i = _mm512_cvtepu8_epi32(c1); - __m512i c2i = _mm512_cvtepu8_epi32(c2); - - __m512i v; + constexpr int kLanes = 32; + int i = 0; + for (; i + kLanes <= d; i += kLanes) { + __m256i c1 = _mm256_loadu_epi8(code1 + i); + __m256i c2 = _mm256_loadu_epi8(code2 + i); + __m512i c1i16 = _mm512_cvtepu8_epi16(c1); + __m512i c2i16 = _mm512_cvtepu8_epi16(c2); +#ifdef __AVX512VNNI__ if (Sim::metric_type == METRIC_INNER_PRODUCT) { - v = _mm512_mullo_epi32(c1i, c2i); + accu = _mm512_dpwssd_epi32(accu, c1i16, c2i16); } else { - __m512i diff = _mm512_sub_epi32(c1i, c2i); - v = _mm512_mullo_epi32(diff, diff); + __m512i diff = _mm512_sub_epi16(c1i16, c2i16); + accu = _mm512_dpwssd_epi32(accu, diff, diff); } - accu = _mm512_add_epi32(accu, v); +#else + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu = _mm512_add_epi32(accu, _mm512_madd_epi16(c1i16, c2i16)); + } else { + __m512i diff = _mm512_sub_epi16(c1i16, c2i16); + accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff)); + } +#endif + } + // tail handling for dimensions not divisible by 32 + if (i < d) { + __mmask32 mask = (__mmask32)((1ULL << (d - i)) - 1ULL); + __m256i c1 = _mm256_maskz_loadu_epi8(mask, code1 + i); + __m256i c2 = _mm256_maskz_loadu_epi8(mask, code2 + i); + __m512i c1i16 = _mm512_cvtepu8_epi16(c1); + __m512i c2i16 = _mm512_cvtepu8_epi16(c2); +#ifdef __AVX512VNNI__ + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu = _mm512_dpwssd_epi32(accu, c1i16, c2i16); + } else { + __m512i diff = _mm512_sub_epi16(c1i16, c2i16); + accu = _mm512_dpwssd_epi32(accu, diff, diff); + } +#else + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + accu = _mm512_add_epi32(accu, _mm512_madd_epi16(c1i16, c2i16)); + } else { + __m512i diff = _mm512_sub_epi16(c1i16, c2i16); + accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff)); + } +#endif } return _mm512_reduce_add_epi32(accu); }