Skip to content

Commit ce9cd6f

Browse files
authored
Merge branch 'main' into export-D95911440
2 parents fbc766e + 5b83ec6 commit ce9cd6f

9 files changed

Lines changed: 78 additions & 20 deletions

File tree

faiss/IndexIVFPQ.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ static std::unique_ptr<float[]> compute_residuals(
149149
const idx_t* list_nos) {
150150
size_t d = quantizer->d;
151151
std::unique_ptr<float[]> residuals(new float[n * d]);
152-
// TODO: parallelize?
152+
// Parallelize with OpenMP (each iteration is independent)
153+
#pragma omp parallel for if (n > 1000)
153154
for (size_t i = 0; i < n; i++) {
154155
if (list_nos[i] < 0)
155156
memset(residuals.get() + i * d, 0, sizeof(float) * d);

faiss/MetricType.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@ namespace faiss {
2424
/// (brute-force) indices supporting additional metric types for vector
2525
/// comparison.
2626
///
27-
/// NOTE: when adding or removing values, update metric_type_from_int() below.
27+
/// NOTE: when adding or removing values, update metric_type_from_int()
28+
/// and metric_type_count() below.
2829
enum MetricType {
29-
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
30-
METRIC_L2 = 1, ///< squared L2 search
31-
METRIC_L1, ///< L1 (aka cityblock)
32-
METRIC_Linf, ///< infinity distance
33-
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
34-
/// metric_arg
30+
METRIC_INNER_PRODUCT, ///< maximum inner product search
31+
METRIC_L2, ///< squared L2 search
32+
METRIC_L1, ///< L1 (aka cityblock)
33+
METRIC_Linf, ///< infinity distance
34+
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
35+
/// metric_arg
3536

3637
/// some additional metrics defined in scipy.spatial.distance
3738
METRIC_Canberra = 20,
@@ -68,6 +69,12 @@ inline MetricType metric_type_from_int(int x) {
6869
return static_cast<MetricType>(x);
6970
}
7071

72+
/// Count of entries in the MetricType enum.
73+
constexpr size_t metric_type_count() {
74+
return (METRIC_Lp - METRIC_INNER_PRODUCT) + 1 +
75+
(METRIC_GOWER - METRIC_Canberra) + 1;
76+
}
77+
7178
} // namespace faiss
7279

7380
#endif

faiss/gpu/GpuIndexIVFFlat.cu

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,6 @@ void GpuIndexIVFFlat::train(idx_t n, const float* x) {
210210

211211
if (this->is_trained) {
212212
FAISS_ASSERT(index_);
213-
if (should_use_cuvs(config_)) {
214-
// copy the IVF centroids to the cuVS index
215-
// in case it has been reset. This is because `reset` clears the
216-
// cuVS index and its centroids.
217-
// TODO: change this once the coarse quantizer is separated from
218-
// cuVS index
219-
updateQuantizer();
220-
};
221213
return;
222214
}
223215

faiss/gpu/impl/CuvsIVFFlat.cu

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ void CuvsIVFFlat::reserveMemory(idx_t numVecs) {
7979
}
8080

8181
void CuvsIVFFlat::reset() {
82-
cuvs_index.reset();
82+
if (cuvs_index != nullptr) {
83+
const raft::device_resources& raft_handle =
84+
resources_->getRaftHandleCurrentDevice();
85+
cuvs::neighbors::ivf_flat::helpers::reset_index(
86+
raft_handle, cuvs_index.get());
87+
}
8388
}
8489

8590
void CuvsIVFFlat::setCuvsIndex(

faiss/impl/AdditiveQuantizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ struct AdditiveQuantizer : Quantizer {
8383
ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
8484
///< scan)
8585
ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
86+
ST_count
8687
};
8788

8889
AdditiveQuantizer(

faiss/impl/ScalarQuantizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct ScalarQuantizer : Quantizer {
3333
QT_bf16,
3434
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
3535
///< [-128 to 127]
36+
QT_count
3637
};
3738

3839
QuantizerType qtype = QT_8bit;

faiss/impl/index_read.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@
7878

7979
namespace faiss {
8080

81+
namespace {
82+
size_t deserialization_loop_limit_ = 0;
83+
} // namespace
84+
85+
size_t get_deserialization_loop_limit() {
86+
return deserialization_loop_limit_;
87+
}
88+
89+
void set_deserialization_loop_limit(size_t value) {
90+
deserialization_loop_limit_ = value;
91+
}
92+
93+
#define FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(val, field_name) \
94+
do { \
95+
auto limit_ = get_deserialization_loop_limit(); \
96+
if (limit_ > 0) { \
97+
FAISS_THROW_IF_NOT_FMT( \
98+
static_cast<size_t>(val) <= limit_, \
99+
"%s=%zd exceeds deserialization_loop_limit" \
100+
" of %zd", \
101+
field_name, \
102+
static_cast<size_t>(val), \
103+
limit_); \
104+
} \
105+
} while (0)
106+
81107
/*************************************************************
82108
* Mmap-ing and viewing facilities
83109
**************************************************************/
@@ -371,6 +397,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
371397
} else if (h == fourcc("ilpn") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
372398
size_t nlist, code_size, n_levels;
373399
READ1(nlist);
400+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nlist, "ilpn nlist");
374401
READ1(code_size);
375402
READ1(n_levels);
376403
auto ailp = std::make_unique<ArrayInvertedListsPanorama>(
@@ -400,6 +427,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
400427
} else if (h == fourcc("ilar") && !(io_flags & IO_FLAG_SKIP_IVF_DATA)) {
401428
auto ails = std::make_unique<ArrayInvertedLists>(0, 0);
402429
READ1(ails->nlist);
430+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ails->nlist, "ilar nlist");
403431
READ1(ails->code_size);
404432
ails->ids.resize(ails->nlist);
405433
ails->codes.resize(ails->nlist);
@@ -430,6 +458,7 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
430458
int h2 = (io_flags & 0xffff0000) | (fourcc("il__") & 0x0000ffff);
431459
size_t nlist, code_size;
432460
READ1(nlist);
461+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nlist, "ilar skip nlist");
433462
READ1(code_size);
434463
std::vector<size_t> sizes(nlist);
435464
read_ArrayInvertedLists_sizes(f, sizes);
@@ -557,6 +586,7 @@ static void read_ProductAdditiveQuantizer(
557586
paq.nsplits > 0,
558587
"invalid ProductAdditiveQuantizer nsplits %zd (must be > 0)",
559588
paq.nsplits);
589+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(paq.nsplits, "nsplits");
560590
}
561591

562592
static void read_ProductResidualQuantizer(
@@ -735,6 +765,8 @@ static void read_HNSW(HNSW& hnsw, IOReader* f) {
735765
static void read_NSG(NSG& nsg, IOReader* f) {
736766
READ1(nsg.ntotal);
737767
READ1(nsg.R);
768+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nsg.ntotal, "nsg.ntotal");
769+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(nsg.R, "nsg.R");
738770
FAISS_THROW_IF_NOT_FMT(nsg.R > 0, "invalid NSG R %d (must be > 0)", nsg.R);
739771
READ1(nsg.L);
740772
READ1(nsg.C);
@@ -858,6 +890,7 @@ void read_ivf_header(
858890
std::vector<std::vector<idx_t>>* ids) {
859891
read_index_header(*ivf, f);
860892
READ1(ivf->nlist);
893+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ivf->nlist, "nlist");
861894
READ1(ivf->nprobe);
862895
ivf->quantizer = read_index(f);
863896
ivf->own_fields = true;
@@ -1368,6 +1401,8 @@ std::unique_ptr<Index> read_index_up(IOReader* f, int io_flags) {
13681401
nt >= 0,
13691402
"invalid VectorTransform chain length %d (must be >= 0)",
13701403
nt);
1404+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(
1405+
nt, "VectorTransform chain length");
13711406
for (int i = 0; i < nt; i++) {
13721407
ixpt->chain.push_back(read_VectorTransform(f));
13731408
}
@@ -1898,6 +1933,7 @@ static void read_binary_ivf_header(
18981933
std::vector<std::vector<idx_t>>* ids = nullptr) {
18991934
read_index_binary_header(ivf, f);
19001935
READ1(ivf.nlist);
1936+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(ivf.nlist, "nlist");
19011937
READ1(ivf.nprobe);
19021938
ivf.quantizer = read_index_binary(f);
19031939
ivf.own_fields = true;
@@ -1915,6 +1951,7 @@ static void read_binary_hash_invlists(
19151951
IOReader* f) {
19161952
size_t sz;
19171953
READ1(sz);
1954+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(sz, "binary hash invlists sz");
19181955
int il_nbit = 0;
19191956
READ1(il_nbit);
19201957
FAISS_THROW_IF_NOT_FMT(
@@ -1965,6 +2002,7 @@ static void read_binary_multi_hash_map(
19652002
size_t sz;
19662003
READ1(id_bits);
19672004
READ1(sz);
2005+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(sz, "multi hash map sz");
19682006
std::vector<uint8_t> buf;
19692007
READVECTOR(buf);
19702008
size_t nbit = add_no_overflow(
@@ -2063,6 +2101,7 @@ std::unique_ptr<IndexBinary> read_index_binary_up(IOReader* f, int io_flags) {
20632101
idxmh->nhash > 0,
20642102
"invalid IndexBinaryMultiHash nhash %d (must be > 0)",
20652103
idxmh->nhash);
2104+
FAISS_CHECK_DESERIALIZATION_LOOP_LIMIT(idxmh->nhash, "nhash");
20662105
READ1(idxmh->nflip);
20672106
idxmh->maps.resize(idxmh->nhash);
20682107
for (int i = 0; i < idxmh->nhash; i++) {

faiss/index_io.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,17 @@ std::unique_ptr<InvertedLists> read_InvertedLists_up(
113113
IOReader* reader,
114114
int io_flags = 0);
115115

116+
// Returns the current deserialization loop limit.
117+
// When nonzero, deserialization rejects loop-driving fields (nlist,
118+
// nsplits, VT chain length, nhash, etc.) that exceed this value.
119+
// Default: 0 (no limit).
120+
size_t get_deserialization_loop_limit();
121+
122+
// Sets the deserialization loop limit.
123+
// NOT thread-safe: set before any concurrent deserialization calls
124+
// and do not modify while deserialization is in progress on other threads.
125+
void set_deserialization_loop_limit(size_t value);
126+
116127
} // namespace faiss
117128

118129
#endif

faiss/invlists/DirectMap.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ inline uint64_t lo_offset(uint64_t lo) {
3737
*/
3838
struct DirectMap {
3939
enum Type {
40-
NoMap = 0, // default
41-
Array = 1, // sequential ids (only for add, no add_with_ids)
42-
Hashtable = 2 // arbitrary ids
40+
NoMap, // default
41+
Array, // sequential ids (only for add, no add_with_ids)
42+
Hashtable, // arbitrary ids
43+
DMT_count
4344
};
4445
Type type;
4546

0 commit comments

Comments
 (0)