Skip to content

Commit e7315d2

Browse files
scsiguymeta-codesync[bot]
authored andcommitted
Support limits on index deserialization loops — useful for tests (#4902)
Summary: Pull Request resolved: #4902 When enabled, throw a FAISS exception when index deserialization loops driven by read data fields (nlist, nsplits, VT chain length, nhash, etc.) exceed a configured limit. This can be set by input validation tests to prevent wasting time on pathological, OOM inducting, inputs. Default: 0 (no limit). Reviewed By: mnorris11 Differential Revision: D96016580
1 parent eddea2a commit e7315d2

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

faiss/impl/index_read.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@
7878

7979
namespace faiss {
8080

81+
namespace {
82+
size_t deserialization_loop_limit_ = 0;
83+
} // namespace
84+
85+
size_t get_deserialization_loop_limit() {
86+
return deserialization_loop_limit_;
87+
}
88+
89+
void set_deserialization_loop_limit(size_t value) {
90+
deserialization_loop_limit_ = value;
91+
}
92+
93+
#define FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(val, field_name) \
94+
do { \
95+
auto limit_ = get_deserialization_loop_limit(); \
96+
if (limit_ > 0) { \
97+
FAISS_THROW_IF_NOT_FMT( \
98+
static_cast<size_t>(val) <= limit_, \
99+
"%s=%zd exceeds deserialization_loop_limit" \
100+
" of %zd", \
101+
field_name, \
102+
static_cast<size_t>(val), \
103+
limit_); \
104+
} \
105+
} while (0)
106+
81107
/*************************************************************
82108
* Mmap-ing and viewing facilities
83109
**************************************************************/
@@ -371,6 +397,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
371397
} else if (h == fourcc("ilpn") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
372398
size_t nlist, code_size, n_levels;
373399
READ1(nlist);
400+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nlist, "ilpn nlist");
374401
READ1(code_size);
375402
READ1(n_levels);
376403
auto ailp = std::make_unique<ArrayInvertedListsPanorama>(
@@ -400,6 +427,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
400427
} else if (h == fourcc("ilar") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
401428
auto ails = std::make_unique<ArrayInvertedLists>(0, 0);
402429
READ1(ails->nlist);
430+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ails->nlist, "ilar nlist");
403431
READ1(ails->code_size);
404432
ails->ids.resize(ails->nlist);
405433
ails->codes.resize(ails->nlist);
@@ -430,6 +458,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
430458
int h2 = (io_flags & 0xffff0000) | (fourcc("il__") & 0x0000ffff);
431459
size_t nlist, code_size;
432460
READ1(nlist);
461+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nlist, "ilar skip nlist");
433462
READ1(code_size);
434463
std::vector<size_t> sizes(nlist);
435464
read_ArrayInvertedLists_sizes(f, sizes);
@@ -557,6 +586,7 @@ static void read_ProductAdditiveQuantizer(
557586
paq.nsplits > 0,
558587
"invalid ProductAdditiveQuantizer nsplits %zd (must be > 0)",
559588
paq.nsplits);
589+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(paq.nsplits, "nsplits");
560590
}
561591

562592
static void read_ProductResidualQuantizer(
@@ -735,6 +765,8 @@ static void read_HNSW(HNSW& hnsw, IOReader* f) {
735765
static void read_NSG(NSG& nsg, IOReader* f) {
736766
READ1(nsg.ntotal);
737767
READ1(nsg.R);
768+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nsg.ntotal, "nsg.ntotal");
769+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nsg.R, "nsg.R");
738770
FAISS_THROW_IF_NOT_FMT(nsg.R > 0, "invalid NSG R %d (must be > 0)", nsg.R);
739771
READ1(nsg.L);
740772
READ1(nsg.C);
@@ -858,6 +890,7 @@ void read_ivf_header(
858890
std::vector<std::vector<idx_t>>* ids) {
859891
read_index_header(*ivf, f);
860892
READ1(ivf->nlist);
893+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ivf->nlist, "nlist");
861894
READ1(ivf->nprobe);
862895
ivf->quantizer = read_index(f);
863896
ivf->own_fields = true;
@@ -1368,6 +1401,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
13681401
nt >= 0,
13691402
"invalid VectorTransform chain length %d (must be >= 0)",
13701403
nt);
1404+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(
1405+
nt, "VectorTransform chain length");
13711406
for (int i = 0; i < nt; i++) {
13721407
ixpt->chain.push_back(read_VectorTransform(f));
13731408
}
@@ -1898,6 +1933,7 @@ static void read_binary_ivf_header(
18981933
std::vector<std::vector<idx_t>>* ids = nullptr) {
18991934
read_index_binary_header(ivf, f);
19001935
READ1(ivf.nlist);
1936+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ivf.nlist, "nlist");
19011937
READ1(ivf.nprobe);
19021938
ivf.quantizer = read_index_binary(f);
19031939
ivf.own_fields = true;
@@ -1915,6 +1951,7 @@ static void read_binary_hash_invlists(
19151951
IOReader* f) {
19161952
size_t sz;
19171953
READ1(sz);
1954+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(sz, "binary hash invlists sz");
19181955
int il_nbit = 0;
19191956
READ1(il_nbit);
19201957
FAISS_THROW_IF_NOT_FMT(
@@ -1965,6 +2002,7 @@ static void read_binary_multi_hash_map(
19652002
size_t sz;
19662003
READ1(id_bits);
19672004
READ1(sz);
2005+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(sz, "multi hash map sz");
19682006
std::vector<uint8_t> buf;
19692007
READVECTOR(buf);
19702008
size_t nbit = add_no_overflow(
@@ -2063,6 +2101,7 @@ std::unique_ptr<IndexBinary> read_index_binary_up(IOReader* f, int io_flags) {
20632101
idxmh->nhash > 0,
20642102
"invalid IndexBinaryMultiHash nhash %d (must be > 0)",
20652103
idxmh->nhash);
2104+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(idxmh->nhash, "nhash");
20662105
READ1(idxmh->nflip);
20672106
idxmh->maps.resize(idxmh->nhash);
20682107
for (int i = 0; i < idxmh->nhash; i++) {

faiss/index_io.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
113113
IOReader* reader,
114114
int io_flags = 0);
115115

116+
// Returns the current deserialization loop limit.
117+
// When nonzero, deserialization rejects loop-driving fields (nlist,
118+
// nsplits, VT chain length, nhash, etc.) that exceed this value.
119+
// Default: 0 (no limit).
120+
size_t get_deserialization_loop_limit();
121+
122+
// Sets the deserialization loop limit.
123+
// NOT thread-safe: set before any concurrent deserialization calls
124+
// and do not modify while deserialization is in progress on other threads.
125+
void set_deserialization_loop_limit(size_t value);
126+
116127
} // namespace faiss
117128

118129
#endif

0 commit comments

Comments
 (0)