@@ -828,103 +828,119 @@ void IndexIVF::range_search_preassigned(
828828
829829#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
830830 {
831- RangeSearchPartialResult pres (result);
832- std::unique_ptr<InvertedListScanner> scanner (
833- get_InvertedListScanner (store_pairs, sel, params));
834- FAISS_THROW_IF_NOT (scanner.get ());
835- all_pres[omp_get_thread_num ()] = &pres;
836-
837- // prepare the list scanning function
838-
839- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
840- idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
841- if (key < 0 ) {
842- return ;
843- }
844- FAISS_THROW_IF_NOT_FMT (
845- key < (idx_t )nlist,
846- " Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n " ,
847- key,
848- ik,
849- nlist);
850-
851- if (invlists->is_empty (key, inverted_list_context)) {
852- return ;
853- }
831+ try {
832+ RangeSearchPartialResult pres (result);
833+ std::unique_ptr<InvertedListScanner> scanner (
834+ get_InvertedListScanner (store_pairs, sel, params));
835+ FAISS_THROW_IF_NOT (scanner.get ());
836+ all_pres[omp_get_thread_num ()] = &pres;
854837
855- try {
856- size_t list_size = 0 ;
857- scanner->set_list (key, coarse_dis[i * cur_nprobe + ik]);
858- if (invlists->use_iterator ) {
859- std::unique_ptr<InvertedListsIterator> it (
860- invlists->get_iterator (key, inverted_list_context));
838+ // prepare the list scanning function
861839
862- scanner->iterate_codes_range (
863- it.get (), radius, qres, list_size);
864- } else {
865- InvertedLists::ScopedCodes scodes (invlists, key);
866- InvertedLists::ScopedIds ids (invlists, key);
867- list_size = invlists->list_size (key);
840+ auto scan_list_func = [&](size_t i,
841+ size_t ik,
842+ RangeQueryResult& qres) {
843+ try {
844+ idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
845+ if (key < 0 ) {
846+ return ;
847+ }
848+
849+ FAISS_THROW_IF_NOT_FMT (
850+ key < (idx_t )nlist,
851+ " Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n " ,
852+ key,
853+ ik,
854+ nlist);
868855
869- scanner->scan_codes_range (
870- list_size, scodes.get (), ids.get (), radius, qres);
856+ if (invlists->is_empty (key, inverted_list_context)) {
857+ return ;
858+ }
859+
860+ size_t list_size = 0 ;
861+ scanner->set_list (key, coarse_dis[i * cur_nprobe + ik]);
862+ if (invlists->use_iterator ) {
863+ std::unique_ptr<InvertedListsIterator> it (
864+ invlists->get_iterator (
865+ key, inverted_list_context));
866+
867+ scanner->iterate_codes_range (
868+ it.get (), radius, qres, list_size);
869+ } else {
870+ InvertedLists::ScopedCodes scodes (invlists, key);
871+ InvertedLists::ScopedIds ids (invlists, key);
872+ list_size = invlists->list_size (key);
873+
874+ scanner->scan_codes_range (
875+ list_size,
876+ scodes.get (),
877+ ids.get (),
878+ radius,
879+ qres);
880+ }
881+ nlistv++;
882+ ndis += list_size;
883+ } catch (const std::exception& e) {
884+ std::lock_guard<std::mutex> lock (exception_mutex);
885+ exception_string = demangle_cpp_symbol (typeid (e).name ()) +
886+ " " + e.what ();
887+ interrupt = true ;
871888 }
872- nlistv++;
873- ndis += list_size;
874- } catch (const std::exception& e) {
875- std::lock_guard<std::mutex> lock (exception_mutex);
876- exception_string =
877- demangle_cpp_symbol (typeid (e).name ()) + " " + e.what ();
878- interrupt = true ;
879- }
880- };
889+ };
881890
882- if (parallel_mode == 0 ) {
891+ if (parallel_mode == 0 ) {
883892#pragma omp for
884- for (idx_t i = 0 ; i < nx; i++) {
885- scanner->set_query (x + i * d);
893+ for (idx_t i = 0 ; i < nx; i++) {
894+ scanner->set_query (x + i * d);
886895
887- RangeQueryResult& qres = pres.new_result (i);
896+ RangeQueryResult& qres = pres.new_result (i);
888897
889- for (idx_t ik = 0 ; ik < cur_nprobe; ik++) {
890- scan_list_func (i, ik, qres);
898+ for (idx_t ik = 0 ; ik < cur_nprobe; ik++) {
899+ scan_list_func (i, ik, qres);
900+ }
891901 }
892- }
893902
894- } else if (parallel_mode == 1 ) {
895- for (idx_t i = 0 ; i < nx; i++) {
896- scanner->set_query (x + i * d);
903+ } else if (parallel_mode == 1 ) {
904+ for (idx_t i = 0 ; i < nx; i++) {
905+ scanner->set_query (x + i * d);
897906
898- RangeQueryResult& qres = pres.new_result (i);
907+ RangeQueryResult& qres = pres.new_result (i);
899908
900909#pragma omp for schedule(dynamic)
901- for (int64_t ik = 0 ; ik < cur_nprobe; ik++) {
902- scan_list_func (i, ik, qres);
910+ for (int64_t ik = 0 ; ik < cur_nprobe; ik++) {
911+ scan_list_func (i, ik, qres);
912+ }
903913 }
904- }
905- } else if (parallel_mode == 2 ) {
906- RangeQueryResult* qres = nullptr ;
914+ } else if (parallel_mode == 2 ) {
915+ RangeQueryResult* qres = nullptr ;
907916
908917#pragma omp for schedule(dynamic)
909- for (idx_t iik = 0 ; iik < nx * (idx_t )cur_nprobe; iik++) {
910- idx_t i = iik / (idx_t )cur_nprobe;
911- idx_t ik = iik % (idx_t )cur_nprobe;
912- if (qres == nullptr || qres->qno != i) {
913- qres = &pres.new_result (i);
914- scanner->set_query (x + i * d);
918+ for (idx_t iik = 0 ; iik < nx * (idx_t )cur_nprobe; iik++) {
919+ idx_t i = iik / (idx_t )cur_nprobe;
920+ idx_t ik = iik % (idx_t )cur_nprobe;
921+ if (qres == nullptr || qres->qno != i) {
922+ qres = &pres.new_result (i);
923+ scanner->set_query (x + i * d);
924+ }
925+ scan_list_func (i, ik, *qres);
915926 }
916- scan_list_func (i, ik, *qres);
927+ } else {
928+ FAISS_THROW_FMT (
929+ " parallel_mode %d not supported\n " , parallel_mode);
917930 }
918- } else {
919- FAISS_THROW_FMT (" parallel_mode %d not supported\n " , parallel_mode);
920- }
921- if (parallel_mode == 0 ) {
922- pres.finalize ();
923- } else {
931+ if (parallel_mode == 0 ) {
932+ pres.finalize ();
933+ } else {
924934#pragma omp barrier
925935#pragma omp single
926- RangeSearchPartialResult::merge (all_pres, false );
936+ RangeSearchPartialResult::merge (all_pres, false );
927937#pragma omp barrier
938+ }
939+ } catch (const std::exception& e) {
940+ std::lock_guard<std::mutex> lock (exception_mutex);
941+ exception_string =
942+ demangle_cpp_symbol (typeid (e).name ()) + " " + e.what ();
943+ interrupt = true ;
928944 }
929945 }
930946
0 commit comments