Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ set(FAISS_SIMD_AVX2_SRC
impl/pq_code_distance/pq_code_distance-avx2.cpp
impl/scalar_quantizer/sq-avx2.cpp
impl/approx_topk/avx2.cpp
impl/binary_hamming/avx2.cpp
utils/simd_impl/distances_avx2.cpp
utils/simd_impl/hamming_avx2.cpp
utils/simd_impl/partitioning_avx2.cpp
utils/distances_fused/simdlib_based.cpp
utils/simd_impl/rabitq_avx2.cpp
Expand All @@ -32,7 +34,9 @@ set(FAISS_SIMD_NEON_SRC
impl/fast_scan/impl-neon.cpp
impl/scalar_quantizer/sq-neon.cpp
impl/approx_topk/neon.cpp
impl/binary_hamming/neon.cpp
utils/simd_impl/distances_aarch64.cpp
utils/simd_impl/hamming_neon.cpp
utils/simd_impl/partitioning_neon.cpp
utils/distances_fused/simdlib_based_neon.cpp
utils/simd_impl/rabitq_neon.cpp
Expand Down Expand Up @@ -342,6 +346,7 @@ set(FAISS_HEADERS
utils/hamming_distance/avx512-inl.h
utils/simd_impl/distances_autovec-inl.h
utils/simd_impl/distances_simdlib256.h
utils/simd_impl/hamming_impl.h
utils/simd_impl/exhaustive_L2sqr_blas_cmax.h
utils/simd_impl/IVFFlatScanner-inl.h
utils/simd_impl/distances_sse-inl.h
Expand Down
51 changes: 12 additions & 39 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
#include <faiss/utils/hamming.h>
#include <faiss/utils/random.h>

#include <faiss/impl/simd_dispatch.h>

#include <random>

// Scalar (NONE) fallback for dynamic dispatch
#define THE_SIMD_LEVEL SIMDLevel::NONE
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
#include <faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h>
#undef THE_SIMD_LEVEL

namespace faiss {

/**************************************************************
Expand Down Expand Up @@ -280,50 +288,15 @@ void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t* recons) const {
storage->reconstruct(key, recons);
}

namespace {

template <class HammingComputer>
struct FlatHammingDis : DistanceComputer {
const int code_size;
const uint8_t* b;
HammingComputer hc;

float operator()(idx_t i) override {
return hc.hamming(b + i * code_size);
}

float symmetric_dis(idx_t i, idx_t j) override {
return HammingComputerDefault(b + j * code_size, code_size)
.hamming(b + i * code_size);
}

explicit FlatHammingDis(const IndexBinaryFlat& storage)
: code_size(storage.code_size), b(storage.xb.data()), hc() {}

// NOTE: Pointers are cast from float in order to reuse the floating-point
// DistanceComputer.
void set_query(const float* x) override {
hc.set((uint8_t*)x, code_size);
}
};

struct BuildDistanceComputer {
using T = DistanceComputer*;
template <class HammingComputer>
DistanceComputer* f(IndexBinaryFlat* flat_storage) {
return new FlatHammingDis<HammingComputer>(*flat_storage);
}
};

} // namespace

DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
FAISS_THROW_IF_NOT_MSG(
flat_storage != nullptr,
"IndexBinaryHNSW requires IndexBinaryFlat storage");
BuildDistanceComputer bd;
return dispatch_HammingComputer(code_size, bd, flat_storage);
return with_simd_level_256bit([&]<SIMDLevel SL>() {
return make_binary_hnsw_distance_computer_dispatch<SL>(
code_size, flat_storage);
});
}

/**************************************************************
Expand Down
238 changes: 28 additions & 210 deletions faiss/IndexBinaryHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>

#include <faiss/impl/simd_dispatch.h>

// Scalar (NONE) fallback for dynamic dispatch
#define THE_SIMD_LEVEL SIMDLevel::NONE
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
#include <faiss/impl/binary_hamming/IndexBinaryHash_impl.h>
#undef THE_SIMD_LEVEL

namespace faiss {

void IndexBinaryHash::InvertedList::add(
Expand Down Expand Up @@ -64,139 +72,8 @@ void IndexBinaryHash::add_with_ids(
ntotal += n;
}

namespace {

/** Enumerate all bit vectors of size nbit with up to maxflip 1s
* test in P127257851 P127258235
*/
struct FlipEnumerator {
int nbit, nflip, maxflip;
uint64_t mask, x;

FlipEnumerator(int nbit_, int maxflip_) : nbit(nbit_), maxflip(maxflip_) {
nflip = 0;
mask = 0;
x = 0;
}

bool next() {
if (x == mask) {
if (nflip == maxflip) {
return false;
}
// increase Hamming radius
nflip++;
mask = (((uint64_t)1 << nflip) - 1);
x = mask << (nbit - nflip);
return true;
}

int i = __builtin_ctzll(x);

if (i > 0) {
x ^= (uint64_t)3 << (i - 1);
} else {
// nb of LSB 1s
int n1 = __builtin_ctzll(~x);
// clear them
x &= ((uint64_t)(-1) << n1);
int n2 = __builtin_ctzll(x);
x ^= (((uint64_t)1 << (n1 + 2)) - 1) << (n2 - n1 - 1);
}
return true;
}
};

struct RangeSearchResults {
int radius;
RangeQueryResult& qres;

inline void add(float dis, idx_t id) {
if (dis < radius) {
qres.add(dis, id);
}
}
};

struct KnnSearchResults {
// heap params
idx_t k;
int32_t* heap_sim;
idx_t* heap_ids;

using C = CMax<int, idx_t>;

inline void add(float dis, idx_t id) {
if (dis < heap_sim[0]) {
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
}
}
};

template <class HammingComputer, class SearchResults>
void search_single_query_template(
const IndexBinaryHash& index,
const uint8_t* q,
SearchResults& res,
size_t& n0,
size_t& nlist,
size_t& ndis) {
size_t code_size = index.code_size;
BitstringReader br(q, code_size);
uint64_t qhash = br.read(index.b);
HammingComputer hc(q, code_size);
FlipEnumerator fe(index.b, index.nflip);

// loop over neighbors that are at most at nflip bits
do {
uint64_t hash = qhash ^ fe.x;
auto it = index.invlists.find(hash);

if (it == index.invlists.end()) {
continue;
}

const IndexBinaryHash::InvertedList& il = it->second;

size_t nv = il.ids.size();

if (nv == 0) {
n0++;
} else {
const uint8_t* codes = il.vecs.data();
for (size_t i = 0; i < nv; i++) {
int dis = hc.hamming(codes);
res.add(dis, il.ids[i]);
codes += code_size;
}
ndis += nv;
nlist++;
}
} while (fe.next());
}

struct Run_search_single_query {
using T = void;
template <class HammingComputer, class... Types>
T f(Types*... args) {
search_single_query_template<HammingComputer>(*args...);
}
};

template <class SearchResults>
void search_single_query(
const IndexBinaryHash& index,
const uint8_t* q,
SearchResults& res,
size_t& n0,
size_t& nlist,
size_t& ndis) {
Run_search_single_query r;
dispatch_HammingComputer(
index.code_size, r, &index, &q, &res, &n0, &nlist, &ndis);
}

} // anonymous namespace
// search_single_query_template and helpers are now in
// impl/binary_hamming/IndexBinaryHash_impl.h (compiled per-ISA)

void IndexBinaryHash::range_search(
idx_t n,
Expand All @@ -215,10 +92,12 @@ void IndexBinaryHash::range_search(
#pragma omp for
for (idx_t i = 0; i < n; i++) { // loop queries
RangeQueryResult& qres = pres.new_result(i);
RangeSearchResults res = {radius, qres};
const uint8_t* q = x + i * code_size;

search_single_query(*this, q, res, n0, nlist, ndis);
with_simd_level_256bit([&]<SIMDLevel SL>() {
binary_hash_range_search_dispatch<SL>(
*this, q, radius, qres, n0, nlist, ndis);
});
}
pres.finalize();
}
Expand Down Expand Up @@ -248,10 +127,12 @@ void IndexBinaryHash::search(
idx_t* idxi = labels + k * i;

heap_heapify<HeapForL2>(k, simi, idxi);
KnnSearchResults res = {k, simi, idxi};
const uint8_t* q = x + i * code_size;

search_single_query(*this, q, res, n0, nlist, ndis);
with_simd_level_256bit([&]<SIMDLevel SL>() {
binary_hash_knn_search_dispatch<SL>(
*this, q, k, simi, idxi, n0, nlist, ndis);
});

heap_reorder<HeapForL2>(k, simi, idxi);
}
Expand Down Expand Up @@ -324,75 +205,8 @@ void IndexBinaryMultiHash::add(idx_t n, const uint8_t* x) {
ntotal += n;
}

namespace {

template <class HammingComputer, class SearchResults>
static void verify_shortlist(
const IndexBinaryFlat* index,
const uint8_t* q,
const std::unordered_set<idx_t>& shortlist,
SearchResults& res) {
size_t code_size = index->code_size;

HammingComputer hc(q, code_size);
const uint8_t* codes = index->xb.data();

for (auto i : shortlist) {
int dis = hc.hamming(codes + i * code_size);
res.add(dis, i);
}
}

struct Run_verify_shortlist {
using T = void;
template <class HammingComputer, class... Types>
void f(Types... args) {
verify_shortlist<HammingComputer>(args...);
}
};

template <class SearchResults>
void search_1_query_multihash(
const IndexBinaryMultiHash& index,
const uint8_t* xi,
SearchResults& res,
size_t& n0,
size_t& nlist,
size_t& ndis) {
std::unordered_set<idx_t> shortlist;
int b = index.b;

BitstringReader br(xi, index.code_size);
for (int h = 0; h < index.nhash; h++) {
uint64_t qhash = br.read(b);
const IndexBinaryMultiHash::Map& map = index.maps[h];

FlipEnumerator fe(index.b, index.nflip);
// loop over neighbors that are at most at nflip bits
do {
uint64_t hash = qhash ^ fe.x;
auto it = map.find(hash);

if (it != map.end()) {
const std::vector<idx_t>& v = it->second;
for (auto i : v) {
shortlist.insert(i);
}
nlist++;
} else {
n0++;
}
} while (fe.next());
}
ndis += shortlist.size();

// verify shortlist
Run_verify_shortlist r;
dispatch_HammingComputer(
index.code_size, r, index.storage, xi, shortlist, res);
}

} // anonymous namespace
// verify_shortlist and search_1_query_multihash are now in
// impl/binary_hamming/IndexBinaryHash_impl.h (compiled per-ISA)

void IndexBinaryMultiHash::range_search(
idx_t n,
Expand All @@ -411,10 +225,12 @@ void IndexBinaryMultiHash::range_search(
#pragma omp for
for (idx_t i = 0; i < n; i++) { // loop queries
RangeQueryResult& qres = pres.new_result(i);
RangeSearchResults res = {radius, qres};
const uint8_t* q = x + i * code_size;

search_1_query_multihash(*this, q, res, n0, nlist, ndis);
with_simd_level_256bit([&]<SIMDLevel SL>() {
binary_multihash_range_search_dispatch<SL>(
*this, q, radius, qres, n0, nlist, ndis);
});
}
pres.finalize();
}
Expand Down Expand Up @@ -444,10 +260,12 @@ void IndexBinaryMultiHash::search(
idx_t* idxi = labels + k * i;

heap_heapify<HeapForL2>(k, simi, idxi);
KnnSearchResults res = {k, simi, idxi};
const uint8_t* q = x + i * code_size;

search_1_query_multihash(*this, q, res, n0, nlist, ndis);
with_simd_level_256bit([&]<SIMDLevel SL>() {
binary_multihash_knn_search_dispatch<SL>(
*this, q, k, simi, idxi, n0, nlist, ndis);
});

heap_reorder<HeapForL2>(k, simi, idxi);
}
Expand Down
Loading
Loading