Skip to content

Commit dfce6e9

Browse files
scsiguymeta-codesync[bot]
authored andcommitted
Fix OMP exception safety in IndexHNSW search (#5133)
Summary: Pull Request resolved: #5133 Exceptions thrown inside OpenMP parallel regions in IndexHNSW search methods call std::terminate because the OpenMP specification does not allow exceptions to escape parallel regions. This is triggered by corrupt serialized index data that causes allocation failures (e.g. VisitedTable construction with absurdly large ntotal) or other errors in DistanceComputer, but can occur with any exception thrown during the HNSW search loop. Wrap the OpenMP parallel bodies in hnsw_search(), IndexHNSW::search_level_0(), and IndexHNSW2Level::search() with per-thread try/catch blocks that capture exceptions via std::exception_ptr and re-throw them on the main thread after the parallel region completes. This follows the same pattern established in IndexFlatCodes and IndexNNDescent/IndexNSG. VisitedTable, SingleResultHandler, and DistanceComputer are moved from stack to heap allocation (std::unique_ptr) because the constructors are now inside try/catch blocks. If a stack-allocated object's constructor throws, the object never comes into existence, but OpenMP requires all threads to participate in the worksharing loop that follows — any reference to the uninitialized object would be undefined behavior. With std::unique_ptr, the variable starts as nullptr before the try block, remains nullptr if construction fails, and the interrupt flag ensures all loop iterations skip via continue without dereferencing it. Reviewed By: mnorris11 Differential Revision: D101638521 fbshipit-source-id: 3920be5e41b1bff0490a14e2571b3dbeace64397
1 parent c627334 commit dfce6e9

2 files changed

Lines changed: 474 additions & 85 deletions

File tree

faiss/IndexHNSW.cpp

Lines changed: 138 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <faiss/IndexHNSW.h>
99

1010
#include <omp.h>
11+
#include <atomic>
1112
#include <cinttypes>
1213
#include <cstdio>
1314
#include <cstdlib>
@@ -26,6 +27,7 @@
2627
#include <faiss/IndexIVFPQ.h>
2728
#include <faiss/impl/AuxIndexStructures.h>
2829
#include <faiss/impl/FaissAssert.h>
30+
#include <faiss/impl/FaissException.h>
2931
#include <faiss/impl/ResultHandler.h>
3032
#include <faiss/impl/VisitedTable.h>
3133
#include <faiss/impl/hnsw/MinimaxHeap.h>
@@ -264,29 +266,48 @@ void hnsw_search(
264266

265267
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
266268
idx_t i1 = std::min(i0 + check_period, n);
269+
std::exception_ptr ex;
270+
std::atomic<bool> interrupt{false};
267271

268272
#pragma omp parallel if (i1 - i0 > 1)
269273
{
270-
VisitedTable vt(index->ntotal, hnsw.use_visited_hashset);
271-
typename BlockResultHandler::SingleResultHandler res(bres);
272-
273-
std::unique_ptr<DistanceComputer> dis(
274-
storage_distance_computer(index->storage));
274+
std::unique_ptr<VisitedTable> vt;
275+
std::unique_ptr<typename BlockResultHandler::SingleResultHandler>
276+
res;
277+
std::unique_ptr<DistanceComputer> dis;
278+
try {
279+
vt = std::make_unique<VisitedTable>(
280+
index->ntotal, hnsw.use_visited_hashset);
281+
res = std::make_unique<
282+
typename BlockResultHandler::SingleResultHandler>(bres);
283+
dis.reset(storage_distance_computer(index->storage));
284+
} catch (...) {
285+
omp_capture_exception(ex, [&] { interrupt = true; });
286+
}
275287

276288
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
277289
for (idx_t i = i0; i < i1; i++) {
278-
res.begin(i);
279-
dis->set_query(x + i * index->d);
280-
281-
HNSWStats stats = hnsw.search(*dis, index, res, vt, params);
282-
n1 += stats.n1;
283-
n2 += stats.n2;
284-
ndis += stats.ndis;
285-
nhops += stats.nhops;
286-
res.end();
287-
vt.advance();
290+
if (interrupt.load(std::memory_order_relaxed)) {
291+
continue;
292+
}
293+
try {
294+
res->begin(i);
295+
dis->set_query(x + i * index->d);
296+
297+
HNSWStats stats =
298+
hnsw.search(*dis, index, *res, *vt, params);
299+
n1 += stats.n1;
300+
n2 += stats.n2;
301+
ndis += stats.ndis;
302+
nhops += stats.nhops;
303+
res->end();
304+
vt->advance();
305+
} catch (...) {
306+
omp_capture_exception(ex, [&] { interrupt = true; });
307+
}
288308
}
289309
}
310+
omp_rethrow_if_exception(ex);
290311
InterruptCallback::check();
291312
}
292313

@@ -441,37 +462,54 @@ void IndexHNSW::search_level_0(
441462
using RH = HeapBlockResultHandler<HNSW::C>;
442463
RH bres(n, distances, labels, k);
443464

465+
std::exception_ptr ex;
466+
std::atomic<bool> interrupt{false};
444467
#pragma omp parallel
445468
{
446-
std::unique_ptr<DistanceComputer> qdis(
447-
storage_distance_computer(storage));
469+
std::unique_ptr<DistanceComputer> qdis;
448470
HNSWStats search_stats;
449-
VisitedTable vt(hnsw_ntotal, hnsw.use_visited_hashset);
450-
RH::SingleResultHandler res(bres);
471+
std::unique_ptr<VisitedTable> vt;
472+
std::unique_ptr<RH::SingleResultHandler> res;
473+
try {
474+
qdis.reset(storage_distance_computer(storage));
475+
vt = std::make_unique<VisitedTable>(
476+
hnsw_ntotal, hnsw.use_visited_hashset);
477+
res = std::make_unique<RH::SingleResultHandler>(bres);
478+
} catch (...) {
479+
omp_capture_exception(ex, [&] { interrupt = true; });
480+
}
451481

452482
#pragma omp for
453483
for (idx_t i = 0; i < n; i++) {
454-
res.begin(i);
455-
qdis->set_query(x + i * d);
456-
457-
hnsw.search_level_0(
458-
*qdis.get(),
459-
res,
460-
nprobe,
461-
nearest + i * nprobe,
462-
nearest_d + i * nprobe,
463-
search_type,
464-
search_stats,
465-
vt,
466-
params);
467-
res.end();
468-
vt.advance();
484+
if (interrupt.load(std::memory_order_relaxed)) {
485+
continue;
486+
}
487+
try {
488+
res->begin(i);
489+
qdis->set_query(x + i * d);
490+
491+
hnsw.search_level_0(
492+
*qdis.get(),
493+
*res,
494+
nprobe,
495+
nearest + i * nprobe,
496+
nearest_d + i * nprobe,
497+
search_type,
498+
search_stats,
499+
*vt,
500+
params);
501+
res->end();
502+
vt->advance();
503+
} catch (...) {
504+
omp_capture_exception(ex, [&] { interrupt = true; });
505+
}
469506
}
470507
#pragma omp critical
471508
{
472509
hnsw_stats.combine(search_stats);
473510
}
474511
}
512+
omp_rethrow_if_exception(ex);
475513
if (is_similarity_metric(this->metric_type)) {
476514
// we need to revert the negated distances
477515
#pragma omp parallel for
@@ -883,73 +921,88 @@ void IndexHNSW2Level::search(
883921
labels,
884922
false);
885923

924+
std::exception_ptr ex;
925+
std::atomic<bool> interrupt{false};
886926
#pragma omp parallel
887927
{
888928
// visited table (not hash set) for tri-state flags.
889-
VisitedTable vt(ntotal, /*use_hashset=*/false);
890-
std::unique_ptr<DistanceComputer> dis(
891-
storage_distance_computer(storage));
892-
929+
std::unique_ptr<VisitedTable> vt;
930+
std::unique_ptr<DistanceComputer> dis;
893931
constexpr int candidates_size = 1;
894-
MinimaxHeap candidates(candidates_size);
932+
std::unique_ptr<MinimaxHeap> candidates;
933+
try {
934+
vt = std::make_unique<VisitedTable>(
935+
ntotal, /*use_hashset=*/false);
936+
dis.reset(storage_distance_computer(storage));
937+
candidates = std::make_unique<MinimaxHeap>(candidates_size);
938+
} catch (...) {
939+
omp_capture_exception(ex, [&] { interrupt = true; });
940+
}
895941

896942
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
897943
for (idx_t i = 0; i < n; i++) {
898-
idx_t* idxi = labels + i * k;
899-
float* simi = distances + i * k;
900-
dis->set_query(x + i * d);
901-
902-
// mark all inverted list elements as visited
903-
904-
for (size_t j = 0; j < nprobe; j++) {
905-
idx_t key = coarse_assign[j + i * nprobe];
906-
if (key < 0) {
907-
break;
908-
}
909-
size_t list_length = index_ivfpq->get_list_size(key);
910-
const idx_t* ids = index_ivfpq->invlists->get_ids(key);
944+
if (interrupt.load(std::memory_order_relaxed)) {
945+
continue;
946+
}
947+
try {
948+
idx_t* idxi = labels + i * k;
949+
float* simi = distances + i * k;
950+
dis->set_query(x + i * d);
951+
952+
// mark all inverted list elements as visited
953+
for (size_t j = 0; j < nprobe; j++) {
954+
idx_t key = coarse_assign[j + i * nprobe];
955+
if (key < 0) {
956+
break;
957+
}
958+
size_t list_length = index_ivfpq->get_list_size(key);
959+
const idx_t* ids = index_ivfpq->invlists->get_ids(key);
911960

912-
for (size_t jj = 0; jj < list_length; jj++) {
913-
vt.set(ids[jj]);
961+
for (size_t jj = 0; jj < list_length; jj++) {
962+
vt->set(ids[jj]);
963+
}
914964
}
915-
}
916965

917-
candidates.clear();
966+
candidates->clear();
918967

919-
for (int j = 0; j < k; j++) {
920-
if (idxi[j] < 0) {
921-
break;
968+
for (int j = 0; j < k; j++) {
969+
if (idxi[j] < 0) {
970+
break;
971+
}
972+
candidates->push(
973+
static_cast<storage_idx_t>(idxi[j]), simi[j]);
922974
}
923-
candidates.push(
924-
static_cast<storage_idx_t>(idxi[j]), simi[j]);
925-
}
926-
927-
// reorder from sorted to heap
928-
maxheap_heapify(k, simi, idxi, simi, idxi, k);
929-
930-
HNSWStats search_stats;
931-
search_from_candidates_2(
932-
hnsw,
933-
*dis,
934-
static_cast<int>(k),
935-
idxi,
936-
simi,
937-
candidates,
938-
vt,
939-
search_stats,
940-
0,
941-
static_cast<int>(k));
942-
n1 += search_stats.n1;
943-
n2 += search_stats.n2;
944-
ndis += search_stats.ndis;
945-
nhops += search_stats.nhops;
946975

947-
vt.advance();
948-
vt.advance();
976+
// reorder from sorted to heap
977+
maxheap_heapify(k, simi, idxi, simi, idxi, k);
949978

950-
maxheap_reorder(k, simi, idxi);
979+
HNSWStats search_stats;
980+
search_from_candidates_2(
981+
hnsw,
982+
*dis,
983+
k,
984+
idxi,
985+
simi,
986+
*candidates,
987+
*vt,
988+
search_stats,
989+
0,
990+
k);
991+
n1 += search_stats.n1;
992+
n2 += search_stats.n2;
993+
ndis += search_stats.ndis;
994+
nhops += search_stats.nhops;
995+
996+
vt->advance();
997+
vt->advance();
998+
999+
maxheap_reorder(k, simi, idxi);
1000+
} catch (...) {
1001+
omp_capture_exception(ex, [&] { interrupt = true; });
1002+
}
9511003
}
9521004
}
1005+
omp_rethrow_if_exception(ex);
9531006

9541007
hnsw_stats.combine({n1, n2, ndis, nhops});
9551008
}

0 commit comments

Comments
 (0)