diff --git a/faiss/impl/LocalSearchQuantizer.cpp b/faiss/impl/LocalSearchQuantizer.cpp index 8753b4c106..42f08b38cb 100644 --- a/faiss/impl/LocalSearchQuantizer.cpp +++ b/faiss/impl/LocalSearchQuantizer.cpp @@ -600,73 +600,75 @@ void LocalSearchQuantizer::icm_encode_step( FAISS_THROW_IF_NOT(M != 0 && K != 0); FAISS_THROW_IF_NOT(binaries != nullptr); + // Resolve SIMD level once, not per iteration of the n × n_iters × M loop. + with_simd_level_256bit([&]() { #pragma omp parallel for schedule(dynamic) - for (int64_t i = 0; i < static_cast(n); i++) { - std::vector objs(K); - - for (size_t iter = 0; iter < n_iters; iter++) { - // condition on the m-th subcode - for (size_t m = 0; m < M; m++) { - // copy - auto u = unaries + m * n * K + i * K; - for (size_t code = 0; code < K; code++) { - objs[code] = u[code]; - } + for (int64_t i = 0; i < static_cast(n); i++) { + std::vector objs(K); - // compute objective function by adding unary - // and binary terms together - for (size_t other_m = 0; other_m < M; other_m++) { - if (other_m == m) { - continue; + for (size_t iter = 0; iter < n_iters; iter++) { + // condition on the m-th subcode + for (size_t m = 0; m < M; m++) { + // copy + auto u = unaries + m * n * K + i * K; + for (size_t code = 0; code < K; code++) { + objs[code] = u[code]; } + // compute objective function by adding unary + // and binary terms together + for (size_t other_m = 0; other_m < M; other_m++) { + if (other_m == m) { + continue; + } + #ifdef COMPILE_SIMD_AVX2 - // TODO: add platform-independent compiler-independent - // prefetch utilities. - if (other_m + 1 < M) { - // do a single prefetch - int32_t code2 = codes[i * M + other_m + 1]; - // for (int32_t code = 0; code < K; code += 64) { - int32_t code = 0; - { - size_t binary_idx = (other_m + 1) * M * K * K + - m * K * K + code2 * K + code; - _mm_prefetch( - (const char*)(binaries + binary_idx), - _MM_HINT_T0); + // TODO: add platform-independent compiler-independent + // prefetch utilities. + if (other_m + 1 < M) { + // do a single prefetch + int32_t code2 = codes[i * M + other_m + 1]; + // for (int32_t code = 0; code < K; code += 64) { + int32_t code = 0; + { + size_t binary_idx = (other_m + 1) * M * K * K + + m * K * K + code2 * K + code; + _mm_prefetch( + (const char*)(binaries + binary_idx), + _MM_HINT_T0); + } } - } #endif - for (size_t code = 0; code < K; code++) { - int32_t code2 = codes[i * M + other_m]; - size_t binary_idx = other_m * M * K * K + m * K * K + - code2 * K + code; - // binaries[m, other_m, code, code2]. - // It is symmetric over (m <-> other_m) - // and (code <-> code2). - // So, replace the op with - // binaries[other_m, m, code2, code]. - objs[code] += binaries[binary_idx]; + for (size_t code = 0; code < K; code++) { + int32_t code2 = codes[i * M + other_m]; + size_t binary_idx = other_m * M * K * K + + m * K * K + code2 * K + code; + // binaries[m, other_m, code, code2]. + // It is symmetric over (m <-> other_m) + // and (code <-> code2). + // So, replace the op with + // binaries[other_m, m, code2, code]. + objs[code] += binaries[binary_idx]; + } } - } - // find the optimal value of the m-th subcode - float best_obj = HUGE_VALF; - int32_t best_code = 0; + // find the optimal value of the m-th subcode + float best_obj = HUGE_VALF; + int32_t best_code = 0; - // find one using SIMD. The following operation is similar - // to the search of the smallest element in objs - using C = CMax; - HeapWithBuckets::addn( - K, objs.data(), 1, &best_obj, &best_code); + // find one using SIMD. The following operation is similar + // to the search of the smallest element in objs + HeapWithBucketsCMaxFloat<16, 1, SL>::addn( + K, objs.data(), 1, &best_obj, &best_code); - // done - codes[i * M + m] = best_code; + // done + codes[i * M + m] = best_code; - } // loop M + } // loop M + } } - } + }); } void LocalSearchQuantizer::perturb_codes( int32_t* codes, diff --git a/faiss/impl/approx_topk/approx_topk.h b/faiss/impl/approx_topk/approx_topk.h index 5d75502789..26c0a08844 100644 --- a/faiss/impl/approx_topk/approx_topk.h +++ b/faiss/impl/approx_topk/approx_topk.h @@ -212,4 +212,65 @@ struct HeapWithBuckets, NBUCKETS, N> { } }; +// ----------------------------------------------------------------------- +// approx_topk_by_mode: consolidates the mode switch + dispatch pattern +// used by residual_quantizer_encode_steps.cpp and other callers. +// ----------------------------------------------------------------------- + +// SL-parameterized version for callers that have already resolved the +// SIMD level (e.g., inside a with_simd_level_256bit lambda). +template +inline void approx_topk_by_mode( + ApproxTopK_mode_t mode, + uint32_t beam_size, + uint32_t n_per_beam, + const float* distances, + uint32_t k, + float* bh_val, + int32_t* bh_ids) { + using C = CMax; + auto approx = [&]() { + HeapWithBucketsCMaxFloat::bs_addn( + beam_size, n_per_beam, distances, k, bh_val, bh_ids); + }; + switch (mode) { + case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D3: + approx.template operator()<8, 3>(); + break; + case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D2: + approx.template operator()<8, 2>(); + break; + case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B16_D2: + approx.template operator()<16, 2>(); + break; + case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B32_D2: + approx.template operator()<32, 2>(); + break; + default: + heap_addn( + k, + bh_val, + bh_ids, + distances, + nullptr, + beam_size * n_per_beam); + break; + } +} + +// Non-SL wrapper that dispatches via with_simd_level_256bit. +inline void approx_topk_by_mode( + ApproxTopK_mode_t mode, + uint32_t beam_size, + uint32_t n_per_beam, + const float* distances, + uint32_t k, + float* bh_val, + int32_t* bh_ids) { + with_simd_level_256bit([&]() { + approx_topk_by_mode( + mode, beam_size, n_per_beam, distances, k, bh_val, bh_ids); + }); +} + } // namespace faiss diff --git a/faiss/impl/residual_quantizer_encode_steps.cpp b/faiss/impl/residual_quantizer_encode_steps.cpp index ac0fe85e6c..85310a61a5 100644 --- a/faiss/impl/residual_quantizer_encode_steps.cpp +++ b/faiss/impl/residual_quantizer_encode_steps.cpp @@ -92,111 +92,96 @@ void beam_search_encode_step( } InterruptCallback::check(); + // Resolve SIMD level once, not per iteration of the n-parallel loop. + with_simd_level_256bit([&]() { #pragma omp parallel for if (n > 100) - for (int64_t i = 0; i < static_cast(n); i++) { - const int32_t* codes_i = codes + i * m * beam_size; - int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; - const float* residuals_i = residuals + i * d * beam_size; - float* new_residuals_i = new_residuals + i * d * new_beam_size; - - float* new_distances_i = new_distances + i * new_beam_size; - using C = CMax; - - if (assign_index) { - const float* cent_distances_i = - cent_distances.data() + i * beam_size * new_beam_size; - const idx_t* cent_ids_i = - cent_ids.data() + i * beam_size * new_beam_size; - - // here we could be a tad more efficient by merging sorted arrays - for (size_t j = 0; j < new_beam_size; j++) { - new_distances_i[j] = C::neutral(); - } - std::vector perm(new_beam_size, -1); - heap_addn( - new_beam_size, - new_distances_i, - perm.data(), - cent_distances_i, - nullptr, - beam_size * new_beam_size); - heap_reorder(new_beam_size, new_distances_i, perm.data()); - - for (size_t j = 0; j < new_beam_size; j++) { - int js = perm[j] / new_beam_size; - int ls = cent_ids_i[perm[j]]; - if (m > 0) { - memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); + for (int64_t i = 0; i < static_cast(n); i++) { + const int32_t* codes_i = codes + i * m * beam_size; + int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; + const float* residuals_i = residuals + i * d * beam_size; + float* new_residuals_i = new_residuals + i * d * new_beam_size; + + float* new_distances_i = new_distances + i * new_beam_size; + using C = CMax; + + if (assign_index) { + const float* cent_distances_i = + cent_distances.data() + i * beam_size * new_beam_size; + const idx_t* cent_ids_i = + cent_ids.data() + i * beam_size * new_beam_size; + + // here we could be a tad more efficient by merging sorted + // arrays + for (size_t j = 0; j < new_beam_size; j++) { + new_distances_i[j] = C::neutral(); + } + std::vector perm(new_beam_size, -1); + heap_addn( + new_beam_size, + new_distances_i, + perm.data(), + cent_distances_i, + nullptr, + beam_size * new_beam_size); + heap_reorder(new_beam_size, new_distances_i, perm.data()); + + for (size_t j = 0; j < new_beam_size; j++) { + int js = perm[j] / new_beam_size; + int ls = cent_ids_i[perm[j]]; + if (m > 0) { + memcpy(new_codes_i, + codes_i + js * m, + sizeof(*codes) * m); + } + new_codes_i[m] = ls; + new_codes_i += m + 1; + fvec_sub( + d, + residuals_i + js * d, + cent + ls * d, + new_residuals_i); + new_residuals_i += d; } - new_codes_i[m] = ls; - new_codes_i += m + 1; - fvec_sub( - d, - residuals_i + js * d, - cent + ls * d, - new_residuals_i); - new_residuals_i += d; - } - } else { - const float* cent_distances_i = - cent_distances.data() + i * beam_size * K; - // then we have to select the best results - for (size_t j = 0; j < new_beam_size; j++) { - new_distances_i[j] = C::neutral(); - } - std::vector perm(new_beam_size, -1); + } else { + const float* cent_distances_i = + cent_distances.data() + i * beam_size * K; + // then we have to select the best results + for (size_t j = 0; j < new_beam_size; j++) { + new_distances_i[j] = C::neutral(); + } + std::vector perm(new_beam_size, -1); - auto approx = [&]() { - HeapWithBuckets::bs_addn( + approx_topk_by_mode( + approx_topk_mode, beam_size, K, cent_distances_i, new_beam_size, new_distances_i, perm.data()); - }; - switch (approx_topk_mode) { - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D3: - approx.template operator()<8, 3>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D2: - approx.template operator()<8, 2>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B16_D2: - approx.template operator()<16, 2>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B32_D2: - approx.template operator()<32, 2>(); - break; - default: - heap_addn( - new_beam_size, - new_distances_i, - perm.data(), - cent_distances_i, - nullptr, - beam_size * K); - } - heap_reorder(new_beam_size, new_distances_i, perm.data()); - - for (size_t j = 0; j < new_beam_size; j++) { - int js = perm[j] / K; - int ls = perm[j] % K; - if (m > 0) { - memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); + heap_reorder(new_beam_size, new_distances_i, perm.data()); + + for (size_t j = 0; j < new_beam_size; j++) { + int js = perm[j] / K; + int ls = perm[j] % K; + if (m > 0) { + memcpy(new_codes_i, + codes_i + js * m, + sizeof(*codes) * m); + } + new_codes_i[m] = ls; + new_codes_i += m + 1; + fvec_sub( + d, + residuals_i + js * d, + cent + ls * d, + new_residuals_i); + new_residuals_i += d; } - new_codes_i[m] = ls; - new_codes_i += m + 1; - fvec_sub( - d, - residuals_i + js * d, - cent + ls * d, - new_residuals_i); - new_residuals_i += d; } } - } + }); } // exposed in the faiss namespace @@ -380,20 +365,21 @@ void beam_search_encode_step_tab( { FAISS_THROW_IF_NOT(ldc >= K); + // Resolve SIMD level once, not per iteration of the n-parallel loop. + with_simd_level_256bit([&]() { #pragma omp parallel for if (n > 100) schedule(dynamic) - for (int64_t i = 0; i < static_cast(n); i++) { - std::vector cent_distances(beam_size * K); - std::vector cd_common(K); + for (int64_t i = 0; i < static_cast(n); i++) { + std::vector cent_distances(beam_size * K); + std::vector cd_common(K); - const int32_t* codes_i = codes + i * m * beam_size; - const float* query_cp_i = query_cp + i * ldqc; - const float* distances_i = distances + i * beam_size; + const int32_t* codes_i = codes + i * m * beam_size; + const float* query_cp_i = query_cp + i * ldqc; + const float* distances_i = distances + i * beam_size; - for (size_t k = 0; k < K; k++) { - cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k]; - } + for (size_t k = 0; k < K; k++) { + cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k]; + } - with_simd_level_256bit([&]() { if constexpr (SL == SIMDLevel::NONE) { compute_cent_distances_baseline( K, @@ -419,64 +405,40 @@ void beam_search_encode_step_tab( cd_common.data(), cent_distances.data()); } - }); - using C = CMax; - int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; - float* new_distances_i = new_distances + i * new_beam_size; - const float* cent_distances_i = cent_distances.data(); + using C = CMax; + int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size; + float* new_distances_i = new_distances + i * new_beam_size; - // then we have to select the best results - for (size_t j = 0; j < new_beam_size; j++) { - new_distances_i[j] = C::neutral(); - } - std::vector perm(new_beam_size, -1); + const float* cent_distances_i = cent_distances.data(); + + // then we have to select the best results + for (size_t j = 0; j < new_beam_size; j++) { + new_distances_i[j] = C::neutral(); + } + std::vector perm(new_beam_size, -1); - auto approx = [&]() { - HeapWithBuckets::bs_addn( + approx_topk_by_mode( + approx_topk_mode, beam_size, K, cent_distances_i, new_beam_size, new_distances_i, perm.data()); - }; - switch (approx_topk_mode) { - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D3: - approx.template operator()<8, 3>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B8_D2: - approx.template operator()<8, 2>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B16_D2: - approx.template operator()<16, 2>(); - break; - case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B32_D2: - approx.template operator()<32, 2>(); - break; - default: - heap_addn( - new_beam_size, - new_distances_i, - perm.data(), - cent_distances_i, - nullptr, - beam_size * K); - break; - } - - heap_reorder(new_beam_size, new_distances_i, perm.data()); + heap_reorder(new_beam_size, new_distances_i, perm.data()); - for (size_t j = 0; j < new_beam_size; j++) { - int js = perm[j] / K; - int ls = perm[j] % K; - if (m > 0) { - memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); + for (size_t j = 0; j < new_beam_size; j++) { + int js = perm[j] / K; + int ls = perm[j] % K; + if (m > 0) { + memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m); + } + new_codes_i[m] = ls; + new_codes_i += m + 1; } - new_codes_i[m] = ls; - new_codes_i += m + 1; } - } + }); } /********************************************************************