Skip to content

Commit 4453677

Browse files
feat: add binary vector support for HNSW with 1-bit-direct quantization (zilliztech#1347)
- Add QT_1bit_direct quantizer type for binary vectors - Add BinarySQDistanceComputerWrapper for binary distance computation - Implement binary data format conversion (bin1) in HNSW index Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent b780811 commit 4453677

15 files changed

+293
-37
lines changed

src/index/hnsw/faiss_hnsw.cc

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,15 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
490490
}
491491
}
492492
return true;
493+
} else if (src_data_format == DataFormatEnum::bin1) {
494+
const knowhere::bin1* const src = reinterpret_cast<const knowhere::bin1*>(src_in);
495+
auto uint8_dim = (dim + 7) / 8;
496+
for (size_t i = 0; i < nrows; i++) {
497+
for (size_t j = 0; j < uint8_dim; ++j) {
498+
dst[i * dim + j] = (float)(src[offsets[i] * uint8_dim + j]);
499+
}
500+
}
501+
return true;
493502
} else {
494503
// unknown
495504
return false;
@@ -524,6 +533,23 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
524533
dst[i] = (float)(src[i + start_row * dim]);
525534
}
526535
return true;
536+
} else if (src_data_format == DataFormatEnum::bin1) {
537+
// NOTE: This is a little bit weird conversion. The source (`src_in`) is a uint8_t byte stream,
538+
// where each query_row has ((dim + 7) / 8) * 8 bits, and the total is nrows * ((dim + 7) / 8) * 8 bits.
539+
// But the final format required is nrows * dim * 32 bits (float).
540+
// There are actually two conversions happening here:
541+
// 1. Each uint8_t value must be converted to float (in `BinarySQDistanceComputerWrapper::set_query`
542+
// and `ScalarQuantizer::compute_codes`), it will be converted back to uint8_t). [same as int8]
543+
// 2. Each row must occupy dim * 32 bits of space, even if not all bits are filled;
544+
// this is required by the convention set in `ScalarQuantizer::compute_codes`.
545+
const knowhere::bin1* const src = reinterpret_cast<const knowhere::bin1*>(src_in);
546+
auto uint8_dim = (dim + 7) / 8;
547+
for (size_t i = 0; i < nrows; i++) {
548+
for (size_t j = 0; j < uint8_dim; j++) {
549+
dst[i * dim + j] = (float)(src[(start_row + i) * uint8_dim + j]);
550+
}
551+
}
552+
return true;
527553
} else {
528554
// unknown
529555
return false;
@@ -561,6 +587,16 @@ convert_rows_from_fp32(const float* const __restrict src, void* const __restrict
561587
dst[i + start_row * dim] = (knowhere::int8)src[i];
562588
}
563589
return true;
590+
} else if (dst_data_format == DataFormatEnum::bin1) {
591+
knowhere::bin1* const dst = reinterpret_cast<knowhere::bin1*>(dst_in);
592+
auto uint8_dim = (dim + 7) / 8;
593+
for (size_t i = 0; i < nrows * uint8_dim; i++) {
594+
KNOWHERE_THROW_IF_NOT_MSG(src[i] >= std::numeric_limits<knowhere::bin1>::min() &&
595+
src[i] <= std::numeric_limits<knowhere::bin1>::max(),
596+
"convert float to bin1(uint8_t) overflow");
597+
dst[i + start_row * uint8_dim] = (knowhere::bin1)src[i];
598+
}
599+
return true;
564600
} else {
565601
// unknown
566602
return false;
@@ -578,6 +614,8 @@ convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) {
578614
return ConvertFromDataTypeIfNeeded<knowhere::bf16>(src);
579615
} else if (data_format == DataFormatEnum::int8) {
580616
return ConvertFromDataTypeIfNeeded<knowhere::int8>(src);
617+
} else if (data_format == DataFormatEnum::bin1) {
618+
return ConvertFromDataTypeIfNeeded<knowhere::bin1>(src);
581619
}
582620
return nullptr;
583621
}
@@ -675,6 +713,8 @@ get_index_data_format(const faiss::Index* index) {
675713
return DataFormatEnum::fp16;
676714
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_8bit_direct_signed) {
677715
return DataFormatEnum::int8;
716+
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_1bit_direct) {
717+
return DataFormatEnum::bin1;
678718
} else {
679719
return std::nullopt;
680720
}
@@ -1182,6 +1222,24 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
11821222
}
11831223
}
11841224
return GenResultDataSet(rows, dim, std::move(data));
1225+
} else if (data_format == DataFormatEnum::bin1) {
1226+
auto uint8_dim = (dim + 7) / 8;
1227+
auto data = std::make_unique<knowhere::bin1[]>(uint8_dim * rows);
1228+
// faiss produces fp32 data format, we need some other format.
1229+
// Let's create a temporary fp32 buffer for this.
1230+
auto tmp = std::make_unique<float[]>(uint8_dim);
1231+
for (int64_t i = 0; i < rows; i++) {
1232+
const int64_t id = ids[i];
1233+
assert(id >= 0 && id < Count());
1234+
if (!get_vector(id, tmp.get())) {
1235+
return expected<DataSetPtr>::Err(Status::invalid_index_error,
1236+
"index inner error, cannot proceed with GetVectorByIds");
1237+
}
1238+
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
1239+
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
1240+
}
1241+
}
1242+
return GenResultDataSet(rows, dim, std::move(data));
11851243
} else {
11861244
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
11871245
}
@@ -1785,7 +1843,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
17851843
bool
17861844
is_ann_iterator_supported() const {
17871845
if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 &&
1788-
data_format != DataFormatEnum::bf16 && data_format != DataFormatEnum::int8) {
1846+
data_format != DataFormatEnum::bf16 && data_format != DataFormatEnum::int8 &&
1847+
data_format != DataFormatEnum::bin1) {
17891848
return false;
17901849
}
17911850
return true;
@@ -1848,6 +1907,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
18481907
case DataFormatEnum::fp16:
18491908
case DataFormatEnum::bf16:
18501909
case DataFormatEnum::int8:
1910+
case DataFormatEnum::bin1:
18511911
convert_rows_to_fp32(data, cur_query.get(), data_format, i, 1, dim);
18521912
break;
18531913
default:
@@ -1925,42 +1985,55 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
19251985

19261986
// create an index
19271987
const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE);
1988+
const bool is_binary = data_format == DataFormatEnum::bin1;
19281989

19291990
std::unique_ptr<faiss::IndexHNSW> hnsw_index;
19301991
auto train_index = [&](const float* data, const int i, const int64_t rows) {
1931-
if (is_cosine) {
1932-
if (data_format == DataFormatEnum::fp32) {
1933-
hnsw_index = std::make_unique<faiss::IndexHNSWFlatCosine>(dim, hnsw_cfg.M.value());
1934-
} else if (data_format == DataFormatEnum::fp16) {
1935-
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_fp16,
1936-
hnsw_cfg.M.value());
1937-
} else if (data_format == DataFormatEnum::bf16) {
1938-
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
1939-
hnsw_cfg.M.value());
1940-
} else if (data_format == DataFormatEnum::int8) {
1941-
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
1942-
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value());
1992+
if (is_binary) {
1993+
if (metric.value() == faiss::MetricType::METRIC_Hamming ||
1994+
metric.value() == faiss::MetricType::METRIC_Jaccard) {
1995+
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_1bit_direct,
1996+
hnsw_cfg.M.value(), metric.value());
19431997
} else {
1944-
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
1998+
LOG_KNOWHERE_ERROR_ << "Unsupported metric for binary data: " << hnsw_cfg.metric_type.value();
19451999
return Status::invalid_metric_type;
19462000
}
19472001
} else {
1948-
if (data_format == DataFormatEnum::fp32) {
1949-
hnsw_index = std::make_unique<faiss::IndexHNSWFlat>(dim, hnsw_cfg.M.value(), metric.value());
1950-
} else if (data_format == DataFormatEnum::fp16) {
1951-
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_fp16,
1952-
hnsw_cfg.M.value(), metric.value());
1953-
} else if (data_format == DataFormatEnum::bf16) {
1954-
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
1955-
hnsw_cfg.M.value(), metric.value());
1956-
} else if (data_format == DataFormatEnum::int8) {
1957-
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(
1958-
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value(), metric.value());
2002+
if (is_cosine) {
2003+
if (data_format == DataFormatEnum::fp32) {
2004+
hnsw_index = std::make_unique<faiss::IndexHNSWFlatCosine>(dim, hnsw_cfg.M.value());
2005+
} else if (data_format == DataFormatEnum::fp16) {
2006+
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_fp16,
2007+
hnsw_cfg.M.value());
2008+
} else if (data_format == DataFormatEnum::bf16) {
2009+
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
2010+
hnsw_cfg.M.value());
2011+
} else if (data_format == DataFormatEnum::int8) {
2012+
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
2013+
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value());
2014+
} else {
2015+
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
2016+
return Status::invalid_metric_type;
2017+
}
19592018
} else {
1960-
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
1961-
return Status::invalid_metric_type;
2019+
if (data_format == DataFormatEnum::fp32) {
2020+
hnsw_index = std::make_unique<faiss::IndexHNSWFlat>(dim, hnsw_cfg.M.value(), metric.value());
2021+
} else if (data_format == DataFormatEnum::fp16) {
2022+
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_fp16,
2023+
hnsw_cfg.M.value(), metric.value());
2024+
} else if (data_format == DataFormatEnum::bf16) {
2025+
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
2026+
hnsw_cfg.M.value(), metric.value());
2027+
} else if (data_format == DataFormatEnum::int8) {
2028+
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(
2029+
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value(), metric.value());
2030+
} else {
2031+
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
2032+
return Status::invalid_metric_type;
2033+
}
19622034
}
19632035
}
2036+
19642037
hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value();
19652038
// train
19662039
LOG_KNOWHERE_INFO_ << "Training HNSW Index";
@@ -2947,6 +3020,8 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWF
29473020
knowhere::feature::MMAP | knowhere::feature::MV)
29483021
KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
29493022
knowhere::feature::MMAP | knowhere::feature::MV)
3023+
KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
3024+
knowhere::feature::MMAP | knowhere::feature::MV)
29503025
#endif
29513026

29523027
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,

tests/ut/test_get_vector.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,13 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
4545
return json;
4646
};
4747

48-
#ifdef KNOWHERE_WITH_CARDINAL
4948
auto bin_hnsw_gen = [base_bin_gen]() {
5049
knowhere::Json json = base_bin_gen();
5150
json[knowhere::indexparam::HNSW_M] = 128;
5251
json[knowhere::indexparam::EFCONSTRUCTION] = 100;
5352
json[knowhere::indexparam::EF] = 64;
5453
return json;
5554
};
56-
#endif
5755

5856
auto bin_flat_gen = base_bin_gen;
5957

@@ -62,9 +60,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
6260
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
6361
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, bin_flat_gen),
6462
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen),
65-
#ifdef KNOWHERE_WITH_CARDINAL
6663
make_tuple(knowhere::IndexEnum::INDEX_HNSW, bin_hnsw_gen),
67-
#endif
6864
}));
6965
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
7066
auto cfg_json = gen().dump();

tests/ut/test_search.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") {
603603
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
604604
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen),
605605
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen),
606-
#ifdef KNOWHERE_WITH_CARDINAL
607606
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
608-
#endif
609607
}));
610608
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
611609
auto cfg_json = gen().dump();

thirdparty/faiss/faiss/FaissHook.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ sq_sel_quantizer_func_ptr sq_sel_quantizer = sq_select_quantizer_ref;
1717
sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner =
1818
sq_select_inverted_list_scanner_ref;
1919

20+
// Note: The Hamming computer implementation is selected at compile time
21+
// based on the instruction set in `hamdis-inl.h`, not by runtime hook.
22+
sq_get_distance_computer_func_ptr sq_get_hamming_distance_computer =
23+
sq_get_hamming_distance_computer_ref;
24+
25+
// Note: The Jaccard distance computer uses `__builtin_popcount` for
26+
// computation. This function is efficiently implemented by the
27+
// compiler and automatically utilizes the best available instruction set.
28+
// Therefore, there is no need to manually adjust or hook the Jaccard computer
29+
// for different SIMD instruction sets.
30+
sq_get_distance_computer_func_ptr sq_get_jaccard_distance_computer =
31+
sq_get_jaccard_distance_computer_ref;
32+
2033
void sq_hook() {
2134
// SQ8 always hook best SIMD
2235
#ifdef __x86_64__

thirdparty/faiss/faiss/FaissHook.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ typedef InvertedListScanner* (*sq_sel_inv_list_scanner_func_ptr)(
3333
bool);
3434

3535
extern sq_get_distance_computer_func_ptr sq_get_distance_computer;
36+
extern sq_get_distance_computer_func_ptr sq_get_hamming_distance_computer;
37+
extern sq_get_distance_computer_func_ptr sq_get_jaccard_distance_computer;
3638
extern sq_sel_quantizer_func_ptr sq_sel_quantizer;
3739
extern sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner;
3840
void sq_hook();

thirdparty/faiss/faiss/IndexScalarQuantizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ IndexScalarQuantizer::IndexScalarQuantizer(
3434
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
3535
qtype == ScalarQuantizer::QT_8bit_direct ||
3636
qtype == ScalarQuantizer::QT_bf16 ||
37-
qtype == ScalarQuantizer::QT_8bit_direct_signed;
37+
qtype == ScalarQuantizer::QT_8bit_direct_signed ||
38+
qtype == ScalarQuantizer::QT_1bit_direct;
3839
code_size = sq.code_size;
3940
}
4041

thirdparty/faiss/faiss/MetricType.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ using idx_t = int64_t;
5252
/// this function is used to distinguish between min and max indexes since
5353
/// we need to support similarity and dis-similarity metrics in a flexible way
5454
constexpr bool is_similarity_metric(MetricType metric_type) {
55-
return ((metric_type == METRIC_INNER_PRODUCT) ||
56-
(metric_type == METRIC_Jaccard));
55+
return metric_type == METRIC_INNER_PRODUCT;
5756
}
5857

5958
} // namespace faiss

thirdparty/faiss/faiss/impl/ScalarQuantizer.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ void ScalarQuantizer::set_derived_sizes() {
9696
code_size = d * 2;
9797
bits = 16;
9898
break;
99+
case QT_1bit_direct:
100+
code_size = (d + 7) / 8;
101+
bits = 1;
102+
break;
99103
}
100104
}
101105

@@ -105,6 +109,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
105109
: qtype == QT_6bit ? 6
106110
: qtype == QT_8bit_uniform ? 8
107111
: qtype == QT_8bit ? 8
112+
: qtype == QT_1bit_direct ? 1
108113
: -1;
109114

110115
switch (qtype) {
@@ -134,6 +139,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
134139
case QT_8bit_direct:
135140
case QT_bf16:
136141
case QT_8bit_direct_signed:
142+
case QT_1bit_direct:
137143
// no training necessary
138144
break;
139145
}
@@ -164,8 +170,18 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
164170

165171
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
166172
MetricType metric) const {
167-
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
173+
FAISS_THROW_IF_NOT(
174+
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT ||
175+
metric == METRIC_Hamming || metric == METRIC_Jaccard);
168176
/* use hook to decide use AVX512 or not */
177+
if (metric == METRIC_Hamming) {
178+
assert(qtype == QT_1bit_direct);
179+
return sq_get_hamming_distance_computer(metric, qtype, d, trained);
180+
}
181+
if (metric == METRIC_Jaccard) {
182+
assert(qtype == QT_1bit_direct);
183+
return sq_get_jaccard_distance_computer(metric, qtype, d, trained);
184+
}
169185
return sq_get_distance_computer(metric, qtype, d, trained);
170186
}
171187

thirdparty/faiss/faiss/impl/ScalarQuantizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct ScalarQuantizer : Quantizer {
3535
QT_bf16,
3636
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
3737
///< [-128 to 127]
38+
QT_1bit_direct, ///< fast indexing of 1 bit per component
3839
};
3940

4041
QuantizerType qtype = QT_8bit;

0 commit comments

Comments
 (0)