Skip to content

Commit 1cb3d46

Browse files
ahuber21meta-codesync[bot]
authored andcommitted
feat(svs): LeanVec OOD support (facebookresearch#4773)
Summary: This PR will add LeanVec OOD (out-of-distribution) support. Pull Request resolved: facebookresearch#4773 Reviewed By: alibeklfc Differential Revision: D94607022 Pulled By: mnorris11 fbshipit-source-id: df1b023919b6a2753fdddcd7d5490475f8630d36
1 parent d2f8d35 commit 1cb3d46

10 files changed

Lines changed: 123 additions & 21 deletions

faiss/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ if(FAISS_OPT_LEVEL STREQUAL "dd")
460460
endif()
461461

462462
if(FAISS_ENABLE_SVS)
463-
find_package(svs_runtime REQUIRED)
463+
find_package(svs_runtime 0.2.0 REQUIRED)
464464

465465
target_link_libraries(faiss PUBLIC svs::svs_runtime)
466466
target_link_libraries(faiss_avx2 PUBLIC svs::svs_runtime)

faiss/Index.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ void Index::train(idx_t /*n*/, const float* /*x*/) {
2424
// does nothing by default
2525
}
2626

27+
void Index::train(
28+
idx_t /*n*/,
29+
const float* /*x*/,
30+
idx_t /*n_train_q*/,
31+
const float* /*xq_train*/) {
32+
// does nothing by default
33+
}
34+
2735
void Index::range_search(
2836
idx_t,
2937
const float*,

faiss/Index.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ struct Index {
131131
*/
132132
virtual void train(idx_t n, const float* x);
133133

134+
/** Perfrom training on a representative set of vectors and a representative
135+
* set of queries
136+
*
137+
* @param n nb of training vectors
138+
* @param x training vectors, size n * d
139+
* @param n_train_q nb of training queries
140+
* @param xq_train training queries, size n_train_q * d
141+
*/
142+
virtual void train(
143+
idx_t n,
144+
const float* x,
145+
idx_t n_train_q,
146+
const float* xq_train);
147+
134148
virtual void train_ex(idx_t n, const void* x, NumericType numeric_type) {
135149
if (numeric_type == NumericType::Float32) {
136150
train(n, static_cast<const float*>(x));

faiss/python/class_wrappers.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,21 +332,57 @@ def replacement_assign(self, x, k, labels=None):
332332
self.assign_c(n, swig_ptr(x), swig_ptr(labels), k)
333333
return labels
334334

335-
def replacement_train(self, x, numeric_type = faiss.Float32):
335+
def replacement_train(
336+
self, x, *, numeric_type=faiss.Float32, xq_train=None
337+
):
336338
"""Trains the index on a representative set of vectors.
337339
The index must be trained before vectors can be added to it.
340+
Optionally accepts numeric_type to specify the type of
341+
input vectors.
342+
Optionally accepts a set of training query vectors for
343+
out-of-distribution training.
338344
339345
Parameters
340346
----------
341347
x : array_like
342-
Query vectors, shape (n, d) where d is appropriate for the index.
348+
Query vectors, shape (n, d) where d is appropriate
349+
for the index. `dtype` must be float32.
350+
numeric_type : type
351+
Numeric type of the input vectors.
352+
xq_train : array_like, optional
353+
Training query vectors, shape (n_train_q, d) where
354+
d is appropriate for the index.
343355
`dtype` must be float32.
344356
"""
357+
# Prepare training data
345358
n, d = x.shape
346359
assert d == self.d
347360
x = np.ascontiguousarray(x, dtype=_numeric_to_str(numeric_type))
361+
362+
# Prepare training queries if provided
363+
n_train_q, train_q = 0, None
364+
if xq_train is not None:
365+
if numeric_type != faiss.Float32:
366+
raise TypeError(
367+
"xq_train is only supported for numeric_type faiss.Float32"
368+
)
369+
n_train_q, d_train = xq_train.shape
370+
assert d_train == self.d
371+
train_q = swig_ptr(
372+
np.ascontiguousarray(
373+
xq_train,
374+
dtype=_numeric_to_str(numeric_type),
375+
)
376+
)
377+
378+
# Dispatch to train_c / train_ex
348379
if numeric_type == faiss.Float32:
349-
self.train_c(n, swig_ptr(x))
380+
if train_q is not None:
381+
self.train_c(
382+
n, swig_ptr(x), n_train_q, train_q
383+
)
384+
else:
385+
self.train_c(n, swig_ptr(x))
350386
else:
351387
self.train_ex(n, swig_ptr(x), numeric_type)
352388

faiss/svs/IndexSVSFaissUtils.h

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@
4545
// create svs_runtime as alias for svs::runtime::FAISS_SVS_RUNTIME_VERSION
4646
SVS_RUNTIME_CREATE_API_ALIAS(svs_runtime, FAISS_SVS_RUNTIME_VERSION);
4747

48-
// SVS forward declarations
49-
namespace svs {
50-
namespace runtime {
51-
inline namespace v0 {
52-
struct FlatIndex;
53-
struct VamanaIndex;
54-
struct DynamicVamanaIndex;
55-
struct LeanVecTrainingData;
56-
} // namespace v0
57-
} // namespace runtime
58-
} // namespace svs
59-
6048
namespace faiss {
6149

6250
inline svs_runtime::MetricType to_svs_metric(faiss::MetricType metric) {
@@ -119,8 +107,8 @@ struct InputBufferConverter {
119107
std::vector<T> buffer;
120108
};
121109

122-
// Specialization for reinterpret cast when types are integral and have the same
123-
// size
110+
// Specialization for reinterpret cast when types are integral and have
111+
// the same size
124112
template <typename T, typename U>
125113
struct InputBufferConverter<
126114
T,
@@ -178,8 +166,8 @@ struct OutputBufferConverter {
178166
std::vector<T> buffer;
179167
};
180168

181-
// Specialization for reinterpret cast when types are integral and have the same
182-
// size
169+
// Specialization for reinterpret cast when types are integral and have
170+
// the same size
183171
template <typename T, typename U>
184172
struct OutputBufferConverter<
185173
T,

faiss/svs/IndexSVSFlat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <faiss/Index.h>
2727
#include <faiss/svs/IndexSVSFaissUtils.h>
2828

29+
#include <svs/runtime/flat_index.h>
30+
2931
#include <iostream>
3032

3133
namespace faiss {

faiss/svs/IndexSVSVamana.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <faiss/svs/IndexSVSFaissUtils.h>
2828

2929
#include <svs/runtime/api_defs.h>
30+
#include <svs/runtime/dynamic_vamana_index.h>
3031

3132
#include <iostream>
3233

faiss/svs/IndexSVSVamanaLeanVec.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ void IndexSVSVamanaLeanVec::add(idx_t n, const float* x) {
6666
}
6767

6868
void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
69+
train(n, x, 0, nullptr);
70+
}
71+
72+
void IndexSVSVamanaLeanVec::train(
73+
idx_t n,
74+
const float* x,
75+
idx_t n_train_q,
76+
const float* queries) {
6977
FAISS_THROW_IF_MSG(
7078
training_data || impl, "Index already trained or contains data.");
7179

@@ -74,7 +82,7 @@ void IndexSVSVamanaLeanVec::train(idx_t n, const float* x) {
7482
"LVQ/LeanVec support not available on this platform or build");
7583

7684
auto status = svs_runtime::LeanVecTrainingData::build(
77-
&training_data, d, n, x, leanvec_d);
85+
&training_data, d, n, x, n_train_q, queries, leanvec_d);
7886
if (!status.ok()) {
7987
FAISS_THROW_MSG(status.message());
8088
}

faiss/svs/IndexSVSVamanaLeanVec.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,17 @@ struct IndexSVSVamanaLeanVec : IndexSVSVamana {
4141

4242
void add(idx_t n, const float* x) override;
4343

44+
/* Default train assumes in-distribution data */
4445
void train(idx_t n, const float* x) override;
4546

47+
/* Generic train with out-of-distribution parameters.
48+
* Out-of-distribution (OOD) means database vectors and queries _can_ be
49+
* sampled from different distributions (e.g., cross-modal). More details in
50+
* the original publication, arXiv:2312.16335.
51+
*/
52+
void train(idx_t n, const float* x, idx_t n_train_q, const float* xq_train)
53+
override;
54+
4655
void serialize_training_data(std::ostream& out) const;
4756
void deserialize_training_data(std::istream& in);
4857

tests/test_svs_py.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,5 +580,41 @@ def _create_instance(self):
580580
return idx
581581

582582

583+
@unittest.skipIf(_SKIP_SVS_LL, _SKIP_SVS_LL_REASON)
584+
class TestSVSLeanVecOOD(unittest.TestCase):
585+
"""Test out-of-distribution training for LeanVec SVS indices"""
586+
587+
def setUp(self):
588+
self.d = 256
589+
self.idx = faiss.IndexSVSVamanaLeanVec(
590+
self.d, 64, faiss.METRIC_INNER_PRODUCT, 64, faiss.SVS_LeanVec4x8
591+
)
592+
self.idx.alpha = 0.95
593+
594+
self.x = np.random.rand(1000, self.d).astype("float32")
595+
self.tq = np.random.rand(1000, self.d).astype("float32")
596+
597+
def test_svs_leanvec_ood_training(self):
598+
self.assertIsNone(self.idx.training_data)
599+
self.idx.train(self.x, xq_train=self.tq)
600+
self.assertIsNotNone(self.idx.training_data)
601+
602+
def test_svs_leanvec_ood_training_smaller(self):
603+
self.idx.train(self.x, xq_train=self.tq[:500])
604+
605+
def test_svs_leanvec_ood_training_wrong_dim(self):
606+
wrong_dim = np.random.rand(1000, self.d + 1).astype("float32")
607+
with self.assertRaises(AssertionError):
608+
self.idx.train(self.x, xq_train=wrong_dim)
609+
610+
def test_svs_leanvec_ood_training_wrong_type(self):
611+
with self.assertRaises(TypeError):
612+
self.idx.train(
613+
self.x,
614+
xq_train=self.tq,
615+
numeric_type=faiss.Float16,
616+
)
617+
618+
583619
if __name__ == '__main__':
584620
unittest.main()

0 commit comments

Comments
 (0)