Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions faiss/impl/scalar_quantizer/sq-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,24 +382,54 @@ struct DistanceComputerByte<Similarity, SIMDLevel::AVX512>

int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
const {
// compute 16 lanes of 32-bit products (16-bytes) at once for
// compute 32 lanes of 16-bit products (32-bytes) at once for
// the supported metrics
__m512i accu = _mm512_setzero_si512();
constexpr int kLanes = 16;
for (int i = 0; i < d; i += kLanes) {
__m128i c1 = _mm_loadu_si128((__m128i*)(code1 + i));
__m128i c2 = _mm_loadu_si128((__m128i*)(code2 + i));
__m512i c1i = _mm512_cvtepu8_epi32(c1);
__m512i c2i = _mm512_cvtepu8_epi32(c2);

__m512i v;
constexpr int kLanes = 32;
int i = 0;
for (; i + kLanes <= d; i += kLanes) {
__m256i c1 = _mm256_loadu_epi8(code1 + i);
__m256i c2 = _mm256_loadu_epi8(code2 + i);
__m512i c1i16 = _mm512_cvtepu8_epi16(c1);
__m512i c2i16 = _mm512_cvtepu8_epi16(c2);
#ifdef __AVX512VNNI__
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
v = _mm512_mullo_epi32(c1i, c2i);
accu = _mm512_dpwssd_epi32(accu, c1i16, c2i16);
} else {
__m512i diff = _mm512_sub_epi32(c1i, c2i);
v = _mm512_mullo_epi32(diff, diff);
__m512i diff = _mm512_sub_epi16(c1i16, c2i16);
accu = _mm512_dpwssd_epi32(accu, diff, diff);
}
accu = _mm512_add_epi32(accu, v);
#else
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(c1i16, c2i16));
} else {
__m512i diff = _mm512_sub_epi16(c1i16, c2i16);
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
}
#endif
}
// tail handling for dimensions not divisible by 32
if (i < d) {
__mmask32 mask = (__mmask32)((1ULL << (d - i)) - 1ULL);
__m256i c1 = _mm256_maskz_loadu_epi8(mask, code1 + i);
__m256i c2 = _mm256_maskz_loadu_epi8(mask, code2 + i);
__m512i c1i16 = _mm512_cvtepu8_epi16(c1);
__m512i c2i16 = _mm512_cvtepu8_epi16(c2);
#ifdef __AVX512VNNI__
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
accu = _mm512_dpwssd_epi32(accu, c1i16, c2i16);
} else {
__m512i diff = _mm512_sub_epi16(c1i16, c2i16);
accu = _mm512_dpwssd_epi32(accu, diff, diff);
}
#else
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(c1i16, c2i16));
} else {
__m512i diff = _mm512_sub_epi16(c1i16, c2i16);
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
}
#endif
}
return _mm512_reduce_add_epi32(accu);
}
Expand Down