Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 23 additions & 42 deletions faiss/Index2Layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
#include <cstdint>
#include <cstdio>

#ifdef __SSE3__
#include <immintrin.h>
#endif

#include <algorithm>

#include <faiss/IndexIVFPQ.h>
Expand Down Expand Up @@ -177,34 +173,26 @@ struct DistanceXPQ4 : Distance2Level {
}

float operator()(idx_t i) override {
#ifdef __SSE3__
const uint8_t* code = storage.codes.data() + i * storage.code_size;
idx_t key = 0;
memcpy(&key, code, storage.code_size_1);
code += storage.code_size_1;

// walking pointers
const float* qa = q;
const __m128* l1_t = (const __m128*)(pq_l1_tab + d * key);
const __m128* pq_l2_t = (const __m128*)pq_l2_tab;
__m128 accu = _mm_setzero_ps();
const float* l1 = pq_l1_tab + d * key;
const float* l2 = pq_l2_tab;
float accu = 0;

for (int m = 0; m < M; m++) {
__m128 qi = _mm_loadu_ps(qa);
__m128 recons = _mm_add_ps(l1_t[m], pq_l2_t[*code++]);
__m128 diff = _mm_sub_ps(qi, recons);
accu = _mm_add_ps(accu, _mm_mul_ps(diff, diff));
pq_l2_t += 256;
for (int j = 0; j < 4; j++) {
float diff = qa[j] - (l1[m * 4 + j] + l2[*code * 4 + j]);
accu += diff * diff;
}
code++;
l2 += 256 * 4;
qa += 4;
}

accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
return _mm_cvtss_f32(accu);
#else
(void)i;
FAISS_THROW_MSG("not implemented for non-x64 platforms");
#endif
return accu;
}
};

Expand All @@ -229,42 +217,36 @@ struct Distance2xXPQ4 : Distance2Level {
int64_t key01 = 0;
memcpy(&key01, code, storage.code_size_1);
code += storage.code_size_1;
#ifdef __SSE3__

// walking pointers
const float* qa = q;
const __m128* pq_l1_t = (const __m128*)pq_l1_tab;
const __m128* pq_l2_t = (const __m128*)pq_l2_tab;
__m128 accu = _mm_setzero_ps();
const float* l1 = pq_l1_tab;
const float* l2 = pq_l2_tab;
float accu = 0;

for (int mi_m = 0; mi_m < 2; mi_m++) {
int64_t l1_idx = key01 & (((int64_t)1 << mi_nbits) - 1);
const __m128* pq_l1 = pq_l1_t + M_2 * l1_idx;
const float* l1_sub = l1 + M_2 * l1_idx * 4;

for (int m = 0; m < M_2; m++) {
__m128 qi = _mm_loadu_ps(qa);
__m128 recons = _mm_add_ps(pq_l1[m], pq_l2_t[*code++]);
__m128 diff = _mm_sub_ps(qi, recons);
accu = _mm_add_ps(accu, _mm_mul_ps(diff, diff));
pq_l2_t += 256;
for (int j = 0; j < 4; j++) {
float diff =
qa[j] - (l1_sub[m * 4 + j] + l2[*code * 4 + j]);
accu += diff * diff;
}
code++;
l2 += 256 * 4;
qa += 4;
}
pq_l1_t += M_2 << mi_nbits;
l1 += (M_2 << mi_nbits) * 4;
key01 >>= mi_nbits;
}
accu = _mm_hadd_ps(accu, accu);
accu = _mm_hadd_ps(accu, accu);
return _mm_cvtss_f32(accu);
#else
FAISS_THROW_MSG("not implemented for non-x64 platforms");
#endif
return accu;
}
};

} // namespace

DistanceComputer* Index2Layer::get_distance_computer() const {
#ifdef __SSE3__
const MultiIndexQuantizer* mi =
dynamic_cast<MultiIndexQuantizer*>(q1.quantizer);

Expand All @@ -277,7 +259,6 @@ DistanceComputer* Index2Layer::get_distance_computer() const {
if (fl && pq.dsub == 4) {
return new DistanceXPQ4(*this);
}
#endif

return Index::get_distance_computer();
}
Expand Down
Loading