Skip to content

Commit 5361fb5

Browse files
authored
Merge branch 'main' into avx512_vnni_opt
2 parents bf176d5 + 5679d3a commit 5361fb5

5 files changed

Lines changed: 189 additions & 80 deletions

File tree

faiss/IndexHNSW.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ void IndexHNSWCagra::search(
10421042
std::vector<storage_idx_t> nearest(n);
10431043
std::vector<float> nearest_d(n);
10441044

1045-
#pragma omp for
1045+
#pragma omp parallel for
10461046
for (idx_t i = 0; i < n; i++) {
10471047
std::unique_ptr<DistanceComputer> dis(
10481048
storage_distance_computer(this->storage));

faiss/IndexIVF.cpp

Lines changed: 92 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -828,103 +828,119 @@ void IndexIVF::range_search_preassigned(
828828

829829
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
830830
{
831-
RangeSearchPartialResult pres(result);
832-
std::unique_ptr<InvertedListScanner> scanner(
833-
get_InvertedListScanner(store_pairs, sel, params));
834-
FAISS_THROW_IF_NOT(scanner.get());
835-
all_pres[omp_get_thread_num()] = &pres;
836-
837-
// prepare the list scanning function
838-
839-
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
840-
idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
841-
if (key < 0) {
842-
return;
843-
}
844-
FAISS_THROW_IF_NOT_FMT(
845-
key < (idx_t)nlist,
846-
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
847-
key,
848-
ik,
849-
nlist);
850-
851-
if (invlists->is_empty(key, inverted_list_context)) {
852-
return;
853-
}
831+
try {
832+
RangeSearchPartialResult pres(result);
833+
std::unique_ptr<InvertedListScanner> scanner(
834+
get_InvertedListScanner(store_pairs, sel, params));
835+
FAISS_THROW_IF_NOT(scanner.get());
836+
all_pres[omp_get_thread_num()] = &pres;
854837

855-
try {
856-
size_t list_size = 0;
857-
scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
858-
if (invlists->use_iterator) {
859-
std::unique_ptr<InvertedListsIterator> it(
860-
invlists->get_iterator(key, inverted_list_context));
838+
// prepare the list scanning function
861839

862-
scanner->iterate_codes_range(
863-
it.get(), radius, qres, list_size);
864-
} else {
865-
InvertedLists::ScopedCodes scodes(invlists, key);
866-
InvertedLists::ScopedIds ids(invlists, key);
867-
list_size = invlists->list_size(key);
840+
auto scan_list_func = [&](size_t i,
841+
size_t ik,
842+
RangeQueryResult& qres) {
843+
try {
844+
idx_t key = keys[i * cur_nprobe + ik]; /* select the list */
845+
if (key < 0) {
846+
return;
847+
}
848+
849+
FAISS_THROW_IF_NOT_FMT(
850+
key < (idx_t)nlist,
851+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
852+
key,
853+
ik,
854+
nlist);
868855

869-
scanner->scan_codes_range(
870-
list_size, scodes.get(), ids.get(), radius, qres);
856+
if (invlists->is_empty(key, inverted_list_context)) {
857+
return;
858+
}
859+
860+
size_t list_size = 0;
861+
scanner->set_list(key, coarse_dis[i * cur_nprobe + ik]);
862+
if (invlists->use_iterator) {
863+
std::unique_ptr<InvertedListsIterator> it(
864+
invlists->get_iterator(
865+
key, inverted_list_context));
866+
867+
scanner->iterate_codes_range(
868+
it.get(), radius, qres, list_size);
869+
} else {
870+
InvertedLists::ScopedCodes scodes(invlists, key);
871+
InvertedLists::ScopedIds ids(invlists, key);
872+
list_size = invlists->list_size(key);
873+
874+
scanner->scan_codes_range(
875+
list_size,
876+
scodes.get(),
877+
ids.get(),
878+
radius,
879+
qres);
880+
}
881+
nlistv++;
882+
ndis += list_size;
883+
} catch (const std::exception& e) {
884+
std::lock_guard<std::mutex> lock(exception_mutex);
885+
exception_string = demangle_cpp_symbol(typeid(e).name()) +
886+
" " + e.what();
887+
interrupt = true;
871888
}
872-
nlistv++;
873-
ndis += list_size;
874-
} catch (const std::exception& e) {
875-
std::lock_guard<std::mutex> lock(exception_mutex);
876-
exception_string =
877-
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
878-
interrupt = true;
879-
}
880-
};
889+
};
881890

882-
if (parallel_mode == 0) {
891+
if (parallel_mode == 0) {
883892
#pragma omp for
884-
for (idx_t i = 0; i < nx; i++) {
885-
scanner->set_query(x + i * d);
893+
for (idx_t i = 0; i < nx; i++) {
894+
scanner->set_query(x + i * d);
886895

887-
RangeQueryResult& qres = pres.new_result(i);
896+
RangeQueryResult& qres = pres.new_result(i);
888897

889-
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
890-
scan_list_func(i, ik, qres);
898+
for (idx_t ik = 0; ik < cur_nprobe; ik++) {
899+
scan_list_func(i, ik, qres);
900+
}
891901
}
892-
}
893902

894-
} else if (parallel_mode == 1) {
895-
for (idx_t i = 0; i < nx; i++) {
896-
scanner->set_query(x + i * d);
903+
} else if (parallel_mode == 1) {
904+
for (idx_t i = 0; i < nx; i++) {
905+
scanner->set_query(x + i * d);
897906

898-
RangeQueryResult& qres = pres.new_result(i);
907+
RangeQueryResult& qres = pres.new_result(i);
899908

900909
#pragma omp for schedule(dynamic)
901-
for (int64_t ik = 0; ik < cur_nprobe; ik++) {
902-
scan_list_func(i, ik, qres);
910+
for (int64_t ik = 0; ik < cur_nprobe; ik++) {
911+
scan_list_func(i, ik, qres);
912+
}
903913
}
904-
}
905-
} else if (parallel_mode == 2) {
906-
RangeQueryResult* qres = nullptr;
914+
} else if (parallel_mode == 2) {
915+
RangeQueryResult* qres = nullptr;
907916

908917
#pragma omp for schedule(dynamic)
909-
for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
910-
idx_t i = iik / (idx_t)cur_nprobe;
911-
idx_t ik = iik % (idx_t)cur_nprobe;
912-
if (qres == nullptr || qres->qno != i) {
913-
qres = &pres.new_result(i);
914-
scanner->set_query(x + i * d);
918+
for (idx_t iik = 0; iik < nx * (idx_t)cur_nprobe; iik++) {
919+
idx_t i = iik / (idx_t)cur_nprobe;
920+
idx_t ik = iik % (idx_t)cur_nprobe;
921+
if (qres == nullptr || qres->qno != i) {
922+
qres = &pres.new_result(i);
923+
scanner->set_query(x + i * d);
924+
}
925+
scan_list_func(i, ik, *qres);
915926
}
916-
scan_list_func(i, ik, *qres);
927+
} else {
928+
FAISS_THROW_FMT(
929+
"parallel_mode %d not supported\n", parallel_mode);
917930
}
918-
} else {
919-
FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
920-
}
921-
if (parallel_mode == 0) {
922-
pres.finalize();
923-
} else {
931+
if (parallel_mode == 0) {
932+
pres.finalize();
933+
} else {
924934
#pragma omp barrier
925935
#pragma omp single
926-
RangeSearchPartialResult::merge(all_pres, false);
936+
RangeSearchPartialResult::merge(all_pres, false);
927937
#pragma omp barrier
938+
}
939+
} catch (const std::exception& e) {
940+
std::lock_guard<std::mutex> lock(exception_mutex);
941+
exception_string =
942+
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
943+
interrupt = true;
928944
}
929945
}
930946

faiss/impl/NSG.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <faiss/impl/NSG.h>
99

1010
#include <algorithm>
11+
#include <limits>
1112
#include <memory>
1213
#include <mutex>
1314
#include <stack>
@@ -113,7 +114,6 @@ using namespace nsg;
113114
NSG::NSG(int R_in) : R(R_in), rng(0x0903) {
114115
L = R + 32;
115116
C = R + 100;
116-
srand(0x1998);
117117
}
118118

119119
void NSG::search(
@@ -179,7 +179,7 @@ void NSG::build(
179179
is_built = true;
180180

181181
if (verbose) {
182-
int max = 0, min = 1e6;
182+
int max = 0, min = std::numeric_limits<int>::max();
183183
double avg = 0;
184184

185185
for (int i = 0; i < n; i++) {
@@ -265,7 +265,7 @@ void NSG::search_on_graph(
265265
continue;
266266
}
267267

268-
init_ids[i] = id;
268+
init_ids[num_ids] = id;
269269
vt.set(id);
270270
num_ids += 1;
271271
}
@@ -397,10 +397,23 @@ void NSG::sync_prune(
397397

398398
std::vector<Node> result;
399399

400+
if (pool.empty()) {
401+
for (int i = 0; i < R; i++) {
402+
graph.at(q, i).id = EMPTY_ID;
403+
}
404+
return;
405+
}
406+
400407
int start = 0;
401408
if (pool[start].id == q) {
402409
start++;
403410
}
411+
if (start >= static_cast<int>(pool.size())) {
412+
for (int i = 0; i < R; i++) {
413+
graph.at(q, i).id = EMPTY_ID;
414+
}
415+
return;
416+
}
404417
result.push_back(pool[start]);
405418

406419
while (result.size() < static_cast<size_t>(R) &&

tests/test_NSG_compressed_graph.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8+
#include <faiss/IndexFlat.h>
89
#include <faiss/IndexNSG.h>
10+
#include <faiss/impl/VisitedTable.h>
911
#include <faiss/utils/hamming.h>
1012
#include <faiss/utils/random.h>
1113
#include <gtest/gtest.h>
@@ -83,3 +85,44 @@ TEST(NSGCompressed, test_compressed) {
8385
EXPECT_EQ(Iref, I);
8486
EXPECT_EQ(Dref, D);
8587
}
88+
89+
// Regression test for sync_prune out-of-bounds bug.
90+
//
91+
// With ntotal=1 and L=1, search_on_graph produces pool = [{id:0}].
92+
// sync_prune(q=0): pool[0].id == q → start++ → start == pool.size().
93+
// Old code: pool[start] is out-of-bounds → undefined behavior.
94+
// Fix: guard returns early and fills graph row with EMPTY_ID.
95+
//
96+
// Calls NSG::build() directly (bypassing IndexNSG::check_knn_graph)
97+
// to reach the edge case with ntotal=1.
98+
TEST(NSGBugs, SyncPruneSingleNode) {
99+
constexpr int d = 4;
100+
constexpr int R = 1;
101+
102+
faiss::IndexFlat storage(d);
103+
float vec[] = {1.0f, 2.0f, 3.0f, 4.0f};
104+
storage.add(1, vec);
105+
106+
faiss::idx_t knn_data[] = {-1};
107+
faiss::nsg::Graph<faiss::idx_t> knn_graph(knn_data, 1, 1);
108+
109+
faiss::NSG nsg_obj(R);
110+
nsg_obj.L = 1;
111+
112+
// Old code crashes here. Fixed code handles it.
113+
ASSERT_NO_THROW(nsg_obj.build(&storage, 1, knn_graph, false));
114+
EXPECT_TRUE(nsg_obj.is_built);
115+
EXPECT_EQ(nsg_obj.enterpoint, 0);
116+
117+
// Search returns the only node
118+
nsg_obj.search_L = 1;
119+
faiss::VisitedTable vt(1);
120+
auto dis = std::unique_ptr<faiss::DistanceComputer>(
121+
faiss::nsg::storage_distance_computer(&storage));
122+
dis->set_query(vec);
123+
124+
faiss::idx_t label = -1;
125+
float distance = -1;
126+
nsg_obj.search(*dis, 1, &label, &distance, vt);
127+
EXPECT_EQ(label, 0);
128+
}

tests/test_ivf_index.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <omp.h>
99
#include <algorithm>
1010
#include <cstddef>
11+
#include <limits>
1112
#include <map>
1213
#include <random>
1314
#include <set>
@@ -16,6 +17,7 @@
1617

1718
#include <faiss/IndexFlat.h>
1819
#include <faiss/IndexIVFFlat.h>
20+
#include <faiss/impl/AuxIndexStructures.h>
1921
#include <faiss/impl/FaissAssert.h>
2022

2123
namespace {
@@ -292,3 +294,38 @@ TEST(IVF, search_preassigned_out_of_range_key) {
292294
false),
293295
faiss::FaissException);
294296
}
297+
298+
// Test: range_search_preassigned with out-of-range keys throws a catchable
299+
// FaissException instead of calling std::terminate from an uncaught
300+
// exception inside the OpenMP parallel region.
301+
TEST(IVF, range_search_preassigned_out_of_range_key) {
302+
int d = 4;
303+
int nlist = 2;
304+
faiss::IndexFlatL2 quantizer(d);
305+
faiss::IndexIVFFlat idx(&quantizer, d, nlist);
306+
idx.own_fields = false;
307+
308+
std::vector<float> train_data(nlist * d, 0.0f);
309+
for (int i = 0; i < nlist * d; i++) {
310+
train_data[i] = static_cast<float>(i);
311+
}
312+
idx.train(nlist, train_data.data());
313+
idx.add(nlist, train_data.data());
314+
315+
std::vector<float> xq(d, 1.0f);
316+
faiss::RangeSearchResult result(1);
317+
318+
faiss::idx_t bad_key = nlist; // out of range
319+
float coarse_dis = 0.0f;
320+
321+
EXPECT_THROW(
322+
idx.range_search_preassigned(
323+
1,
324+
xq.data(),
325+
std::numeric_limits<float>::max(),
326+
&bad_key,
327+
&coarse_dis,
328+
&result,
329+
false),
330+
faiss::FaissException);
331+
}

0 commit comments

Comments
 (0)