Skip to content

Commit 2956cc8

Browse files
authored
enhance: scann supprt search parameters 'ensure_topk_full' (zilliztech#1072)
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
1 parent 368e43b commit 2956cc8

File tree

8 files changed

+131
-5
lines changed

8 files changed

+131
-5
lines changed

src/index/data_view_dense_index/data_view_index_config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ AdaptToBaseIndexConfig(Config* cfg, PARAM_TYPE param_type, size_t dim) {
7474
auto reorder_k = int(base_cfg->k.value() * base_cfg->refine_ratio.value());
7575
base_cfg->k = reorder_k;
7676
base_cfg->reorder_k = reorder_k;
77+
base_cfg->ensure_topk_full = true;
7778
break;
7879
}
7980
case PARAM_TYPE::RANGE_SEARCH: {

src/index/ivf/ivf.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,21 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
775775
faiss::IVFSearchParameters base_search_params;
776776
base_search_params.sel = id_selector;
777777
base_search_params.nprobe = nprobe;
778+
base_search_params.ensure_topk_full = ivf_cfg.ensure_topk_full.value();
779+
if (base_search_params.ensure_topk_full) {
780+
if (auto base_index_ptr = reinterpret_cast<faiss::IndexIVFPQFastScan*>(index_->base_index)) {
781+
auto nlist = base_index_ptr->nlist;
782+
base_search_params.nprobe = nlist;
783+
// use max_codes to early termination
784+
base_search_params.max_codes = (nprobe * 1.0 / nlist) * (index_->ntotal - bitset.count());
785+
base_search_params.max_lists_num = nprobe;
786+
} else {
787+
throw std::runtime_error("invalid base index type of scann base index");
788+
}
789+
} else {
790+
base_search_params.nprobe = nprobe;
791+
base_search_params.max_codes = 0;
792+
}
778793

779794
faiss::IndexScaNNSearchParameters scann_search_params;
780795
scann_search_params.base_index_params = &base_search_params;

tests/ut/test_data_view_index.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,76 @@ TEST_CASE("Test SCANN with data view refiner", "[float metrics]") {
128128
}
129129
}
130130

131+
TEST_CASE("Ensure topk test", "[float metrics]") {
132+
using Catch::Approx;
133+
auto version = GenTestVersionList();
134+
if (!faiss::support_pq_fast_scan) {
135+
SKIP("pass scann test");
136+
}
137+
138+
const int64_t nb = 10000, nq = 10;
139+
auto metric = GENERATE(as<std::string>{}, knowhere::metric::COSINE, knowhere::metric::IP, knowhere::metric::L2);
140+
auto topk = nb;
141+
auto dim = GENERATE(as<int64_t>{}, 120);
142+
143+
auto base_gen = [=]() {
144+
knowhere::Json json;
145+
json[knowhere::meta::DIM] = dim;
146+
json[knowhere::meta::METRIC_TYPE] = metric;
147+
json[knowhere::meta::TOPK] = topk;
148+
json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99;
149+
json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01;
150+
return json;
151+
};
152+
153+
auto scann_gen = [base_gen, topk]() {
154+
knowhere::Json json = base_gen();
155+
json[knowhere::indexparam::NLIST] = 512;
156+
json[knowhere::indexparam::NPROBE] = 1;
157+
json[knowhere::indexparam::REFINE_RATIO] = 1.0;
158+
json[knowhere::indexparam::SUB_DIM] = 2;
159+
json[knowhere::indexparam::WITH_RAW_DATA] = true;
160+
json[knowhere::indexparam::ENSURE_TOPK_FULL] = true;
161+
return json;
162+
};
163+
164+
auto rand = GENERATE(1);
165+
const auto train_ds = GenDataSet(nb, dim, rand);
166+
const auto query_ds = GenDataSet(nq, dim, rand + 777);
167+
168+
const knowhere::Json conf = {
169+
{knowhere::meta::METRIC_TYPE, metric},
170+
{knowhere::meta::TOPK, topk},
171+
};
172+
knowhere::ViewDataOp data_view = [&train_ds, data_size = sizeof(float) * dim](size_t id) {
173+
auto data = train_ds->GetTensor();
174+
return data + data_size * id;
175+
};
176+
auto data_view_pack = knowhere::Pack(data_view);
177+
auto cfg_json = scann_gen().dump();
178+
knowhere::Json json = knowhere::Json::parse(cfg_json);
179+
180+
auto scann_with_dv_refiner =
181+
knowhere::IndexFactory::Instance()
182+
.Create<knowhere::fp32>(knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR, version, data_view_pack)
183+
.value();
184+
185+
REQUIRE(scann_with_dv_refiner.Type() == knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR);
186+
REQUIRE(scann_with_dv_refiner.Build(train_ds, json) == knowhere::Status::success);
187+
REQUIRE(scann_with_dv_refiner.Count() == nb);
188+
REQUIRE(scann_with_dv_refiner.Size() > 0);
189+
REQUIRE(scann_with_dv_refiner.HasRawData(metric) == false);
190+
REQUIRE(scann_with_dv_refiner.HasRawData(metric) ==
191+
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR, version,
192+
cfg_json));
193+
auto scann_with_dv_refiner_results = scann_with_dv_refiner.Search(query_ds, json, nullptr);
194+
auto res_ids = scann_with_dv_refiner_results.value()->GetIds();
195+
// check we can get all vectors in (topk = nb, nprobe = )
196+
for (auto i = 0; i < nq * topk; i++) {
197+
REQUIRE(res_ids[i] != -1);
198+
}
199+
}
200+
131201
template <typename DataType>
132202
void
133203
BaseTest(const knowhere::DataSetPtr train_ds, const knowhere::DataSetPtr query_ds, const int64_t k,

tests/ut/test_iterator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
192192
json[knowhere::indexparam::NPROBE] = 14;
193193
json[knowhere::indexparam::REORDER_K] = 200;
194194
json[knowhere::indexparam::WITH_RAW_DATA] = true;
195+
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
195196
return json;
196197
};
197198

thirdparty/faiss/faiss/IndexIVF.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ struct SearchParametersIVF : SearchParameters {
7777
///< to minimize code change, when users only use nprobe to search, this config does not take affect since we will first retrieve the nearest nprobe buckets
7878
///< it is a bit heavy to further retrieve more buckets
7979
///< therefore to make sure we get topk results, use nprobe=nlist and use max_codes to narrow down the search range
80+
size_t max_lists_num = 0; ///< select min{scanned number of (max_codes),
81+
///< scanned number of (max_lists_num) to return.}
8082
bool ensure_topk_full = false;
8183

8284
///< during IVF range search, if reach 'max_empty_result_buckets' num of

thirdparty/faiss/faiss/IndexIVFFastScan.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ void IndexIVFFastScan::search_preassigned(
363363
IndexIVFStats* stats) const {
364364
size_t nprobe = this->nprobe;
365365
if (params) {
366-
FAISS_THROW_IF_NOT(params->max_codes == 0);
367366
nprobe = params->nprobe;
368367
}
369368

@@ -591,7 +590,7 @@ void IndexIVFFastScan::search_dispatch_implem(
591590
int impl = implem;
592591

593592
if (impl == 0) {
594-
if (bbs == 32) {
593+
if (bbs == 32 && !params->ensure_topk_full) {
595594
impl = 12;
596595
} else {
597596
impl = 10;
@@ -671,7 +670,7 @@ void IndexIVFFastScan::search_dispatch_implem(
671670
)
672671
);
673672
search_implem_10(
674-
n, x, *handler.get(), cq,
673+
n, x, k, *handler.get(), cq,
675674
&ndis, &nlist_visited, scaler, params);
676675
}
677676
// clang-format on
@@ -704,7 +703,7 @@ void IndexIVFFastScan::search_dispatch_implem(
704703
cq_i, &ndis, &nlist_visited, scaler, params);
705704
} else {
706705
search_implem_10(
707-
i1 - i0, x + i0 * d, *handler.get(),
706+
i1 - i0, x + i0 * d,k, *handler.get(),
708707
cq_i, &ndis, &nlist_visited, scaler, params);
709708
}
710709
// clang-format on
@@ -1021,12 +1020,27 @@ void IndexIVFFastScan::search_implem_2(
10211020
void IndexIVFFastScan::search_implem_10(
10221021
idx_t n,
10231022
const float* x,
1023+
idx_t k,
10241024
SIMDResultHandlerToFloat& handler,
10251025
const CoarseQuantized& cq,
10261026
size_t* ndis_out,
10271027
size_t* nlist_out,
10281028
const NormTableScaler* scaler,
10291029
const IVFSearchParameters* params) const {
1030+
// const size_t nprobe = params ? params->nprobe : this->nprobe;
1031+
const bool ensure_topk_full = params ? params->ensure_topk_full : false;
1032+
size_t max_codes = params ? params->max_codes : this->max_codes;
1033+
size_t max_lists_num = params ? params->max_lists_num : nlist;
1034+
FAISS_THROW_IF_NOT_MSG(
1035+
n == 1 || !ensure_topk_full,
1036+
"ensure_topk_full can't be true if queries number larger than 1.");
1037+
if (max_codes == 0) {
1038+
max_codes = std::numeric_limits<idx_t>::max();
1039+
}
1040+
if (max_lists_num == 0) {
1041+
max_lists_num = nlist;
1042+
}
1043+
10301044
size_t dim12 = ksub * M2;
10311045
AlignedTable<uint8_t> dis_tables;
10321046
AlignedTable<uint16_t> biases;
@@ -1051,6 +1065,11 @@ void IndexIVFFastScan::search_implem_10(
10511065
LUT = dis_tables.get() + i * dim12;
10521066
}
10531067
for (size_t j = 0; j < nprobe; j++) {
1068+
auto nscan = handler.count_scanned_rows();
1069+
if ((nscan >= max_codes || j >= max_lists_num) &&
1070+
(!ensure_topk_full || nscan >= (size_t)k)) {
1071+
break;
1072+
}
10541073
size_t ij = i * nprobe + j;
10551074
if (!single_LUT) {
10561075
LUT = dis_tables.get() + ij * dim12;

thirdparty/faiss/faiss/IndexIVFFastScan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ struct IndexIVFFastScan : IndexIVF {
240240
void search_implem_10(
241241
idx_t n,
242242
const float* x,
243+
idx_t k,
243244
SIMDResultHandlerToFloat& handler,
244245
const CoarseQuantized& cq,
245246
size_t* ndis_out,

thirdparty/faiss/faiss/impl/simd_result_handlers.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,25 @@ struct SIMDResultHandlerToFloat : SIMDResultHandler {
6464
nullptr; // table of biases to add to each query (for IVF L2 search)
6565
const float* normalizers = nullptr; // size 2 * nq, to convert
6666

67+
size_t scan_cnt = 0; // scanned vector number (except filtered)
68+
6769
SIMDResultHandlerToFloat(size_t nq, size_t ntotal) : nq(nq), ntotal(ntotal) {}
6870

6971
virtual void begin(const float* norms) {
7072
normalizers = norms;
73+
scan_cnt = 0;
7174
}
7275

7376
// called at end of search to convert int16 distances to float, before
7477
// normalizers are deallocated
7578
virtual void end() {
7679
normalizers = nullptr;
80+
scan_cnt = 0;
81+
}
82+
83+
// Get the number of scanned vectors
84+
size_t count_scanned_rows() {
85+
return scan_cnt;
7786
}
7887
};
7988

@@ -293,6 +302,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
293302
auto real_idx = this->adjust_id(b, j);
294303
lt_mask -= 1 << j;
295304
if (this->sel->is_member(real_idx)) {
305+
this->scan_cnt++;
296306
T d = d32tab[j];
297307
if (C::cmp(idis[q], d)) {
298308
idis[q] = d;
@@ -310,6 +320,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
310320
lt_mask -= 1 << j;
311321
T d = d32tab[j];
312322
if (C::cmp(idis[q], d)) {
323+
this->scan_cnt++;
313324
idis[q] = d;
314325
ids[q] = this->adjust_id(b, j);
315326

@@ -329,6 +340,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
329340
dis[q] = b + idis[q] * one_a;
330341
}
331342
}
343+
this->scan_cnt = 0;
332344
}
333345
};
334346

@@ -388,6 +400,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
388400
auto real_idx = this->adjust_id(b, j);
389401
lt_mask -= 1 << j;
390402
if (this->sel->is_member(real_idx)) {
403+
this->scan_cnt++;
391404
T dis = d32tab[j];
392405
if (C::cmp(heap_dis[0], dis)) {
393406
heap_replace_top<C>(k, heap_dis, heap_ids, dis, real_idx);
@@ -404,6 +417,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
404417
lt_mask -= 1 << j;
405418
T dis = d32tab[j];
406419
if (C::cmp(heap_dis[0], dis)) {
420+
this->scan_cnt++;
407421
int64_t idx = this->adjust_id(b, j);
408422
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
409423

@@ -431,6 +445,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
431445
heap_ids[j] = heap_ids_in[j];
432446
}
433447
}
448+
this->scan_cnt = 0;
434449
}
435450
};
436451

@@ -500,7 +515,6 @@ struct SingleQueryResultCollectHandler : ResultHandlerCompare<C, with_id_map> {
500515
int64_t idx = this->adjust_id(b, j);
501516
collect.emplace_back(idx, dis);
502517
this->in_range_num += 1;
503-
504518
}
505519
}
506520
}
@@ -582,6 +596,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
582596
auto real_idx = this->adjust_id(b, j);
583597
lt_mask -= 1 << j;
584598
if (this->sel->is_member(real_idx)) {
599+
this->scan_cnt++;
585600
T dis = d32tab[j];
586601
res.add(dis, real_idx);
587602

@@ -595,6 +610,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
595610
int j = __builtin_ctz(lt_mask);
596611
lt_mask -= 1 << j;
597612
T dis = d32tab[j];
613+
this->scan_cnt++;
598614
res.add(dis, this->adjust_id(b, j));
599615

600616
this->in_range_num += 1;
@@ -639,6 +655,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
639655
// possibly add empty results
640656
heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
641657
}
658+
this->scan_cnt = 0;
642659
}
643660
};
644661

0 commit comments

Comments
 (0)