|
8 | 8 | #include <faiss/IndexHNSW.h> |
9 | 9 |
|
10 | 10 | #include <omp.h> |
| 11 | +#include <atomic> |
11 | 12 | #include <cinttypes> |
12 | 13 | #include <cstdio> |
13 | 14 | #include <cstdlib> |
|
26 | 27 | #include <faiss/IndexIVFPQ.h> |
27 | 28 | #include <faiss/impl/AuxIndexStructures.h> |
28 | 29 | #include <faiss/impl/FaissAssert.h> |
| 30 | +#include <faiss/impl/FaissException.h> |
29 | 31 | #include <faiss/impl/ResultHandler.h> |
30 | 32 | #include <faiss/impl/VisitedTable.h> |
31 | 33 | #include <faiss/impl/hnsw/MinimaxHeap.h> |
@@ -264,29 +266,48 @@ void hnsw_search( |
264 | 266 |
|
265 | 267 | for (idx_t i0 = 0; i0 < n; i0 += check_period) { |
266 | 268 | idx_t i1 = std::min(i0 + check_period, n); |
| 269 | + std::exception_ptr ex; |
| 270 | + std::atomic<bool> interrupt{false}; |
267 | 271 |
|
268 | 272 | #pragma omp parallel if (i1 - i0 > 1) |
269 | 273 | { |
270 | | - VisitedTable vt(index->ntotal, hnsw.use_visited_hashset); |
271 | | - typename BlockResultHandler::SingleResultHandler res(bres); |
272 | | - |
273 | | - std::unique_ptr<DistanceComputer> dis( |
274 | | - storage_distance_computer(index->storage)); |
| 274 | + std::unique_ptr<VisitedTable> vt; |
| 275 | + std::unique_ptr<typename BlockResultHandler::SingleResultHandler> |
| 276 | + res; |
| 277 | + std::unique_ptr<DistanceComputer> dis; |
| 278 | + try { |
| 279 | + vt = std::make_unique<VisitedTable>( |
| 280 | + index->ntotal, hnsw.use_visited_hashset); |
| 281 | + res = std::make_unique< |
| 282 | + typename BlockResultHandler::SingleResultHandler>(bres); |
| 283 | + dis.reset(storage_distance_computer(index->storage)); |
| 284 | + } catch (...) { |
| 285 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 286 | + } |
275 | 287 |
|
276 | 288 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) |
277 | 289 | for (idx_t i = i0; i < i1; i++) { |
278 | | - res.begin(i); |
279 | | - dis->set_query(x + i * index->d); |
280 | | - |
281 | | - HNSWStats stats = hnsw.search(*dis, index, res, vt, params); |
282 | | - n1 += stats.n1; |
283 | | - n2 += stats.n2; |
284 | | - ndis += stats.ndis; |
285 | | - nhops += stats.nhops; |
286 | | - res.end(); |
287 | | - vt.advance(); |
| 290 | + if (interrupt.load(std::memory_order_relaxed)) { |
| 291 | + continue; |
| 292 | + } |
| 293 | + try { |
| 294 | + res->begin(i); |
| 295 | + dis->set_query(x + i * index->d); |
| 296 | + |
| 297 | + HNSWStats stats = |
| 298 | + hnsw.search(*dis, index, *res, *vt, params); |
| 299 | + n1 += stats.n1; |
| 300 | + n2 += stats.n2; |
| 301 | + ndis += stats.ndis; |
| 302 | + nhops += stats.nhops; |
| 303 | + res->end(); |
| 304 | + vt->advance(); |
| 305 | + } catch (...) { |
| 306 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 307 | + } |
288 | 308 | } |
289 | 309 | } |
| 310 | + omp_rethrow_if_exception(ex); |
290 | 311 | InterruptCallback::check(); |
291 | 312 | } |
292 | 313 |
|
@@ -441,37 +462,54 @@ void IndexHNSW::search_level_0( |
441 | 462 | using RH = HeapBlockResultHandler<HNSW::C>; |
442 | 463 | RH bres(n, distances, labels, k); |
443 | 464 |
|
| 465 | + std::exception_ptr ex; |
| 466 | + std::atomic<bool> interrupt{false}; |
444 | 467 | #pragma omp parallel |
445 | 468 | { |
446 | | - std::unique_ptr<DistanceComputer> qdis( |
447 | | - storage_distance_computer(storage)); |
| 469 | + std::unique_ptr<DistanceComputer> qdis; |
448 | 470 | HNSWStats search_stats; |
449 | | - VisitedTable vt(hnsw_ntotal, hnsw.use_visited_hashset); |
450 | | - RH::SingleResultHandler res(bres); |
| 471 | + std::unique_ptr<VisitedTable> vt; |
| 472 | + std::unique_ptr<RH::SingleResultHandler> res; |
| 473 | + try { |
| 474 | + qdis.reset(storage_distance_computer(storage)); |
| 475 | + vt = std::make_unique<VisitedTable>( |
| 476 | + hnsw_ntotal, hnsw.use_visited_hashset); |
| 477 | + res = std::make_unique<RH::SingleResultHandler>(bres); |
| 478 | + } catch (...) { |
| 479 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 480 | + } |
451 | 481 |
|
452 | 482 | #pragma omp for |
453 | 483 | for (idx_t i = 0; i < n; i++) { |
454 | | - res.begin(i); |
455 | | - qdis->set_query(x + i * d); |
456 | | - |
457 | | - hnsw.search_level_0( |
458 | | - *qdis.get(), |
459 | | - res, |
460 | | - nprobe, |
461 | | - nearest + i * nprobe, |
462 | | - nearest_d + i * nprobe, |
463 | | - search_type, |
464 | | - search_stats, |
465 | | - vt, |
466 | | - params); |
467 | | - res.end(); |
468 | | - vt.advance(); |
| 484 | + if (interrupt.load(std::memory_order_relaxed)) { |
| 485 | + continue; |
| 486 | + } |
| 487 | + try { |
| 488 | + res->begin(i); |
| 489 | + qdis->set_query(x + i * d); |
| 490 | + |
| 491 | + hnsw.search_level_0( |
| 492 | + *qdis.get(), |
| 493 | + *res, |
| 494 | + nprobe, |
| 495 | + nearest + i * nprobe, |
| 496 | + nearest_d + i * nprobe, |
| 497 | + search_type, |
| 498 | + search_stats, |
| 499 | + *vt, |
| 500 | + params); |
| 501 | + res->end(); |
| 502 | + vt->advance(); |
| 503 | + } catch (...) { |
| 504 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 505 | + } |
469 | 506 | } |
470 | 507 | #pragma omp critical |
471 | 508 | { |
472 | 509 | hnsw_stats.combine(search_stats); |
473 | 510 | } |
474 | 511 | } |
| 512 | + omp_rethrow_if_exception(ex); |
475 | 513 | if (is_similarity_metric(this->metric_type)) { |
476 | 514 | // we need to revert the negated distances |
477 | 515 | #pragma omp parallel for |
@@ -883,73 +921,88 @@ void IndexHNSW2Level::search( |
883 | 921 | labels, |
884 | 922 | false); |
885 | 923 |
|
| 924 | + std::exception_ptr ex; |
| 925 | + std::atomic<bool> interrupt{false}; |
886 | 926 | #pragma omp parallel |
887 | 927 | { |
888 | 928 | // visited table (not hash set) for tri-state flags. |
889 | | - VisitedTable vt(ntotal, /*use_hashset=*/false); |
890 | | - std::unique_ptr<DistanceComputer> dis( |
891 | | - storage_distance_computer(storage)); |
892 | | - |
| 929 | + std::unique_ptr<VisitedTable> vt; |
| 930 | + std::unique_ptr<DistanceComputer> dis; |
893 | 931 | constexpr int candidates_size = 1; |
894 | | - MinimaxHeap candidates(candidates_size); |
| 932 | + std::unique_ptr<MinimaxHeap> candidates; |
| 933 | + try { |
| 934 | + vt = std::make_unique<VisitedTable>( |
| 935 | + ntotal, /*use_hashset=*/false); |
| 936 | + dis.reset(storage_distance_computer(storage)); |
| 937 | + candidates = std::make_unique<MinimaxHeap>(candidates_size); |
| 938 | + } catch (...) { |
| 939 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 940 | + } |
895 | 941 |
|
896 | 942 | #pragma omp for reduction(+ : n1, n2, ndis, nhops) |
897 | 943 | for (idx_t i = 0; i < n; i++) { |
898 | | - idx_t* idxi = labels + i * k; |
899 | | - float* simi = distances + i * k; |
900 | | - dis->set_query(x + i * d); |
901 | | - |
902 | | - // mark all inverted list elements as visited |
903 | | - |
904 | | - for (size_t j = 0; j < nprobe; j++) { |
905 | | - idx_t key = coarse_assign[j + i * nprobe]; |
906 | | - if (key < 0) { |
907 | | - break; |
908 | | - } |
909 | | - size_t list_length = index_ivfpq->get_list_size(key); |
910 | | - const idx_t* ids = index_ivfpq->invlists->get_ids(key); |
| 944 | + if (interrupt.load(std::memory_order_relaxed)) { |
| 945 | + continue; |
| 946 | + } |
| 947 | + try { |
| 948 | + idx_t* idxi = labels + i * k; |
| 949 | + float* simi = distances + i * k; |
| 950 | + dis->set_query(x + i * d); |
| 951 | + |
| 952 | + // mark all inverted list elements as visited |
| 953 | + for (size_t j = 0; j < nprobe; j++) { |
| 954 | + idx_t key = coarse_assign[j + i * nprobe]; |
| 955 | + if (key < 0) { |
| 956 | + break; |
| 957 | + } |
| 958 | + size_t list_length = index_ivfpq->get_list_size(key); |
| 959 | + const idx_t* ids = index_ivfpq->invlists->get_ids(key); |
911 | 960 |
|
912 | | - for (size_t jj = 0; jj < list_length; jj++) { |
913 | | - vt.set(ids[jj]); |
| 961 | + for (size_t jj = 0; jj < list_length; jj++) { |
| 962 | + vt->set(ids[jj]); |
| 963 | + } |
914 | 964 | } |
915 | | - } |
916 | 965 |
|
917 | | - candidates.clear(); |
| 966 | + candidates->clear(); |
918 | 967 |
|
919 | | - for (int j = 0; j < k; j++) { |
920 | | - if (idxi[j] < 0) { |
921 | | - break; |
| 968 | + for (int j = 0; j < k; j++) { |
| 969 | + if (idxi[j] < 0) { |
| 970 | + break; |
| 971 | + } |
| 972 | + candidates->push( |
| 973 | + static_cast<storage_idx_t>(idxi[j]), simi[j]); |
922 | 974 | } |
923 | | - candidates.push( |
924 | | - static_cast<storage_idx_t>(idxi[j]), simi[j]); |
925 | | - } |
926 | | - |
927 | | - // reorder from sorted to heap |
928 | | - maxheap_heapify(k, simi, idxi, simi, idxi, k); |
929 | | - |
930 | | - HNSWStats search_stats; |
931 | | - search_from_candidates_2( |
932 | | - hnsw, |
933 | | - *dis, |
934 | | - static_cast<int>(k), |
935 | | - idxi, |
936 | | - simi, |
937 | | - candidates, |
938 | | - vt, |
939 | | - search_stats, |
940 | | - 0, |
941 | | - static_cast<int>(k)); |
942 | | - n1 += search_stats.n1; |
943 | | - n2 += search_stats.n2; |
944 | | - ndis += search_stats.ndis; |
945 | | - nhops += search_stats.nhops; |
946 | 975 |
|
947 | | - vt.advance(); |
948 | | - vt.advance(); |
| 976 | + // reorder from sorted to heap |
| 977 | + maxheap_heapify(k, simi, idxi, simi, idxi, k); |
949 | 978 |
|
950 | | - maxheap_reorder(k, simi, idxi); |
| 979 | + HNSWStats search_stats; |
| 980 | + search_from_candidates_2( |
| 981 | + hnsw, |
| 982 | + *dis, |
| 983 | + k, |
| 984 | + idxi, |
| 985 | + simi, |
| 986 | + *candidates, |
| 987 | + *vt, |
| 988 | + search_stats, |
| 989 | + 0, |
| 990 | + k); |
| 991 | + n1 += search_stats.n1; |
| 992 | + n2 += search_stats.n2; |
| 993 | + ndis += search_stats.ndis; |
| 994 | + nhops += search_stats.nhops; |
| 995 | + |
| 996 | + vt->advance(); |
| 997 | + vt->advance(); |
| 998 | + |
| 999 | + maxheap_reorder(k, simi, idxi); |
| 1000 | + } catch (...) { |
| 1001 | + omp_capture_exception(ex, [&] { interrupt = true; }); |
| 1002 | + } |
951 | 1003 | } |
952 | 1004 | } |
| 1005 | + omp_rethrow_if_exception(ex); |
953 | 1006 |
|
954 | 1007 | hnsw_stats.combine({n1, n2, ndis, nhops}); |
955 | 1008 | } |
|
0 commit comments