Skip to content

Commit ac6c671

Browse files
ltamasifacebook-github-bot
authored andcommitted
Add a couple of convenience methods for converting embeddings (#13329)
Summary: Pull Request resolved: #13329 The patch adds two convenience methods `ConvertFloatsToSlice` and `ConvertSliceToFloats` that can be used to convert embeddings from a contiguous range of floats to a RocksDB `Slice` or vice versa. The methods are added to the public API so they can be utilized by applications as well. Reviewed By: jowlyzhang Differential Revision: D68581494 fbshipit-source-id: 2207fa3e668a6546b7de6d8ab78be2ba9f2ffd8c
1 parent a8bd6a3 commit ac6c671

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

include/rocksdb/utilities/secondary_index_faiss.h

+21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <string>
1111

1212
#include "rocksdb/rocksdb_namespace.h"
13+
#include "rocksdb/slice.h"
1314
#include "rocksdb/utilities/secondary_index.h"
1415

1516
namespace faiss {
@@ -29,4 +30,24 @@ namespace ROCKSDB_NAMESPACE {
2930
std::unique_ptr<SecondaryIndex> NewFaissIVFIndex(
3031
std::unique_ptr<faiss::IndexIVF>&& index, std::string primary_column_name);
3132

33+
// Helper methods to convert embeddings from a span of floats to Slice or vice
34+
// versa
35+
36+
// Convert the given span of floats of size dim to a Slice.
37+
// PRE: embedding points to a contiguous span of floats of size dim
38+
inline Slice ConvertFloatsToSlice(const float* embedding, size_t dim) {
39+
return Slice(reinterpret_cast<const char*>(embedding), dim * sizeof(float));
40+
}
41+
42+
// Convert the given Slice to a span of floats of size dim.
43+
// PRE: embedding.size() == dim * sizeof(float)
44+
// Returns nullptr if the precondition is violated.
45+
inline const float* ConvertSliceToFloats(const Slice& embedding, size_t dim) {
46+
if (embedding.size() != dim * sizeof(float)) {
47+
return nullptr;
48+
}
49+
50+
return reinterpret_cast<const float*>(embedding.data());
51+
}
52+
3253
} // namespace ROCKSDB_NAMESPACE

utilities/secondary_index/faiss_ivf_index.cc

+14-9
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ class FaissIVFIndex::KNNIterator : public Iterator {
7070
pos_ = 0;
7171
keys_.clear();
7272

73-
if (target.size() != index_->d * sizeof(float)) {
73+
const float* const embedding = ConvertSliceToFloats(target, index_->d);
74+
if (!embedding) {
7475
status_ = Status::InvalidArgument(
7576
"Incorrectly sized vector passed to FaissIVFIndex");
7677
return;
@@ -83,8 +84,8 @@ class FaissIVFIndex::KNNIterator : public Iterator {
8384
constexpr faiss::idx_t n = 1;
8485

8586
try {
86-
index_->search(n, reinterpret_cast<const float*>(target.data()), k_,
87-
distances_.data(), labels_.data(), &params);
87+
index_->search(n, embedding, k_, distances_.data(), labels_.data(),
88+
&params);
8889
} catch (const std::exception& e) {
8990
status_ = Status::InvalidArgument(e.what());
9091
}
@@ -364,7 +365,9 @@ Status FaissIVFIndex::UpdatePrimaryColumnValue(
364365
const {
365366
assert(updated_column_value);
366367

367-
if (primary_column_value.size() != index_->d * sizeof(float)) {
368+
const float* const embedding =
369+
ConvertSliceToFloats(primary_column_value, index_->d);
370+
if (!embedding) {
368371
return Status::InvalidArgument(
369372
"Incorrectly sized vector passed to FaissIVFIndex");
370373
}
@@ -373,8 +376,7 @@ Status FaissIVFIndex::UpdatePrimaryColumnValue(
373376
faiss::idx_t label = -1;
374377

375378
try {
376-
index_->quantizer->assign(
377-
n, reinterpret_cast<const float*>(primary_column_value.data()), &label);
379+
index_->quantizer->assign(n, embedding, &label);
378380
} catch (const std::exception& e) {
379381
return Status::InvalidArgument(e.what());
380382
}
@@ -420,13 +422,16 @@ Status FaissIVFIndex::GetSecondaryValue(
420422
assert(label < index_->nlist);
421423

422424
constexpr faiss::idx_t n = 1;
425+
426+
const float* const embedding =
427+
ConvertSliceToFloats(original_column_value, index_->d);
428+
assert(embedding);
429+
423430
constexpr faiss::idx_t* xids = nullptr;
424431
std::string code_str;
425432

426433
try {
427-
index_->add_core(
428-
n, reinterpret_cast<const float*>(original_column_value.data()), xids,
429-
&label, &code_str);
434+
index_->add_core(n, embedding, xids, &label, &code_str);
430435
} catch (const std::exception& e) {
431436
return Status::InvalidArgument(e.what());
432437
}

utilities/secondary_index/faiss_ivf_index_test.cc

+8-16
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ TEST(FaissIVFIndexTest, Basic) {
7474
cfh1, primary_key,
7575
WideColumns{
7676
{primary_column_name,
77-
Slice(reinterpret_cast<const char*>(embeddings.data() + i * dim),
78-
dim * sizeof(float))}}));
77+
ConvertFloatsToSlice(embeddings.data() + i * dim, dim)}}));
7978
}
8079

8180
ASSERT_OK(txn->Commit());
@@ -102,10 +101,8 @@ TEST(FaissIVFIndexTest, Basic) {
102101

103102
// Since we use IndexIVFFlat, there is no fine quantization, so the code
104103
// is actually just the original embedding
105-
ASSERT_EQ(
106-
it->value(),
107-
Slice(reinterpret_cast<const char*>(embeddings.data() + id * dim),
108-
dim * sizeof(float)));
104+
ASSERT_EQ(it->value(),
105+
ConvertFloatsToSlice(embeddings.data() + id * dim, dim));
109106

110107
++num_found;
111108
}
@@ -159,9 +156,7 @@ TEST(FaissIVFIndexTest, Basic) {
159156
// Search for a vector from the original set; we expect to find the vector
160157
// itself as the closest match, since we're performing an exhaustive search
161158
{
162-
it->Seek(
163-
Slice(reinterpret_cast<const char*>(embeddings.data() + id * dim),
164-
dim * sizeof(float)));
159+
it->Seek(ConvertFloatsToSlice(embeddings.data() + id * dim, dim));
165160
ASSERT_TRUE(it->Valid());
166161
ASSERT_OK(it->status());
167162
ASSERT_EQ(get_id(), id);
@@ -225,8 +220,7 @@ TEST(FaissIVFIndexTest, Basic) {
225220
ASSERT_FALSE(it->Valid());
226221
ASSERT_TRUE(it->status().IsNotSupported());
227222

228-
it->SeekForPrev(Slice(reinterpret_cast<const char*>(embeddings.data()),
229-
dim * sizeof(float)));
223+
it->SeekForPrev(ConvertFloatsToSlice(embeddings.data(), dim));
230224
ASSERT_FALSE(it->Valid());
231225
ASSERT_TRUE(it->status().IsNotSupported());
232226

@@ -354,8 +348,7 @@ TEST(FaissIVFIndexTest, Compare) {
354348

355349
const std::string primary_key = std::to_string(i);
356350
ASSERT_OK(db->Put(WriteOptions(), cfh1, primary_key,
357-
Slice(reinterpret_cast<const char*>(embedding),
358-
dim * sizeof(float))));
351+
ConvertFloatsToSlice(embedding, dim)));
359352
}
360353
}
361354

@@ -413,9 +406,8 @@ TEST(FaissIVFIndexTest, Compare) {
413406
}
414407

415408
size_t num_found = 0;
416-
for (it->Seek(Slice(reinterpret_cast<const char*>(embedding),
417-
dim * sizeof(float)));
418-
it->Valid(); it->Next()) {
409+
for (it->Seek(ConvertFloatsToSlice(embedding, dim)); it->Valid();
410+
it->Next()) {
419411
const faiss::idx_t id = get_id();
420412
ASSERT_GE(id, 0);
421413
ASSERT_LT(id, num_db);

0 commit comments

Comments
 (0)