Skip to content

Commit 6c70444

Browse files
scsiguymeta-codesync[bot]
authored andcommitted
Validate code_size during deserialization to prevent oversized allocations (#5151)
Summary: Pull Request resolved: #5151 Several index types read code_size directly from the serialized stream independently of the quantizer parameters that determine its correct value. When the stored code_size is corrupt but ntotal is 0, the existing consistency check (codes.size() == ntotal * code_size) passes trivially. A subsequent search then allocates (code_size * sizeof(float)) bytes in GenericFlatCodesDistanceComputer, which can trigger an OOM exception. Two layers of protection: 1. Cross-validate the deserialized code_size against the quantizer-derived value for all index types that read code_size from the stream: IndexResidualQuantizer, IndexLocalSearchQuantizer, IndexProductResidualQuantizer, IndexProductLocalSearchQuantizer, IndexIVFAdditiveQuantizer, IndexIVFScalarQuantizer, IndexLSH, and Index2Layer. The quantizer code_size is computed from validated parameters via set_derived_values() and is always authoritative. 2. For IndexLattice, where code_size is derived from constructor parameters (scale_nbit, lattice_nbit, nsq) rather than read from the stream, validate that code_size does not exceed the uncompressed vector size (d * sizeof(float)). IndexLattice is a lossy compressor, so its code_size must always be smaller than the uncompressed representation. A corrupt scale_nbit can overflow the total_nbit computation, producing a code_size that wraps to a huge value; this bound catches that before any allocation is attempted. Reviewed By: mnorris11 Differential Revision: D102360605 fbshipit-source-id: 1f0e7262a0e0e4566d7813b7da1bf0102e6fd9bf
1 parent 9dbb81c commit 6c70444

2 files changed

Lines changed: 283 additions & 0 deletions

File tree

faiss/impl/index_read.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,18 @@ static void validate_aq_dimension_match(
855855
idx_d);
856856
}
857857

858+
static void validate_code_size_match(
859+
size_t stored,
860+
size_t expected,
861+
const char* index_type) {
862+
FAISS_THROW_IF_NOT_FMT(
863+
stored == expected,
864+
"%s code_size mismatch: stored %zd vs derived %zd",
865+
index_type,
866+
stored,
867+
expected);
868+
}
869+
858870
static void read_ResidualQuantizer(
859871
ResidualQuantizer& rq,
860872
IOReader* f,
@@ -1493,6 +1505,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
14931505
READVECTOR(idxl->thresholds);
14941506
int code_size_i;
14951507
READ1(code_size_i);
1508+
FAISS_THROW_IF_NOT_FMT(
1509+
code_size_i >= 0,
1510+
"IndexLSH invalid code_size %d (must be >= 0)",
1511+
code_size_i);
14961512
idxl->code_size = code_size_i;
14971513
if (h == fourcc("IxHE")) {
14981514
FAISS_THROW_IF_NOT_FMT(
@@ -1503,6 +1519,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
15031519
// leak
15041520
idxl->code_size *= 8;
15051521
}
1522+
validate_code_size_match(
1523+
idxl->code_size, (idxl->nbits + 7) / 8, "IndexLSH");
15061524
{
15071525
// Read, dereference, discard.
15081526
auto sub_vt = read_VectorTransform_up(f);
@@ -1550,6 +1568,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
15501568
validate_aq_dimension_match(
15511569
idxr->rq, idxr->d, "IndexResidualQuantizer");
15521570
READ1(idxr->code_size);
1571+
validate_code_size_match(
1572+
idxr->code_size, idxr->rq.code_size, "IndexResidualQuantizer");
15531573
read_vector(idxr->codes, f);
15541574
FAISS_THROW_IF_NOT(
15551575
idxr->codes.size() == idxr->ntotal * idxr->code_size);
@@ -1561,6 +1581,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
15611581
validate_aq_dimension_match(
15621582
idxr->lsq, idxr->d, "IndexLocalSearchQuantizer");
15631583
READ1(idxr->code_size);
1584+
validate_code_size_match(
1585+
idxr->code_size,
1586+
idxr->lsq.code_size,
1587+
"IndexLocalSearchQuantizer");
15641588
read_vector(idxr->codes, f);
15651589
FAISS_THROW_IF_NOT(
15661590
idxr->codes.size() == idxr->ntotal * idxr->code_size);
@@ -1572,6 +1596,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
15721596
validate_aq_dimension_match(
15731597
idxpr->prq, idxpr->d, "IndexProductResidualQuantizer");
15741598
READ1(idxpr->code_size);
1599+
validate_code_size_match(
1600+
idxpr->code_size,
1601+
idxpr->prq.code_size,
1602+
"IndexProductResidualQuantizer");
15751603
read_vector(idxpr->codes, f);
15761604
FAISS_THROW_IF_NOT(
15771605
idxpr->codes.size() == idxpr->ntotal * idxpr->code_size);
@@ -1583,6 +1611,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
15831611
validate_aq_dimension_match(
15841612
idxpl->plsq, idxpl->d, "IndexProductLocalSearchQuantizer");
15851613
READ1(idxpl->code_size);
1614+
validate_code_size_match(
1615+
idxpl->code_size,
1616+
idxpl->plsq.code_size,
1617+
"IndexProductLocalSearchQuantizer");
15861618
read_vector(idxpl->codes, f);
15871619
FAISS_THROW_IF_NOT(
15881620
idxpl->codes.size() == idxpl->ntotal * idxpl->code_size);
@@ -1847,6 +1879,27 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
18471879
nsq,
18481880
dsq);
18491881
auto idxl = std::make_unique<IndexLattice>(d, nsq, scale_nbit, r2);
1882+
// IndexLattice is a lossy compressor: code_size should be
1883+
// smaller than the uncompressed vector (d floats). A corrupt
1884+
// scale_nbit can overflow the total_nbit computation, producing
1885+
// a code_size that wraps to a huge value.
1886+
{
1887+
size_t max_code_size = mul_no_overflow(
1888+
static_cast<size_t>(d),
1889+
sizeof(float),
1890+
"IndexLattice uncompressed vector size");
1891+
FAISS_THROW_IF_NOT_FMT(
1892+
idxl->code_size <= max_code_size,
1893+
"IndexLattice code_size %zd exceeds uncompressed "
1894+
"vector size %zd (likely corrupt scale_nbit=%d, "
1895+
"d=%d, nsq=%d, r2=%d)",
1896+
idxl->code_size,
1897+
max_code_size,
1898+
scale_nbit,
1899+
d,
1900+
nsq,
1901+
r2);
1902+
}
18501903
read_index_header(*idxl, f);
18511904
READVECTOR(idxl->trained);
18521905
idx = std::move(idxl);
@@ -1856,6 +1909,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
18561909
read_ivf_header(ivsc.get(), f, &ids);
18571910
read_ScalarQuantizer(&ivsc->sq, f, *ivsc);
18581911
READ1(ivsc->code_size);
1912+
validate_code_size_match(
1913+
ivsc->code_size, ivsc->sq.code_size, "IndexIVFScalarQuantizer");
18591914
ArrayInvertedLists* ail = set_array_invlist(ivsc.get(), ids);
18601915
for (size_t i = 0; i < ivsc->nlist; i++)
18611916
READVECTOR(ail->codes[i]);
@@ -1865,6 +1920,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
18651920
read_ivf_header(ivsc.get(), f);
18661921
read_ScalarQuantizer(&ivsc->sq, f, *ivsc);
18671922
READ1(ivsc->code_size);
1923+
validate_code_size_match(
1924+
ivsc->code_size, ivsc->sq.code_size, "IndexIVFScalarQuantizer");
18681925
if (h == fourcc("IwSQ")) {
18691926
ivsc->by_residual = true;
18701927
} else {
@@ -1903,6 +1960,10 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
19031960
}
19041961
validate_aq_dimension_match(
19051962
*iva->aq, iva->d, "IndexIVFAdditiveQuantizer");
1963+
validate_code_size_match(
1964+
iva->code_size,
1965+
iva->aq->code_size,
1966+
"IndexIVFAdditiveQuantizer");
19061967
READ1(iva->by_residual);
19071968
READ1(iva->use_precomputed_table);
19081969
read_InvertedLists(*iva, f, io_flags);
@@ -2068,6 +2129,14 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
20682129
READ1(idxp->code_size_1);
20692130
READ1(idxp->code_size_2);
20702131
READ1(idxp->code_size);
2132+
validate_code_size_match(
2133+
idxp->code_size_2,
2134+
idxp->pq.code_size,
2135+
"Index2Layer code_size_2");
2136+
validate_code_size_match(
2137+
idxp->code_size,
2138+
idxp->code_size_1 + idxp->code_size_2,
2139+
"Index2Layer");
20712140
read_vector(idxp->codes, f);
20722141
idx = std::move(idxp);
20732142
} else if (

tests/test_read_index_deserialize.cpp

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515

1616
#include <faiss/Index.h>
1717
#include <faiss/Index2Layer.h>
18+
#include <faiss/IndexAdditiveQuantizer.h>
1819
#include <faiss/IndexAdditiveQuantizerFastScan.h>
1920
#include <faiss/IndexBinary.h>
2021
#include <faiss/IndexBinaryHNSW.h>
2122
#include <faiss/IndexBinaryIVF.h>
2223
#include <faiss/IndexFlat.h>
2324
#include <faiss/IndexHNSW.h>
25+
#include <faiss/IndexIVFAdditiveQuantizer.h>
2426
#include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
2527
#include <faiss/IndexIVFFlat.h>
2628
#include <faiss/IndexIVFIndependentQuantizer.h>
2729
#include <faiss/IndexIVFPQ.h>
2830
#include <faiss/IndexIVFPQR.h>
2931
#include <faiss/IndexRaBitQFastScan.h>
32+
#include <faiss/IndexScalarQuantizer.h>
3033
#include <faiss/VectorTransform.h>
3134
#include <faiss/impl/FaissException.h>
3235
#include <faiss/impl/ScalarQuantizer.h>
@@ -3584,6 +3587,217 @@ TEST(ReadIndexDeserialize, IndexRQFastScanAQDimensionMismatch) {
35843587
expect_read_throws_with(buf, "does not match index d");
35853588
}
35863589

3590+
// ============================================================
3591+
// code_size cross-validation against quantizer-derived values.
3592+
//
3593+
// Several index types read code_size from the serialized stream
3594+
// independently of the quantizer parameters. A corrupt code_size
3595+
// that passes the codes.size() == ntotal * code_size check (e.g.
3596+
// when ntotal == 0) can cause excessive allocations in
3597+
// GenericFlatCodesDistanceComputer during search.
3598+
// ============================================================
3599+
3600+
// Locate the byte offset of code_size in a serialized index by
3601+
// diffing two serializations: one with the real code_size and one
3602+
// with a probe value. This avoids false matches from other fields
3603+
// that happen to share the same numeric value.
3604+
static ssize_t find_code_size_offset(
3605+
Index* index,
3606+
size_t* code_size_ptr,
3607+
size_t real_cs) {
3608+
VectorIOWriter w1;
3609+
write_index(index, &w1);
3610+
3611+
size_t probe_cs = real_cs ^ 0xDEAD;
3612+
*code_size_ptr = probe_cs;
3613+
VectorIOWriter w2;
3614+
write_index(index, &w2);
3615+
*code_size_ptr = real_cs;
3616+
3617+
EXPECT_EQ(w1.data.size(), w2.data.size());
3618+
if (w1.data.size() != w2.data.size()) {
3619+
return -1;
3620+
}
3621+
3622+
ssize_t offset = -1;
3623+
for (size_t i = 0; i + sizeof(size_t) <= w1.data.size(); i++) {
3624+
if (w1.data[i] != w2.data[i]) {
3625+
size_t v1, v2;
3626+
memcpy(&v1, w1.data.data() + i, sizeof(size_t));
3627+
memcpy(&v2, w2.data.data() + i, sizeof(size_t));
3628+
if (v1 == real_cs && v2 == probe_cs) {
3629+
offset = i;
3630+
}
3631+
i += sizeof(size_t) - 1;
3632+
}
3633+
}
3634+
return offset;
3635+
}
3636+
3637+
// Serialize a valid index, locate the exact byte offset of its
3638+
// code_size field via double-serialization diffing, corrupt it,
3639+
// and verify deserialization rejects it.
3640+
static void corrupt_code_size_and_expect_throw(
3641+
Index* index,
3642+
size_t* code_size_ptr,
3643+
const std::string& expected_substr) {
3644+
size_t real_cs = *code_size_ptr;
3645+
ssize_t offset = find_code_size_offset(index, code_size_ptr, real_cs);
3646+
ASSERT_GE(offset, 0) << "could not locate code_size field in "
3647+
<< "serialized data via double-serialization diff";
3648+
3649+
VectorIOWriter writer;
3650+
write_index(index, &writer);
3651+
3652+
size_t corrupt_cs = real_cs + 999;
3653+
memcpy(writer.data.data() + offset, &corrupt_cs, sizeof(size_t));
3654+
3655+
VectorIOReader reader;
3656+
reader.data = writer.data;
3657+
try {
3658+
auto idx = std::unique_ptr<Index>(read_index(&reader));
3659+
FAIL() << "expected FaissException containing '" << expected_substr
3660+
<< "'";
3661+
} catch (const FaissException& e) {
3662+
EXPECT_NE(
3663+
std::string(e.what()).find(expected_substr), std::string::npos)
3664+
<< "expected '" << expected_substr << "' in: " << e.what();
3665+
}
3666+
}
3667+
3668+
static std::vector<float> make_random_data(int d, int n, int seed = 42) {
3669+
std::vector<float> data(d * n);
3670+
std::mt19937 rng(seed);
3671+
std::uniform_real_distribution<float> dist;
3672+
for (auto& v : data) {
3673+
v = dist(rng);
3674+
}
3675+
return data;
3676+
}
3677+
3678+
TEST(ReadIndexDeserialize, ResidualQuantizerCodeSizeMismatch) {
3679+
int d = 8, nb = 256;
3680+
IndexResidualQuantizer idx(d, 2, 4);
3681+
auto xb = make_random_data(d, nb);
3682+
idx.train(nb, xb.data());
3683+
corrupt_code_size_and_expect_throw(
3684+
&idx, &idx.code_size, "code_size mismatch");
3685+
}
3686+
3687+
TEST(ReadIndexDeserialize, LocalSearchQuantizerCodeSizeMismatch) {
3688+
int d = 8, nb = 256;
3689+
IndexLocalSearchQuantizer idx(d, 2, 4);
3690+
auto xb = make_random_data(d, nb);
3691+
idx.train(nb, xb.data());
3692+
corrupt_code_size_and_expect_throw(
3693+
&idx, &idx.code_size, "code_size mismatch");
3694+
}
3695+
3696+
TEST(ReadIndexDeserialize, ProductResidualQuantizerCodeSizeMismatch) {
3697+
int d = 16, nb = 512;
3698+
IndexProductResidualQuantizer idx(d, 2, 4, 8);
3699+
auto xb = make_random_data(d, nb);
3700+
idx.train(nb, xb.data());
3701+
corrupt_code_size_and_expect_throw(
3702+
&idx, &idx.code_size, "code_size mismatch");
3703+
}
3704+
3705+
TEST(ReadIndexDeserialize, ProductLocalSearchQuantizerCodeSizeMismatch) {
3706+
int d = 8, nb = 256;
3707+
IndexProductLocalSearchQuantizer idx(d, 2, 2, 4);
3708+
auto xb = make_random_data(d, nb);
3709+
idx.train(nb, xb.data());
3710+
corrupt_code_size_and_expect_throw(
3711+
&idx, &idx.code_size, "code_size mismatch");
3712+
}
3713+
3714+
// IndexLSH code_size is validated against (nbits + 7) / 8. Use a
3715+
// crafted payload with a mismatched code_size to exercise the check.
3716+
TEST(ReadIndexDeserialize, LSHCodeSizeMismatch) {
3717+
std::vector<uint8_t> buf;
3718+
push_fourcc(buf, "IxHe"); // new format IndexLSH
3719+
push_index_header(buf, 8, 0);
3720+
push_val<int>(buf, 64); // nbits
3721+
push_val<bool>(buf, false); // rotate_data
3722+
push_val<bool>(buf, false); // train_thresholds
3723+
push_vector<float>(buf, {}); // thresholds
3724+
push_val<int>(buf, 99); // code_size = 99 (should be 8)
3725+
// rrot: RandomRotationMatrix
3726+
push_fourcc(buf, "rrot");
3727+
push_val<int>(buf, 8); // d_in
3728+
push_val<int>(buf, 64); // d_out
3729+
push_val<bool>(buf, false); // is_trained
3730+
push_val<bool>(buf, false); // have_bias
3731+
push_vector<float>(buf, {}); // A
3732+
push_vector<float>(buf, {}); // b
3733+
push_vector<uint8_t>(buf, {}); // codes
3734+
3735+
expect_read_throws_with(buf, "code_size mismatch");
3736+
}
3737+
3738+
TEST(ReadIndexDeserialize, IVFScalarQuantizerCodeSizeMismatch) {
3739+
// QT_fp16 gives code_size = d * 2 = 16, distinctive enough.
3740+
int d = 8, nb = 256, nlist = 4;
3741+
IndexFlatL2 quantizer(d);
3742+
IndexIVFScalarQuantizer idx(&quantizer, d, nlist, ScalarQuantizer::QT_fp16);
3743+
idx.own_fields = false;
3744+
auto xb = make_random_data(d, nb);
3745+
idx.train(nb, xb.data());
3746+
// The corrupted code_size is caught either by our
3747+
// validate_code_size_match or by the InvertedLists code_size
3748+
// consistency check — both reject corrupt data.
3749+
corrupt_code_size_and_expect_throw(&idx, &idx.code_size, "code_size");
3750+
}
3751+
3752+
TEST(ReadIndexDeserialize, IVFAdditiveQuantizerCodeSizeMismatch) {
3753+
int d = 16, nb = 512, nlist = 4;
3754+
IndexFlatL2 quantizer(d);
3755+
IndexIVFResidualQuantizer idx(&quantizer, d, nlist, 2, 8);
3756+
idx.own_fields = false;
3757+
auto xb = make_random_data(d, nb);
3758+
idx.train(nb, xb.data());
3759+
corrupt_code_size_and_expect_throw(&idx, &idx.code_size, "code_size");
3760+
}
3761+
3762+
TEST(ReadIndexDeserialize, Index2LayerCodeSize2Mismatch) {
3763+
int d = 8, nb = 256, nlist = 4;
3764+
auto quantizer = std::make_unique<IndexFlatL2>(d);
3765+
Index2Layer idx(quantizer.release(), nlist, 4);
3766+
idx.q1.own_fields = true;
3767+
auto xb = make_random_data(d, nb);
3768+
idx.train(nb, xb.data());
3769+
idx.add(nb, xb.data());
3770+
corrupt_code_size_and_expect_throw(
3771+
&idx, &idx.code_size_2, "code_size mismatch");
3772+
}
3773+
3774+
TEST(ReadIndexDeserialize, Index2LayerCodeSizeSumMismatch) {
3775+
int d = 8, nb = 256, nlist = 4;
3776+
auto quantizer = std::make_unique<IndexFlatL2>(d);
3777+
Index2Layer idx(quantizer.release(), nlist, 4);
3778+
idx.q1.own_fields = true;
3779+
auto xb = make_random_data(d, nb);
3780+
idx.train(nb, xb.data());
3781+
idx.add(nb, xb.data());
3782+
corrupt_code_size_and_expect_throw(
3783+
&idx, &idx.code_size, "code_size mismatch");
3784+
}
3785+
3786+
// IndexLattice code_size is derived from scale_nbit, lattice_nbit,
3787+
// and nsq. A corrupt scale_nbit (e.g. negative) causes integer
3788+
// overflow in the total_nbit → code_size computation, producing
3789+
// a huge code_size that is rejected at deserialization.
3790+
TEST(ReadIndexDeserialize, IndexLatticeCodeSizeTooLarge) {
3791+
std::vector<uint8_t> buf;
3792+
push_fourcc(buf, "IxLa");
3793+
push_val<int>(buf, 4); // d
3794+
push_val<int>(buf, 2); // nsq (dsq = 4/2 = 2, power of 2 >= 2)
3795+
push_val<int>(buf, -100); // scale_nbit (corrupt → overflows code_size)
3796+
push_val<int>(buf, 1); // r2
3797+
3798+
expect_read_throws_with(buf, "code_size");
3799+
}
3800+
35873801
// ============================================================
35883802
// SVS fourcc rejection / deserialization safety (Group F: T262015608)
35893803
// ============================================================

0 commit comments

Comments
 (0)