Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 55 additions & 53 deletions faiss/impl/LocalSearchQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,73 +600,75 @@ void LocalSearchQuantizer::icm_encode_step(
FAISS_THROW_IF_NOT(M != 0 && K != 0);
FAISS_THROW_IF_NOT(binaries != nullptr);

// Resolve SIMD level once, not per iteration of the n × n_iters × M loop.
with_simd_level_256bit([&]<SIMDLevel SL>() {
#pragma omp parallel for schedule(dynamic)
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
std::vector<float> objs(K);

for (size_t iter = 0; iter < n_iters; iter++) {
// condition on the m-th subcode
for (size_t m = 0; m < M; m++) {
// copy
auto u = unaries + m * n * K + i * K;
for (size_t code = 0; code < K; code++) {
objs[code] = u[code];
}
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
std::vector<float> objs(K);

// compute objective function by adding unary
// and binary terms together
for (size_t other_m = 0; other_m < M; other_m++) {
if (other_m == m) {
continue;
for (size_t iter = 0; iter < n_iters; iter++) {
// condition on the m-th subcode
for (size_t m = 0; m < M; m++) {
// copy
auto u = unaries + m * n * K + i * K;
for (size_t code = 0; code < K; code++) {
objs[code] = u[code];
}

// compute objective function by adding unary
// and binary terms together
for (size_t other_m = 0; other_m < M; other_m++) {
if (other_m == m) {
continue;
}

#ifdef COMPILE_SIMD_AVX2
// TODO: add platform-independent compiler-independent
// prefetch utilities.
if (other_m + 1 < M) {
// do a single prefetch
int32_t code2 = codes[i * M + other_m + 1];
// for (int32_t code = 0; code < K; code += 64) {
int32_t code = 0;
{
size_t binary_idx = (other_m + 1) * M * K * K +
m * K * K + code2 * K + code;
_mm_prefetch(
(const char*)(binaries + binary_idx),
_MM_HINT_T0);
// TODO: add platform-independent compiler-independent
// prefetch utilities.
if (other_m + 1 < M) {
// do a single prefetch
int32_t code2 = codes[i * M + other_m + 1];
// for (int32_t code = 0; code < K; code += 64) {
int32_t code = 0;
{
size_t binary_idx = (other_m + 1) * M * K * K +
m * K * K + code2 * K + code;
_mm_prefetch(
(const char*)(binaries + binary_idx),
_MM_HINT_T0);
}
}
}
#endif

for (size_t code = 0; code < K; code++) {
int32_t code2 = codes[i * M + other_m];
size_t binary_idx = other_m * M * K * K + m * K * K +
code2 * K + code;
// binaries[m, other_m, code, code2].
// It is symmetric over (m <-> other_m)
// and (code <-> code2).
// So, replace the op with
// binaries[other_m, m, code2, code].
objs[code] += binaries[binary_idx];
for (size_t code = 0; code < K; code++) {
int32_t code2 = codes[i * M + other_m];
size_t binary_idx = other_m * M * K * K +
m * K * K + code2 * K + code;
// binaries[m, other_m, code, code2].
// It is symmetric over (m <-> other_m)
// and (code <-> code2).
// So, replace the op with
// binaries[other_m, m, code2, code].
objs[code] += binaries[binary_idx];
}
}
}

// find the optimal value of the m-th subcode
float best_obj = HUGE_VALF;
int32_t best_code = 0;
// find the optimal value of the m-th subcode
float best_obj = HUGE_VALF;
int32_t best_code = 0;

// find one using SIMD. The following operation is similar
// to the search of the smallest element in objs
using C = CMax<float, int>;
HeapWithBuckets<C, 16, 1>::addn(
K, objs.data(), 1, &best_obj, &best_code);
// find one using SIMD. The following operation is similar
// to the search of the smallest element in objs
HeapWithBucketsCMaxFloat<16, 1, SL>::addn(
K, objs.data(), 1, &best_obj, &best_code);

// done
codes[i * M + m] = best_code;
// done
codes[i * M + m] = best_code;

} // loop M
} // loop M
}
}
}
});
}
void LocalSearchQuantizer::perturb_codes(
int32_t* codes,
Expand Down
61 changes: 61 additions & 0 deletions faiss/impl/approx_topk/approx_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,65 @@ struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
}
};

// -----------------------------------------------------------------------
// approx_topk_by_mode: consolidates the mode switch + dispatch pattern
// used by residual_quantizer_encode_steps.cpp and other callers.
// -----------------------------------------------------------------------

// SL-parameterized version for callers that have already resolved the
// SIMD level (e.g., inside a with_simd_level_256bit lambda).
template <SIMDLevel SL>
inline void approx_topk_by_mode(
ApproxTopK_mode_t mode,
uint32_t beam_size,
uint32_t n_per_beam,
const float* distances,
uint32_t k,
float* bh_val,
int32_t* bh_ids) {
using C = CMax<float, int>;
auto approx = [&]<uint32_t NB, uint32_t ND>() {
HeapWithBucketsCMaxFloat<NB, ND, SL>::bs_addn(
beam_size, n_per_beam, distances, k, bh_val, bh_ids);
};
switch (mode) {
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D3:
approx.template operator()<8, 3>();
break;
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D2:
approx.template operator()<8, 2>();
break;
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B16_D2:
approx.template operator()<16, 2>();
break;
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B32_D2:
approx.template operator()<32, 2>();
break;
default:
heap_addn<C>(
k,
bh_val,
bh_ids,
distances,
nullptr,
beam_size * n_per_beam);
break;
}
}

// Non-SL wrapper that dispatches via with_simd_level_256bit.
inline void approx_topk_by_mode(
ApproxTopK_mode_t mode,
uint32_t beam_size,
uint32_t n_per_beam,
const float* distances,
uint32_t k,
float* bh_val,
int32_t* bh_ids) {
with_simd_level_256bit([&]<SIMDLevel SL>() {
approx_topk_by_mode<SL>(
mode, beam_size, n_per_beam, distances, k, bh_val, bh_ids);
});
}

} // namespace faiss
Loading
Loading