Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ namespace faiss {
/// (brute-force) indices supporting additional metric types for vector
/// comparison.
///
/// NOTE: when adding or removing values, update metric_type_from_int() below.
/// NOTE: when adding or removing values, update metric_type_from_int()
/// and metric_type_count() below.
enum MetricType {
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
METRIC_L2 = 1, ///< squared L2 search
METRIC_L1, ///< L1 (aka cityblock)
METRIC_Linf, ///< infinity distance
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
/// metric_arg
METRIC_INNER_PRODUCT, ///< maximum inner product search
METRIC_L2, ///< squared L2 search
METRIC_L1, ///< L1 (aka cityblock)
METRIC_Linf, ///< infinity distance
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
/// metric_arg

/// some additional metrics defined in scipy.spatial.distance
METRIC_Canberra = 20,
Expand Down Expand Up @@ -68,6 +69,12 @@ inline MetricType metric_type_from_int(int x) {
return static_cast<MetricType>(x);
}

/// Count of entries in the MetricType enum.
constexpr size_t metric_type_count() {
return (METRIC_Lp - METRIC_INNER_PRODUCT) + 1 +
(METRIC_GOWER - METRIC_Canberra) + 1;
}

} // namespace faiss

#endif
1 change: 1 addition & 0 deletions faiss/impl/AdditiveQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct AdditiveQuantizer : Quantizer {
ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
///< scan)
ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
ST_count
};

AdditiveQuantizer(
Expand Down
1 change: 1 addition & 0 deletions faiss/impl/ScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct ScalarQuantizer : Quantizer {
QT_bf16,
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
///< [-128 to 127]
QT_count
};

QuantizerType qtype = QT_8bit;
Expand Down
39 changes: 39 additions & 0 deletions faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@

namespace faiss {

namespace {
size_t max_deserialization_iters = 0;
} // namespace

size_t get_max_deserialization_iterations() {
return max_deserialization_iters;
}

void set_max_deserialization_iterations(size_t value) {
max_deserialization_iters = value;
}

#define FAISS_CHECK_DESERIALIZATION_ITERATIONS(val, field_name) \
do { \
auto limit_ = get_max_deserialization_iterations(); \
if (limit_ > 0) { \
FAISS_THROW_IF_NOT_FMT( \
static_cast<size_t>(val) <= limit_, \
"%s=%zd exceeds max_deserialization_iterations " \
"limit %zd", \
field_name, \
static_cast<size_t>(val), \
limit_); \
} \
} while (0)

/*************************************************************
* Mmap-ing and viewing facilities
**************************************************************/
Expand Down Expand Up @@ -371,6 +397,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
} else if (h == fourcc("ilpn") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
size_t nlist, code_size, n_levels;
READ1(nlist);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(nlist, "ilpn nlist");
READ1(code_size);
READ1(n_levels);
auto ailp = std::make_unique<ArrayInvertedListsPanorama>(
Expand Down Expand Up @@ -400,6 +427,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
} else if (h == fourcc("ilar") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
auto ails = std::make_unique<ArrayInvertedLists>(0, 0);
READ1(ails->nlist);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(ails->nlist, "ilar nlist");
READ1(ails->code_size);
ails->ids.resize(ails->nlist);
ails->codes.resize(ails->nlist);
Expand Down Expand Up @@ -430,6 +458,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
int h2 = (io_flags & 0xffff0000) | (fourcc("il__") & 0x0000ffff);
size_t nlist, code_size;
READ1(nlist);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(nlist, "ilar skip nlist");
READ1(code_size);
std::vector<size_t> sizes(nlist);
read_ArrayInvertedLists_sizes(f, sizes);
Expand Down Expand Up @@ -557,6 +586,7 @@ static void read_ProductAdditiveQuantizer(
paq.nsplits > 0,
"invalid ProductAdditiveQuantizer nsplits %zd (must be > 0)",
paq.nsplits);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(paq.nsplits, "nsplits");
}

static void read_ProductResidualQuantizer(
Expand Down Expand Up @@ -735,6 +765,8 @@ static void read_HNSW(HNSW& hnsw, IOReader* f) {
static void read_NSG(NSG& nsg, IOReader* f) {
READ1(nsg.ntotal);
READ1(nsg.R);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(nsg.ntotal, "nsg.ntotal");
FAISS_CHECK_DESERIALIZATION_ITERATIONS(nsg.R, "nsg.R");
FAISS_THROW_IF_NOT_FMT(nsg.R > 0, "invalid NSG R %d (must be > 0)", nsg.R);
READ1(nsg.L);
READ1(nsg.C);
Expand Down Expand Up @@ -858,6 +890,7 @@ void read_ivf_header(
std::vector<std::vector<idx_t>>* ids) {
read_index_header(*ivf, f);
READ1(ivf->nlist);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(ivf->nlist, "nlist");
READ1(ivf->nprobe);
ivf->quantizer = read_index(f);
ivf->own_fields = true;
Expand Down Expand Up @@ -1368,6 +1401,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
nt >= 0,
"invalid VectorTransform chain length %d (must be >= 0)",
nt);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(
nt, "VectorTransform chain length");
for (int i = 0; i < nt; i++) {
ixpt->chain.push_back(read_VectorTransform(f));
}
Expand Down Expand Up @@ -1898,6 +1933,7 @@ static void read_binary_ivf_header(
std::vector<std::vector<idx_t>>* ids = nullptr) {
read_index_binary_header(ivf, f);
READ1(ivf.nlist);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(ivf.nlist, "nlist");
READ1(ivf.nprobe);
ivf.quantizer = read_index_binary(f);
ivf.own_fields = true;
Expand All @@ -1915,6 +1951,7 @@ static void read_binary_hash_invlists(
IOReader* f) {
size_t sz;
READ1(sz);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(sz, "binary hash invlists sz");
int il_nbit = 0;
READ1(il_nbit);
FAISS_THROW_IF_NOT_FMT(
Expand Down Expand Up @@ -1965,6 +2002,7 @@ static void read_binary_multi_hash_map(
size_t sz;
READ1(id_bits);
READ1(sz);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(sz, "multi hash map sz");
std::vector<uint8_t> buf;
READVECTOR(buf);
size_t nbit = add_no_overflow(
Expand Down Expand Up @@ -2063,6 +2101,7 @@ std::unique_ptr<IndexBinary> read_index_binary_up(IOReader* f, int io_flags) {
idxmh->nhash > 0,
"invalid IndexBinaryMultiHash nhash %d (must be > 0)",
idxmh->nhash);
FAISS_CHECK_DESERIALIZATION_ITERATIONS(idxmh->nhash, "nhash");
READ1(idxmh->nflip);
idxmh->maps.resize(idxmh->nhash);
for (int i = 0; i < idxmh->nhash; i++) {
Expand Down
11 changes: 11 additions & 0 deletions faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
IOReader* reader,
int io_flags = 0);

// Returns the current deserialization-iteration cap.
// When nonzero, deserialization rejects loop-driving fields (nlist,
// nsplits, VT chain length, nhash, etc.) that exceed this value.
// Default: 0 (no limit).
size_t get_max_deserialization_iterations();

// Sets the deserialization-iteration cap.
// NOT thread-safe: set before any concurrent deserialization calls
// and do not modify while deserialization is in progress on other threads.
void set_max_deserialization_iterations(size_t value);

} // namespace faiss

#endif
7 changes: 4 additions & 3 deletions faiss/invlists/DirectMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ inline uint64_t lo_offset(uint64_t lo) {
*/
struct DirectMap {
enum Type {
NoMap = 0, // default
Array = 1, // sequential ids (only for add, no add_with_ids)
Hashtable = 2 // arbitrary ids
NoMap, // default
Array, // sequential ids (only for add, no add_with_ids)
Hashtable, // arbitrary ids
DMT_count
};
Type type;

Expand Down
Loading