Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions faiss/IndexBinaryHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,25 +217,33 @@ void IndexBinaryHNSW::search(
using RH = HeapBlockResultHandler<HNSW::C>;
RH bres(n, distances_f, labels, k);

size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

#pragma omp parallel
{
VisitedTable vt(ntotal);
std::unique_ptr<DistanceComputer> dis(get_distance_computer());
RH::SingleResultHandler res(bres);

#pragma omp for
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
for (idx_t i = 0; i < n; i++) {
res.begin(i);
dis->set_query((float*)(x + i * code_size));
// Given that IndexBinaryHNSW is not an IndexHNSW, we pass nullptr
// as the index parameter. This state does not get used in the
// search function, as it is merely there to to enable Panorama
// execution for IndexHNSWFlatPanorama.
hnsw.search(*dis, nullptr, res, vt, params_in);
HNSWStats stats = hnsw.search(*dis, nullptr, res, vt, params_in);
n1 += stats.n1;
n2 += stats.n2;
ndis += stats.ndis;
nhops += stats.nhops;
res.end();
}
}

hnsw_stats.combine({n1, n2, ndis, nhops});

#pragma omp parallel for
for (int i = 0; i < n * k; ++i) {
distances[i] = std::round(distances_f[i]);
Expand Down Expand Up @@ -267,11 +275,9 @@ template <class HammingComputer>
struct FlatHammingDis : DistanceComputer {
const int code_size;
const uint8_t* b;
size_t ndis;
HammingComputer hc;

float operator()(idx_t i) override {
ndis++;
return hc.hamming(b + i * code_size);
}

Expand All @@ -281,21 +287,13 @@ struct FlatHammingDis : DistanceComputer {
}

explicit FlatHammingDis(const IndexBinaryFlat& storage)
: code_size(storage.code_size),
b(storage.xb.data()),
ndis(0),
hc() {}
: code_size(storage.code_size), b(storage.xb.data()), hc() {}

// NOTE: Pointers are cast from float in order to reuse the floating-point
// DistanceComputer.
void set_query(const float* x) override {
hc.set((uint8_t*)x, code_size);
}

~FlatHammingDis() override {
#pragma omp atomic
hnsw_stats.ndis += ndis;
}
};

struct BuildDistanceComputer {
Expand Down Expand Up @@ -405,6 +403,10 @@ void IndexBinaryHNSWCagra::search(

res.end();
}
#pragma omp critical
{
hnsw_stats.combine(search_stats);
}
}

#pragma omp parallel for
Expand Down
Loading