7878
7979namespace faiss {
8080
81+ namespace {
82+ size_t deserialization_loop_limit_ = 0 ;
83+ } // namespace
84+
85+ size_t get_deserialization_loop_limit () {
86+ return deserialization_loop_limit_;
87+ }
88+
89+ void set_deserialization_loop_limit (size_t value) {
90+ deserialization_loop_limit_ = value;
91+ }
92+
93+ #define FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (val, field_name ) \
94+ do { \
95+ auto limit_ = get_deserialization_loop_limit (); \
96+ if (limit_ > 0 ) { \
97+ FAISS_THROW_IF_NOT_FMT ( \
98+ static_cast <size_t >(val) <= limit_, \
99+ " %s=%zd exceeds deserialization_loop_limit" \
100+ " of %zd" , \
101+ field_name, \
102+ static_cast <size_t >(val), \
103+ limit_); \
104+ } \
105+ } while (0 )
106+
81107/* ************************************************************
82108 * Mmap-ing and viewing facilities
83109 **************************************************************/
@@ -371,6 +397,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
371397 } else if (h == fourcc (" ilpn" ) && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
372398 size_t nlist, code_size, n_levels;
373399 READ1 (nlist);
400+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (nlist, " ilpn nlist" );
374401 READ1 (code_size);
375402 READ1 (n_levels);
376403 auto ailp = std::make_unique<ArrayInvertedListsPanorama>(
@@ -400,6 +427,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
400427 } else if (h == fourcc (" ilar" ) && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
401428 auto ails = std::make_unique<ArrayInvertedLists>(0 , 0 );
402429 READ1 (ails->nlist );
430+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (ails->nlist , " ilar nlist" );
403431 READ1 (ails->code_size );
404432 ails->ids .resize (ails->nlist );
405433 ails->codes .resize (ails->nlist );
@@ -430,6 +458,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
430458 int h2 = (io_flags & 0xffff0000 ) | (fourcc (" il__" ) & 0x0000ffff );
431459 size_t nlist, code_size;
432460 READ1 (nlist);
461+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (nlist, " ilar skip nlist" );
433462 READ1 (code_size);
434463 std::vector<size_t > sizes (nlist);
435464 read_ArrayInvertedLists_sizes (f, sizes);
@@ -557,6 +586,7 @@ static void read_ProductAdditiveQuantizer(
557586 paq.nsplits > 0 ,
558587 " invalid ProductAdditiveQuantizer nsplits %zd (must be > 0)" ,
559588 paq.nsplits );
589+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (paq.nsplits , " nsplits" );
560590}
561591
562592static void read_ProductResidualQuantizer (
@@ -735,6 +765,8 @@ static void read_HNSW(HNSW& hnsw, IOReader* f) {
735765static void read_NSG (NSG& nsg, IOReader* f) {
736766 READ1 (nsg.ntotal );
737767 READ1 (nsg.R );
768+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (nsg.ntotal , " nsg.ntotal" );
769+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (nsg.R , " nsg.R" );
738770 FAISS_THROW_IF_NOT_FMT (nsg.R > 0 , " invalid NSG R %d (must be > 0)" , nsg.R );
739771 READ1 (nsg.L );
740772 READ1 (nsg.C );
@@ -858,6 +890,7 @@ void read_ivf_header(
858890 std::vector<std::vector<idx_t >>* ids) {
859891 read_index_header (*ivf, f);
860892 READ1 (ivf->nlist );
893+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (ivf->nlist , " nlist" );
861894 READ1 (ivf->nprobe );
862895 ivf->quantizer = read_index (f);
863896 ivf->own_fields = true ;
@@ -1368,6 +1401,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
13681401 nt >= 0 ,
13691402 " invalid VectorTransform chain length %d (must be >= 0)" ,
13701403 nt);
1404+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (
1405+ nt, " VectorTransform chain length" );
13711406 for (int i = 0 ; i < nt; i++) {
13721407 ixpt->chain .push_back (read_VectorTransform (f));
13731408 }
@@ -1898,6 +1933,7 @@ static void read_binary_ivf_header(
18981933 std::vector<std::vector<idx_t >>* ids = nullptr ) {
18991934 read_index_binary_header (ivf, f);
19001935 READ1 (ivf.nlist );
1936+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (ivf.nlist , " nlist" );
19011937 READ1 (ivf.nprobe );
19021938 ivf.quantizer = read_index_binary (f);
19031939 ivf.own_fields = true ;
@@ -1915,6 +1951,7 @@ static void read_binary_hash_invlists(
19151951 IOReader* f) {
19161952 size_t sz;
19171953 READ1 (sz);
1954+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (sz, " binary hash invlists sz" );
19181955 int il_nbit = 0 ;
19191956 READ1 (il_nbit);
19201957 FAISS_THROW_IF_NOT_FMT (
@@ -1965,6 +2002,7 @@ static void read_binary_multi_hash_map(
19652002 size_t sz;
19662003 READ1 (id_bits);
19672004 READ1 (sz);
2005+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (sz, " multi hash map sz" );
19682006 std::vector<uint8_t > buf;
19692007 READVECTOR (buf);
19702008 size_t nbit = add_no_overflow (
@@ -2063,6 +2101,7 @@ std::unique_ptr<IndexBinary> read_index_binary_up(IOReader* f, int io_flags) {
20632101 idxmh->nhash > 0 ,
20642102 " invalid IndexBinaryMultiHash nhash %d (must be > 0)" ,
20652103 idxmh->nhash );
2104+ FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT (idxmh->nhash , " nhash" );
20662105 READ1 (idxmh->nflip );
20672106 idxmh->maps .resize (idxmh->nhash );
20682107 for (int i = 0 ; i < idxmh->nhash ; i++) {
0 commit comments