@@ -490,6 +490,15 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
490490 }
491491 }
492492 return true ;
493+ } else if (src_data_format == DataFormatEnum::bin1) {
494+ const knowhere::bin1* const src = reinterpret_cast <const knowhere::bin1*>(src_in);
495+ auto uint8_dim = (dim + 7 ) / 8 ;
496+ for (size_t i = 0 ; i < nrows; i++) {
497+ for (size_t j = 0 ; j < uint8_dim; ++j) {
498+ dst[i * dim + j] = (float )(src[offsets[i] * uint8_dim + j]);
499+ }
500+ }
501+ return true ;
493502 } else {
494503 // unknown
495504 return false ;
@@ -524,6 +533,23 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
524533 dst[i] = (float )(src[i + start_row * dim]);
525534 }
526535 return true ;
536+ } else if (src_data_format == DataFormatEnum::bin1) {
537+ // NOTE: This is a little bit weird conversion. The source (`src_in`) is a uint8_t byte stream,
538+ // where each query_row has ((dim + 7) / 8) * 8 bits, and the total is nrows * ((dim + 7) / 8) * 8 bits.
539+ // But the final format required is nrows * dim * 32 bits (float).
540+ // There are actually two conversions happening here:
541+ // 1. Each uint8_t value must be converted to float (in `BinarySQDistanceComputerWrapper::set_query`
542+ // and `ScalarQuantizer::compute_codes`), it will be converted back to uint8_t). [same as int8]
543+ // 2. Each row must occupy dim * 32 bits of space, even if not all bits are filled;
544+ // this is required by the convention set in `ScalarQuantizer::compute_codes`.
545+ const knowhere::bin1* const src = reinterpret_cast <const knowhere::bin1*>(src_in);
546+ auto uint8_dim = (dim + 7 ) / 8 ;
547+ for (size_t i = 0 ; i < nrows; i++) {
548+ for (size_t j = 0 ; j < uint8_dim; j++) {
549+ dst[i * dim + j] = (float )(src[(start_row + i) * uint8_dim + j]);
550+ }
551+ }
552+ return true ;
527553 } else {
528554 // unknown
529555 return false ;
@@ -561,6 +587,16 @@ convert_rows_from_fp32(const float* const __restrict src, void* const __restrict
561587 dst[i + start_row * dim] = (knowhere::int8)src[i];
562588 }
563589 return true ;
590+ } else if (dst_data_format == DataFormatEnum::bin1) {
591+ knowhere::bin1* const dst = reinterpret_cast <knowhere::bin1*>(dst_in);
592+ auto uint8_dim = (dim + 7 ) / 8 ;
593+ for (size_t i = 0 ; i < nrows * uint8_dim; i++) {
594+ KNOWHERE_THROW_IF_NOT_MSG (src[i] >= std::numeric_limits<knowhere::bin1>::min () &&
595+ src[i] <= std::numeric_limits<knowhere::bin1>::max (),
596+ " convert float to bin1(uint8_t) overflow" );
597+ dst[i + start_row * uint8_dim] = (knowhere::bin1)src[i];
598+ }
599+ return true ;
564600 } else {
565601 // unknown
566602 return false ;
@@ -578,6 +614,8 @@ convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) {
578614 return ConvertFromDataTypeIfNeeded<knowhere::bf16 >(src);
579615 } else if (data_format == DataFormatEnum::int8) {
580616 return ConvertFromDataTypeIfNeeded<knowhere::int8>(src);
617+ } else if (data_format == DataFormatEnum::bin1) {
618+ return ConvertFromDataTypeIfNeeded<knowhere::bin1>(src);
581619 }
582620 return nullptr ;
583621}
@@ -675,6 +713,8 @@ get_index_data_format(const faiss::Index* index) {
675713 return DataFormatEnum::fp16;
676714 } else if (index_sq->sq .qtype == faiss::ScalarQuantizer::QT_8bit_direct_signed) {
677715 return DataFormatEnum::int8;
716+ } else if (index_sq->sq .qtype == faiss::ScalarQuantizer::QT_1bit_direct) {
717+ return DataFormatEnum::bin1;
678718 } else {
679719 return std::nullopt ;
680720 }
@@ -1182,6 +1222,24 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
11821222 }
11831223 }
11841224 return GenResultDataSet (rows, dim, std::move (data));
1225+ } else if (data_format == DataFormatEnum::bin1) {
1226+ auto uint8_dim = (dim + 7 ) / 8 ;
1227+ auto data = std::make_unique<knowhere::bin1[]>(uint8_dim * rows);
1228+ // faiss produces fp32 data format, we need some other format.
1229+ // Let's create a temporary fp32 buffer for this.
1230+ auto tmp = std::make_unique<float []>(uint8_dim);
1231+ for (int64_t i = 0 ; i < rows; i++) {
1232+ const int64_t id = ids[i];
1233+ assert (id >= 0 && id < Count ());
1234+ if (!get_vector (id, tmp.get ())) {
1235+ return expected<DataSetPtr>::Err (Status::invalid_index_error,
1236+ " index inner error, cannot proceed with GetVectorByIds" );
1237+ }
1238+ if (!convert_rows_from_fp32 (tmp.get (), data.get (), data_format, i, 1 , dim)) {
1239+ return expected<DataSetPtr>::Err (Status::invalid_args, " Unsupported data format" );
1240+ }
1241+ }
1242+ return GenResultDataSet (rows, dim, std::move (data));
11851243 } else {
11861244 return expected<DataSetPtr>::Err (Status::invalid_args, " Unsupported data format" );
11871245 }
@@ -1785,7 +1843,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
17851843 bool
17861844 is_ann_iterator_supported () const {
17871845 if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 &&
1788- data_format != DataFormatEnum::bf16 && data_format != DataFormatEnum::int8) {
1846+ data_format != DataFormatEnum::bf16 && data_format != DataFormatEnum::int8 &&
1847+ data_format != DataFormatEnum::bin1) {
17891848 return false ;
17901849 }
17911850 return true ;
@@ -1848,6 +1907,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
18481907 case DataFormatEnum::fp16:
18491908 case DataFormatEnum::bf16 :
18501909 case DataFormatEnum::int8:
1910+ case DataFormatEnum::bin1:
18511911 convert_rows_to_fp32 (data, cur_query.get (), data_format, i, 1 , dim);
18521912 break ;
18531913 default :
@@ -1925,42 +1985,55 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
19251985
19261986 // create an index
19271987 const bool is_cosine = IsMetricType (hnsw_cfg.metric_type .value (), metric::COSINE);
1988+ const bool is_binary = data_format == DataFormatEnum::bin1;
19281989
19291990 std::unique_ptr<faiss::IndexHNSW> hnsw_index;
19301991 auto train_index = [&](const float * data, const int i, const int64_t rows) {
1931- if (is_cosine) {
1932- if (data_format == DataFormatEnum::fp32) {
1933- hnsw_index = std::make_unique<faiss::IndexHNSWFlatCosine>(dim, hnsw_cfg.M .value ());
1934- } else if (data_format == DataFormatEnum::fp16) {
1935- hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_fp16,
1936- hnsw_cfg.M .value ());
1937- } else if (data_format == DataFormatEnum::bf16 ) {
1938- hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
1939- hnsw_cfg.M .value ());
1940- } else if (data_format == DataFormatEnum::int8) {
1941- hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
1942- dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value ());
1992+ if (is_binary) {
1993+ if (metric.value () == faiss::MetricType::METRIC_Hamming ||
1994+ metric.value () == faiss::MetricType::METRIC_Jaccard) {
1995+ hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_1bit_direct,
1996+ hnsw_cfg.M .value (), metric.value ());
19431997 } else {
1944- LOG_KNOWHERE_ERROR_ << " Unsupported metric type : " << hnsw_cfg.metric_type .value ();
1998+ LOG_KNOWHERE_ERROR_ << " Unsupported metric for binary data : " << hnsw_cfg.metric_type .value ();
19451999 return Status::invalid_metric_type;
19462000 }
19472001 } else {
1948- if (data_format == DataFormatEnum::fp32) {
1949- hnsw_index = std::make_unique<faiss::IndexHNSWFlat>(dim, hnsw_cfg.M .value (), metric.value ());
1950- } else if (data_format == DataFormatEnum::fp16) {
1951- hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_fp16,
1952- hnsw_cfg.M .value (), metric.value ());
1953- } else if (data_format == DataFormatEnum::bf16 ) {
1954- hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
1955- hnsw_cfg.M .value (), metric.value ());
1956- } else if (data_format == DataFormatEnum::int8) {
1957- hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(
1958- dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value (), metric.value ());
2002+ if (is_cosine) {
2003+ if (data_format == DataFormatEnum::fp32) {
2004+ hnsw_index = std::make_unique<faiss::IndexHNSWFlatCosine>(dim, hnsw_cfg.M .value ());
2005+ } else if (data_format == DataFormatEnum::fp16) {
2006+ hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_fp16,
2007+ hnsw_cfg.M .value ());
2008+ } else if (data_format == DataFormatEnum::bf16 ) {
2009+ hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
2010+ hnsw_cfg.M .value ());
2011+ } else if (data_format == DataFormatEnum::int8) {
2012+ hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
2013+ dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value ());
2014+ } else {
2015+ LOG_KNOWHERE_ERROR_ << " Unsupported metric type: " << hnsw_cfg.metric_type .value ();
2016+ return Status::invalid_metric_type;
2017+ }
19592018 } else {
1960- LOG_KNOWHERE_ERROR_ << " Unsupported metric type: " << hnsw_cfg.metric_type .value ();
1961- return Status::invalid_metric_type;
2019+ if (data_format == DataFormatEnum::fp32) {
2020+ hnsw_index = std::make_unique<faiss::IndexHNSWFlat>(dim, hnsw_cfg.M .value (), metric.value ());
2021+ } else if (data_format == DataFormatEnum::fp16) {
2022+ hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_fp16,
2023+ hnsw_cfg.M .value (), metric.value ());
2024+ } else if (data_format == DataFormatEnum::bf16 ) {
2025+ hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
2026+ hnsw_cfg.M .value (), metric.value ());
2027+ } else if (data_format == DataFormatEnum::int8) {
2028+ hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(
2029+ dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value (), metric.value ());
2030+ } else {
2031+ LOG_KNOWHERE_ERROR_ << " Unsupported metric type: " << hnsw_cfg.metric_type .value ();
2032+ return Status::invalid_metric_type;
2033+ }
19622034 }
19632035 }
2036+
19642037 hnsw_index->hnsw .efConstruction = hnsw_cfg.efConstruction .value ();
19652038 // train
19662039 LOG_KNOWHERE_INFO_ << " Training HNSW Index" ;
@@ -2947,6 +3020,8 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWF
29473020 knowhere::feature::MMAP | knowhere::feature::MV)
29483021KNOWHERE_SIMPLE_REGISTER_DENSE_INT_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
29493022 knowhere::feature::MMAP | knowhere::feature::MV)
3023+ KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
3024+ knowhere::feature::MMAP | knowhere::feature::MV)
29503025#endif
29513026
29523027KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL (HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
0 commit comments