Skip to content

Commit 8932716

Browse files
scsiguymeta-codesync[bot]
authored andcommitted
Validate SVS storage_kind via shared helper at all deserialization read sites (#5204)
Summary: Pull Request resolved: #5204 read_index_up reads the SVS storage_kind field from two distinct fourcc paths — the IndexSVSVamana family (ILVQ/ISVL/ISVD/ISV2) and the IndexSVSIVF family (ISIQ/ISIL/ISID). The Vamana branch read the value into a plain int, range-checked it against [0, SVS_count), and then cast to SVSStorageKind, so a malformed payload was rejected at the deserialization boundary with a clear error. The IVF branch read the value directly into the SVSStorageKind enum field with no range check. A corrupt or maliciously constructed payload could store an out-of-range enum value and continue reading several more fields; the bad value was only noticed later from to_svs_storage_kind() inside IndexSVSIVF::deserialize_impl, after the SVS runtime load had been entered, producing a less actionable error from a deeper site. Replace the open-coded Vamana check and the missing IVF check with a single static read_svs_storage_kind(IOReader*) helper in index_read.cpp. The helper performs the four-byte read and the [0, SVS_count) validation in one place and returns the validated enum, so each call site collapses to a one-line assignment. Future SVS read sites cannot diverge in their validation strategy without going around this helper. The helper preserves the existing "invalid SVS storage_kind=N (must be in [0, M))" error message, so no externally-visible error string changes. Reviewed By: mnorris11 Differential Revision: D104481541 fbshipit-source-id: 7e67d84b93ed142fcac39336c6bd6fdeda48cdb5
1 parent 6bd749e commit 8932716

2 files changed

Lines changed: 105 additions & 18 deletions

File tree

faiss/impl/index_read.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ namespace faiss {
8686
namespace {
8787
size_t deserialization_loop_limit_ = 0;
8888
size_t deserialization_vector_byte_limit_ = uint64_t{1} << 40; // 1 TB
89+
90+
#ifdef FAISS_ENABLE_SVS
91+
// Read and validate an SVSStorageKind from the stream. Centralizes the
92+
// [0, SVS_count) range check so every SVS read site rejects out-of-range
93+
// values uniformly at the deserialization boundary, instead of letting
94+
// to_svs_storage_kind() surface the failure later from inside an SVS
95+
// runtime load.
96+
SVSStorageKind read_svs_storage_kind(IOReader* f) {
97+
int sk;
98+
READ1(sk);
99+
FAISS_THROW_IF_NOT_FMT(
100+
sk >= 0 && sk < static_cast<int>(SVS_count),
101+
"invalid SVS storage_kind=%d (must be in [0, %d))",
102+
sk,
103+
static_cast<int>(SVS_count));
104+
return static_cast<SVSStorageKind>(sk);
105+
}
106+
#endif // FAISS_ENABLE_SVS
89107
} // namespace
90108

91109
size_t get_deserialization_loop_limit() {
@@ -2556,14 +2574,7 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
25562574
READ1(svs->prune_to);
25572575
READ1(svs->use_full_search_history);
25582576

2559-
int sk;
2560-
READ1(sk);
2561-
FAISS_THROW_IF_NOT_FMT(
2562-
sk >= 0 && sk < static_cast<int>(SVS_count),
2563-
"invalid SVS storage_kind=%d (must be in [0, %d))",
2564-
sk,
2565-
static_cast<int>(SVS_count));
2566-
svs->storage_kind = static_cast<SVSStorageKind>(sk);
2577+
svs->storage_kind = read_svs_storage_kind(f);
25672578

25682579
if (h == fourcc("ISVL")) {
25692580
auto* leanvec = dynamic_cast<IndexSVSVamanaLeanVec*>(svs.get());
@@ -2636,7 +2647,7 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
26362647
READ1(svs_ivf->k_reorder);
26372648
READ1(svs_ivf->num_threads);
26382649
READ1(svs_ivf->intra_query_threads);
2639-
READ1(svs_ivf->storage_kind);
2650+
svs_ivf->storage_kind = read_svs_storage_kind(f);
26402651
READ1(svs_ivf->is_static);
26412652
if (h == fourcc("ISIL")) {
26422653
auto* leanvec = dynamic_cast<IndexSVSIVFLeanVec*>(svs_ivf.get());

tests/test_read_index_deserialize.cpp

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3908,11 +3908,12 @@ TEST(ReadIndexDeserialize,
39083908

39093909
#include <faiss/svs/IndexSVSVamana.h>
39103910

3911-
// An invalid storage_kind value should be rejected at deserialization time
3912-
// with a FaissException, not abort via FAISS_ASSERT in to_svs_storage_kind().
3913-
TEST(ReadIndexDeserialize, SVSVamanaInvalidStorageKind) {
3914-
std::vector<uint8_t> buf;
3915-
push_fourcc(buf, "ISVD");
3911+
namespace {
3912+
3913+
/// Append the IndexSVSVamana on-disk fields up to and including storage_kind.
3914+
/// Caller may then push initialized=true (to trigger the SVS deserialize_impl
3915+
/// path) or any other trailing fields needed to reach the validation site.
3916+
void push_svs_vamana_prefix(std::vector<uint8_t>& buf, int storage_kind) {
39163917
push_index_header(buf, 8, 0);
39173918
push_val<size_t>(buf, 32); // graph_max_degree
39183919
push_val<float>(buf, 1.2f); // alpha
@@ -3922,10 +3923,85 @@ TEST(ReadIndexDeserialize, SVSVamanaInvalidStorageKind) {
39223923
push_val<size_t>(buf, 750); // max_candidate_pool_size
39233924
push_val<size_t>(buf, 28); // prune_to
39243925
push_val<bool>(buf, false); // use_full_search_history
3925-
push_val<int>(
3926-
buf,
3927-
static_cast<int>(SVS_count)); // storage_kind — first invalid value
3928-
push_val<bool>(buf, true); // initialized
3926+
push_val<int>(buf, storage_kind);
3927+
}
3928+
3929+
/// Append the IndexSVSIVF on-disk fields up to and including storage_kind.
3930+
/// Used to exercise the validation guard at the ISIQ/ISIL/ISID read sites.
3931+
void push_svs_ivf_prefix(std::vector<uint8_t>& buf, int storage_kind) {
3932+
push_index_header(buf, 8, 0);
3933+
push_val<size_t>(buf, 8); // num_centroids
3934+
push_val<size_t>(buf, 64); // minibatch_size
3935+
push_val<size_t>(buf, 1); // num_iterations
3936+
push_val<bool>(buf, false); // is_hierarchical
3937+
push_val<float>(buf, 0.1f); // training_fraction
3938+
push_val<size_t>(buf, 0); // hierarchical_level1_clusters
3939+
push_val<size_t>(buf, 0); // seed
3940+
push_val<size_t>(buf, 1); // n_probes
3941+
push_val<float>(buf, 1.0f); // k_reorder (float, not size_t)
3942+
push_val<size_t>(buf, 1); // num_threads
3943+
push_val<size_t>(buf, 1); // intra_query_threads
3944+
push_val<int>(buf, storage_kind);
3945+
}
3946+
3947+
} // namespace
3948+
3949+
// An invalid storage_kind value should be rejected at deserialization time
3950+
// with a FaissException, not abort via FAISS_ASSERT in to_svs_storage_kind().
3951+
TEST(ReadIndexDeserialize, SVSVamanaInvalidStorageKind) {
3952+
std::vector<uint8_t> buf;
3953+
push_fourcc(buf, "ISVD");
3954+
push_svs_vamana_prefix(buf, /*storage_kind=*/static_cast<int>(SVS_count));
3955+
push_val<bool>(buf, true); // initialized
3956+
3957+
expect_read_throws_with(buf, "storage_kind");
3958+
}
3959+
3960+
// The Vamana validator must reject negative storage_kind values too — the
3961+
// shared helper takes a signed int because READ1 reads four bytes that could
3962+
// be either sign.
3963+
TEST(ReadIndexDeserialize, SVSVamanaNegativeStorageKind) {
3964+
std::vector<uint8_t> buf;
3965+
push_fourcc(buf, "ISVD");
3966+
push_svs_vamana_prefix(buf, /*storage_kind=*/-1);
3967+
push_val<bool>(buf, true);
3968+
3969+
expect_read_throws_with(buf, "storage_kind");
3970+
}
3971+
3972+
// IVF Vamana flavours (ISIQ = IndexSVSIVFLVQ, ISIL = IndexSVSIVFLeanVec,
3973+
// ISID = IndexSVSIVF) all share the same storage_kind read site and must
3974+
// reject out-of-range values at the deserialization boundary, mirroring the
3975+
// Vamana branch above. Without the shared validator the bad value would
3976+
// propagate into IndexSVSIVF::deserialize_impl and only get rejected from
3977+
// to_svs_storage_kind() after several allocations and an SVS-runtime call.
3978+
TEST(ReadIndexDeserialize, SVSIVFInvalidStorageKind) {
3979+
std::vector<uint8_t> buf;
3980+
push_fourcc(buf, "ISID");
3981+
push_svs_ivf_prefix(buf, /*storage_kind=*/static_cast<int>(SVS_count));
3982+
push_val<bool>(buf, true); // is_static
3983+
push_val<bool>(buf, true); // initialized
3984+
3985+
expect_read_throws_with(buf, "storage_kind");
3986+
}
3987+
3988+
TEST(ReadIndexDeserialize, SVSIVFLVQInvalidStorageKind) {
3989+
std::vector<uint8_t> buf;
3990+
push_fourcc(buf, "ISIQ");
3991+
push_svs_ivf_prefix(buf, /*storage_kind=*/-1);
3992+
push_val<bool>(buf, true);
3993+
push_val<bool>(buf, true);
3994+
3995+
expect_read_throws_with(buf, "storage_kind");
3996+
}
3997+
3998+
TEST(ReadIndexDeserialize, SVSIVFLeanVecInvalidStorageKind) {
3999+
std::vector<uint8_t> buf;
4000+
push_fourcc(buf, "ISIL");
4001+
push_svs_ivf_prefix(buf, /*storage_kind=*/static_cast<int>(SVS_count) + 7);
4002+
push_val<bool>(buf, true); // is_static
4003+
push_val<size_t>(buf, 8); // leanvec_d (only on ISIL path)
4004+
push_val<bool>(buf, true); // initialized
39294005

39304006
expect_read_throws_with(buf, "storage_kind");
39314007
}

0 commit comments

Comments
 (0)