|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the MIT license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +#pragma once |
| 9 | + |
| 10 | +#include <cassert> |
| 11 | + |
| 12 | +#include <faiss/impl/FaissAssert.h> |
| 13 | +#include <faiss/impl/fast_scan/kernels_simd256.h> |
| 14 | +#include <faiss/impl/fast_scan/kernels_simd512.h> |
| 15 | +#include <faiss/impl/fast_scan/simd_result_handlers.h> |
| 16 | + |
| 17 | +namespace faiss { |
| 18 | + |
| 19 | +using namespace simd_result_handlers; |
| 20 | + |
| 21 | +/* |
| 22 | + * Unified kernel: selects 256-bit vs 512-bit path based on |
| 23 | + * compile-time __AVX512F__ guard. |
| 24 | + */ |
| 25 | +template <int NQ, class ResultHandler, class Scaler> |
| 26 | +void kernel_accumulate_block( |
| 27 | + int nsq, |
| 28 | + const uint8_t* codes, |
| 29 | + const uint8_t* LUT, |
| 30 | + ResultHandler& res, |
| 31 | + const Scaler& scaler) { |
| 32 | +#ifdef __AVX512F__ |
| 33 | + pq4_kernel_qbs_512<NQ>(nsq, codes, LUT, res, scaler); |
| 34 | +#else |
| 35 | + pq4_kernel_qbs_256<NQ>(nsq, codes, LUT, res, scaler); |
| 36 | +#endif |
| 37 | +} |
| 38 | + |
| 39 | +// handle at most 4 blocks of queries |
| 40 | +template <int QBS, class ResultHandler, class Scaler> |
| 41 | +void accumulate_q_4step( |
| 42 | + size_t ntotal2, |
| 43 | + int nsq, |
| 44 | + const uint8_t* codes, |
| 45 | + const uint8_t* LUT0, |
| 46 | + ResultHandler& res, |
| 47 | + const Scaler& scaler, |
| 48 | + size_t block_stride) { |
| 49 | + constexpr int Q1 = QBS & 15; |
| 50 | + constexpr int Q2 = (QBS >> 4) & 15; |
| 51 | + constexpr int Q3 = (QBS >> 8) & 15; |
| 52 | + constexpr int Q4 = (QBS >> 12) & 15; |
| 53 | + constexpr int SQ = Q1 + Q2 + Q3 + Q4; |
| 54 | + |
| 55 | + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { |
| 56 | + FixedStorageHandler<SQ, 2> res2; |
| 57 | + const uint8_t* LUT = LUT0; |
| 58 | + kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler); |
| 59 | + LUT += Q1 * nsq * 16; |
| 60 | + if (Q2 > 0) { |
| 61 | + res2.set_block_origin(Q1, 0); |
| 62 | + kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler); |
| 63 | + LUT += Q2 * nsq * 16; |
| 64 | + } |
| 65 | + if (Q3 > 0) { |
| 66 | + res2.set_block_origin(Q1 + Q2, 0); |
| 67 | + kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler); |
| 68 | + LUT += Q3 * nsq * 16; |
| 69 | + } |
| 70 | + if (Q4 > 0) { |
| 71 | + res2.set_block_origin(Q1 + Q2 + Q3, 0); |
| 72 | + kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler); |
| 73 | + } |
| 74 | + res.set_block_origin(0, j0); |
| 75 | + res2.to_other_handler(res); |
| 76 | + codes += block_stride; |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +template <int NQ, class ResultHandler, class Scaler> |
| 81 | +void kernel_accumulate_block_loop( |
| 82 | + size_t ntotal2, |
| 83 | + int nsq, |
| 84 | + const uint8_t* codes, |
| 85 | + const uint8_t* LUT, |
| 86 | + ResultHandler& res, |
| 87 | + const Scaler& scaler, |
| 88 | + size_t block_stride) { |
| 89 | + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { |
| 90 | + res.set_block_origin(0, j0); |
| 91 | + kernel_accumulate_block<NQ, ResultHandler>( |
| 92 | + nsq, codes, LUT, res, scaler); |
| 93 | + codes += block_stride; |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +// non-template version of accumulate kernel -- dispatches dynamically |
| 98 | +template <class ResultHandler, class Scaler> |
| 99 | +void accumulate( |
| 100 | + int nq, |
| 101 | + size_t ntotal2, |
| 102 | + int nsq, |
| 103 | + const uint8_t* codes, |
| 104 | + const uint8_t* LUT, |
| 105 | + ResultHandler& res, |
| 106 | + const Scaler& scaler, |
| 107 | + size_t block_stride) { |
| 108 | + assert(nsq % 2 == 0); |
| 109 | + assert(is_aligned_pointer(LUT)); |
| 110 | + |
| 111 | +#define DISPATCH(NQ) \ |
| 112 | + case NQ: \ |
| 113 | + kernel_accumulate_block_loop<NQ, ResultHandler>( \ |
| 114 | + ntotal2, nsq, codes, LUT, res, scaler, block_stride); \ |
| 115 | + return |
| 116 | + |
| 117 | + switch (nq) { |
| 118 | + DISPATCH(1); |
| 119 | + DISPATCH(2); |
| 120 | + DISPATCH(3); |
| 121 | + DISPATCH(4); |
| 122 | + } |
| 123 | + FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); |
| 124 | + |
| 125 | +#undef DISPATCH |
| 126 | +} |
| 127 | + |
| 128 | +template <class ResultHandler, class Scaler> |
| 129 | +void pq4_accumulate_loop_qbs_fixed_scaler( |
| 130 | + int qbs, |
| 131 | + size_t ntotal2, |
| 132 | + int nsq, |
| 133 | + const uint8_t* codes, |
| 134 | + const uint8_t* LUT0, |
| 135 | + ResultHandler& res, |
| 136 | + const Scaler& scaler, |
| 137 | + size_t block_stride = 0) { |
| 138 | + assert(nsq % 2 == 0); |
| 139 | + assert(is_aligned_pointer(codes)); |
| 140 | + assert(is_aligned_pointer(LUT0)); |
| 141 | + |
| 142 | + // try out optimized versions |
| 143 | + switch (qbs) { |
| 144 | +#define DISPATCH(QBS) \ |
| 145 | + case QBS: \ |
| 146 | + accumulate_q_4step<QBS>( \ |
| 147 | + ntotal2, nsq, codes, LUT0, res, scaler, block_stride); \ |
| 148 | + return; |
| 149 | + DISPATCH(0x3333); // 12 |
| 150 | + DISPATCH(0x2333); // 11 |
| 151 | + DISPATCH(0x2233); // 10 |
| 152 | + DISPATCH(0x333); // 9 |
| 153 | + DISPATCH(0x2223); // 9 |
| 154 | + DISPATCH(0x233); // 8 |
| 155 | + DISPATCH(0x1223); // 8 |
| 156 | + DISPATCH(0x223); // 7 |
| 157 | + DISPATCH(0x34); // 7 |
| 158 | + DISPATCH(0x133); // 7 |
| 159 | + DISPATCH(0x6); // 6 |
| 160 | + DISPATCH(0x33); // 6 |
| 161 | + DISPATCH(0x123); // 6 |
| 162 | + DISPATCH(0x222); // 6 |
| 163 | + DISPATCH(0x23); // 5 |
| 164 | + DISPATCH(0x5); // 5 |
| 165 | + DISPATCH(0x13); // 4 |
| 166 | + DISPATCH(0x22); // 4 |
| 167 | + DISPATCH(0x4); // 4 |
| 168 | + DISPATCH(0x3); // 3 |
| 169 | + DISPATCH(0x21); // 3 |
| 170 | + DISPATCH(0x2); // 2 |
| 171 | + DISPATCH(0x1); // 1 |
| 172 | +#undef DISPATCH |
| 173 | + } |
| 174 | + |
| 175 | + // default implementation where qbs is not known at compile time |
| 176 | + for (size_t j0 = 0; j0 < ntotal2; j0 += 32) { |
| 177 | + const uint8_t* LUT = LUT0; |
| 178 | + int qi = qbs; |
| 179 | + int i0 = 0; |
| 180 | + while (qi) { |
| 181 | + int nq = qi & 15; |
| 182 | + qi >>= 4; |
| 183 | + res.set_block_origin(i0, j0); |
| 184 | +#define DISPATCH(NQ) \ |
| 185 | + case NQ: \ |
| 186 | + kernel_accumulate_block<NQ, ResultHandler>( \ |
| 187 | + nsq, codes, LUT, res, scaler); \ |
| 188 | + break |
| 189 | + switch (nq) { |
| 190 | + DISPATCH(1); |
| 191 | + DISPATCH(2); |
| 192 | + DISPATCH(3); |
| 193 | + DISPATCH(4); |
| 194 | +#undef DISPATCH |
| 195 | + default: |
| 196 | + FAISS_THROW_FMT("accumulate nq=%d not instantiated", nq); |
| 197 | + } |
| 198 | + i0 += nq; |
| 199 | + LUT += nq * nsq * 16; |
| 200 | + } |
| 201 | + codes += block_stride; |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +} // namespace faiss |
0 commit comments