Skip to content

Commit 9d6b2e7

Browse files
scsiguymeta-codesync[bot]
authored andcommitted
Additional input validation for index deserialization (#4899)
Summary: Pull Request resolved: #4899 Add bounds checks when reading index data from untrusted byte streams. Five new FAISS_THROW_IF_NOT_FMT guards reject invalid values early during deserialization: - ProductAdditiveQuantizer: nsplits must be > 0 - ScalarQuantizer: qtype must be within the valid QuantizerType enum range - NSG: R (max out-degree) must be > 0 - IndexPreTransform: VectorTransform chain length must be >= 0 - IndexBinaryMultiHash: nhash must be > 0 Each check includes a descriptive error message with the offending value. Without these checks, invalid data could cause undefined behavior such as zero-size allocations, out-of-range enum casts, or negative loop bounds. Reviewed By: mnorris11 Differential Revision: D95968069 fbshipit-source-id: b9a3e88a01d3427614fa0e027fb233f11ee8cf2e
1 parent 734751a commit 9d6b2e7

2 files changed

Lines changed: 110 additions & 1 deletion

File tree

faiss/impl/index_read.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,10 @@ static void read_ProductAdditiveQuantizer(
553553
IOReader* f) {
554554
read_AdditiveQuantizer(paq, f);
555555
READ1(paq.nsplits);
556+
FAISS_THROW_IF_NOT_FMT(
557+
paq.nsplits > 0,
558+
"invalid ProductAdditiveQuantizer nsplits %zd (must be > 0)",
559+
paq.nsplits);
556560
}
557561

558562
static void read_ProductResidualQuantizer(
@@ -581,7 +585,14 @@ static void read_ProductLocalSearchQuantizer(
581585
}
582586

583587
void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f) {
584-
READ1(ivsc->qtype);
588+
int qtype_int;
589+
READ1(qtype_int);
590+
FAISS_THROW_IF_NOT_FMT(
591+
qtype_int >= ScalarQuantizer::QT_8bit &&
592+
qtype_int <= ScalarQuantizer::QT_8bit_direct_signed,
593+
"invalid ScalarQuantizer qtype %d",
594+
qtype_int);
595+
ivsc->qtype = static_cast<ScalarQuantizer::QuantizerType>(qtype_int);
585596
READ1(ivsc->rangestat);
586597
READ1(ivsc->rangestat_arg);
587598
READ1(ivsc->d);
@@ -724,6 +735,7 @@ static void read_HNSW(HNSW& hnsw, IOReader* f) {
724735
static void read_NSG(NSG& nsg, IOReader* f) {
725736
READ1(nsg.ntotal);
726737
READ1(nsg.R);
738+
FAISS_THROW_IF_NOT_FMT(nsg.R > 0, "invalid NSG R %d (must be > 0)", nsg.R);
727739
READ1(nsg.L);
728740
READ1(nsg.C);
729741
READ1(nsg.search_L);
@@ -1352,6 +1364,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
13521364
} else {
13531365
READ1(nt);
13541366
}
1367+
FAISS_THROW_IF_NOT_FMT(
1368+
nt >= 0,
1369+
"invalid VectorTransform chain length %d (must be >= 0)",
1370+
nt);
13551371
for (int i = 0; i < nt; i++) {
13561372
ixpt->chain.push_back(read_VectorTransform(f));
13571373
}
@@ -2043,6 +2059,10 @@ std::unique_ptr<IndexBinary> read_index_binary_up(IOReader* f, int io_flags) {
20432059
idxmh->own_fields = true;
20442060
READ1(idxmh->b);
20452061
READ1(idxmh->nhash);
2062+
FAISS_THROW_IF_NOT_FMT(
2063+
idxmh->nhash > 0,
2064+
"invalid IndexBinaryMultiHash nhash %d (must be > 0)",
2065+
idxmh->nhash);
20462066
READ1(idxmh->nflip);
20472067
idxmh->maps.resize(idxmh->nhash);
20482068
for (int i = 0; i < idxmh->nhash; i++) {

tests/test_read_index_deserialize.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,95 @@ TEST(ReadIndexDeserialize, BinaryHashEmptyInvlistBuffer) {
440440
expect_binary_read_throws_with(buf, "binary hash invlists");
441441
}
442442

443+
// -----------------------------------------------------------------------
444+
// Test: NSG with R=0 triggers the R > 0 validation.
445+
// -----------------------------------------------------------------------
446+
TEST(ReadIndexDeserialize, NSGNegativeR) {
447+
// "INSf" format: fourcc + index_header + GK + build_type +
448+
// nndescent_S/R/L/iter + read_NSG(ntotal, R, ...)
449+
std::vector<uint8_t> buf;
450+
push_fourcc(buf, "INSf");
451+
push_index_header(buf, /*d=*/4, /*ntotal=*/0);
452+
push_val<int>(buf, 0); // GK
453+
push_val<int>(buf, 0); // build_type
454+
push_val<int>(buf, 10); // nndescent_S
455+
push_val<int>(buf, 10); // nndescent_R
456+
push_val<int>(buf, 10); // nndescent_L
457+
push_val<int>(buf, 1); // nndescent_iter
458+
// read_NSG fields:
459+
push_val<int>(buf, 0); // ntotal
460+
push_val<int>(buf, -1); // R = -1 (invalid)
461+
462+
expect_read_throws_with(buf, "invalid NSG R");
463+
}
464+
465+
// -----------------------------------------------------------------------
466+
// Test: ScalarQuantizer with out-of-range qtype throws.
467+
// -----------------------------------------------------------------------
468+
TEST(ReadIndexDeserialize, ScalarQuantizerInvalidQtype) {
469+
// "IxSQ" format: fourcc + index_header + read_ScalarQuantizer(qtype, ...)
470+
std::vector<uint8_t> buf;
471+
push_fourcc(buf, "IxSQ");
472+
push_index_header(buf, /*d=*/4, /*ntotal=*/0);
473+
// ScalarQuantizer fields:
474+
push_val<int>(buf, 99); // qtype = 99 (out of range)
475+
476+
expect_read_throws_with(buf, "qtype");
477+
}
478+
479+
// -----------------------------------------------------------------------
480+
// Test: ProductAdditiveQuantizer with nsplits=0 throws.
481+
// -----------------------------------------------------------------------
482+
TEST(ReadIndexDeserialize, ProductAdditiveQuantizerZeroNsplits) {
483+
// "IxPR" format: fourcc + index_header +
484+
// read_ProductResidualQuantizer(read_ProductAdditiveQuantizer(
485+
// read_AdditiveQuantizer(...) + nsplits))
486+
std::vector<uint8_t> buf;
487+
push_fourcc(buf, "IxPR");
488+
push_index_header(buf, /*d=*/4, /*ntotal=*/0);
489+
// AdditiveQuantizer fields:
490+
push_val<size_t>(buf, 4); // d
491+
push_val<size_t>(buf, 1); // M
492+
push_vector<size_t>(buf, {8}); // nbits (1 element matching M=1)
493+
push_val<bool>(buf, true); // is_trained
494+
push_vector<float>(buf, {}); // codebooks (empty)
495+
push_val<int>(buf, 0); // search_type = ST_decompress
496+
push_val<float>(buf, 0.0f); // norm_min
497+
push_val<float>(buf, 1.0f); // norm_max
498+
// ProductAdditiveQuantizer field:
499+
push_val<size_t>(buf, 0); // nsplits = 0 (invalid)
500+
501+
expect_read_throws_with(buf, "nsplits");
502+
}
503+
504+
// -----------------------------------------------------------------------
505+
// Test: PreTransform with negative chain length throws.
506+
// -----------------------------------------------------------------------
507+
TEST(ReadIndexDeserialize, PreTransformNegativeChainLength) {
508+
// "IxPT" format: fourcc + index_header + nt + VT chain + nested index
509+
std::vector<uint8_t> buf;
510+
push_fourcc(buf, "IxPT");
511+
push_index_header(buf, /*d=*/4, /*ntotal=*/0);
512+
push_val<int>(buf, -1); // nt = -1 (invalid)
513+
514+
expect_read_throws_with(buf, "chain length");
515+
}
516+
517+
// -----------------------------------------------------------------------
518+
// Test: IndexBinaryMultiHash with nhash=0 throws.
519+
// -----------------------------------------------------------------------
520+
TEST(ReadIndexDeserialize, BinaryMultiHashZeroNhash) {
521+
std::vector<uint8_t> buf;
522+
push_fourcc(buf, "IBHm");
523+
push_binary_index_header(buf, /*d=*/16, /*ntotal=*/0);
524+
// Nested IBxF storage (ntotal=0 matches outer)
525+
push_minimal_binary_flat(buf, /*d=*/16);
526+
push_val<int>(buf, 4); // b
527+
push_val<int>(buf, 0); // nhash = 0 (invalid)
528+
529+
expect_binary_read_throws_with(buf, "nhash");
530+
}
531+
443532
// -----------------------------------------------------------------------
444533
// Test: IndexBinaryHash with b=0 triggers the b > 0 validation.
445534
// Without this check, BitstringReader::read(0) would silently produce

0 commit comments

Comments
 (0)