Skip to content

Commit fa362d8

Browse files
synchronize with the facebookresearch/faiss#4926
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent 9d905f2 commit fa362d8

7 files changed

Lines changed: 990 additions & 113 deletions

File tree

thirdparty/faiss/faiss/impl/fast_scan/accumulate_loops.h

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,12 @@ void accumulate_fixed_blocks(
5959
constexpr int bbs = 32 * BB;
6060
for (size_t j0 = 0; j0 < nb; j0 += bbs, codes += block_stride) {
6161
res.set_block_origin(0, j0);
62-
// skip computing distances if all vectors inside a block are filtered out
6362
if constexpr (has_sel_member_v<ResultHandler>) {
64-
if (res.sel != nullptr) {
65-
bool skip_flag = true;
66-
for (size_t jj = 0;
67-
jj < std::min<size_t>(bbs, res.ntotal - j0);
68-
jj++) {
69-
auto real_idx = res.adjust_id(0, jj);
70-
if (res.sel->is_member(real_idx)) {
71-
skip_flag = false;
72-
break;
73-
}
74-
}
75-
if (skip_flag) {
76-
continue;
77-
}
78-
}
63+
if (whether_all_vectors_filtered_out(
64+
res, std::min<size_t>(bbs, res.ntotal - j0)))
65+
continue;
7966
}
67+
8068
FixedStorageHandler<NQ, 2 * BB> res2;
8169
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
8270
res2.to_other_handler(res);
@@ -142,24 +130,12 @@ void accumulate_q_4step_256(
142130

143131
for (size_t j0 = 0; j0 < ntotal2; j0 += 32, codes += block_stride) {
144132
res.set_block_origin(0, j0);
145-
// skip computing distances if all vectors inside a block are filtered out
146133
if constexpr (has_sel_member_v<ResultHandler>) {
147-
if (res.sel != nullptr) {
148-
bool skip_flag = true;
149-
for (size_t jj = 0;
150-
jj < std::min<size_t>(32, ntotal2 - j0);
151-
jj++) {
152-
auto real_idx = res.adjust_id(0, jj);
153-
if (res.sel->is_member(real_idx)) {
154-
skip_flag = false;
155-
break;
156-
}
157-
}
158-
if (skip_flag) {
159-
continue;
160-
}
161-
}
134+
if (whether_all_vectors_filtered_out(
135+
res, std::min<size_t>(32, res.ntotal - j0)))
136+
continue;
162137
}
138+
163139
FixedStorageHandler<SQ, 2> res2;
164140
const uint8_t* LUT = LUT0;
165141
pq4_kernel_qbs_256<Q1>(nsq, codes, LUT, res2, scaler);
@@ -230,25 +206,13 @@ void pq4_accumulate_loop_qbs_fixed_scaler_256(
230206

231207
// Default: qbs not known at compile time
232208
for (size_t j0 = 0; j0 < ntotal2; j0 += 32, codes += block_stride) {
233-
// skip computing distances if all vectors inside a block are filtered out
209+
res.set_block_origin(0, j0);
234210
if constexpr (has_sel_member_v<ResultHandler>) {
235-
if (res.sel != nullptr) {
236-
res.set_block_origin(0, j0);
237-
bool skip_flag = true;
238-
for (size_t jj = 0;
239-
jj < std::min<size_t>(32, ntotal2 - j0);
240-
jj++) {
241-
auto real_idx = res.adjust_id(0, jj);
242-
if (res.sel->is_member(real_idx)) {
243-
skip_flag = false;
244-
break;
245-
}
246-
}
247-
if (skip_flag) {
248-
continue;
249-
}
250-
}
211+
if (whether_all_vectors_filtered_out(
212+
res, std::min<size_t>(32, res.ntotal - j0)))
213+
continue;
251214
}
215+
252216
const uint8_t* LUT = LUT0;
253217
int qi = qbs;
254218
int i0 = 0;

thirdparty/faiss/faiss/impl/fast_scan/decompose_qbs.h

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,12 @@ void accumulate_q_4step(
5454

5555
for (size_t j0 = 0; j0 < ntotal2; j0 += 32, codes += block_stride) {
5656
res.set_block_origin(0, j0);
57-
// skip computing distances if all vectors inside a block are filtered out
5857
if constexpr (has_sel_member_v<ResultHandler>) {
59-
if (res.sel != nullptr) {
60-
bool skip_flag = true;
61-
for (size_t jj = 0;
62-
jj < std::min<size_t>(32, ntotal2 - j0);
63-
jj++) {
64-
auto real_idx = res.adjust_id(0, jj);
65-
if (res.sel->is_member(real_idx)) {
66-
skip_flag = false;
67-
break;
68-
}
69-
}
70-
if (skip_flag) {
71-
continue;
72-
}
73-
}
58+
if (whether_all_vectors_filtered_out(
59+
res, std::min<size_t>(32, res.ntotal - j0)))
60+
continue;
7461
}
62+
7563
FixedStorageHandler<SQ, 2> res2;
7664
const uint8_t* LUT = LUT0;
7765
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
@@ -105,24 +93,12 @@ void kernel_accumulate_block_loop(
10593
size_t block_stride) {
10694
for (size_t j0 = 0; j0 < ntotal2; j0 += 32, codes += block_stride) {
10795
res.set_block_origin(0, j0);
108-
// skip computing distances if all vectors inside a block are filtered out
10996
if constexpr (has_sel_member_v<ResultHandler>) {
110-
if (res.sel != nullptr) {
111-
bool skip_flag = true;
112-
for (size_t jj = 0;
113-
jj < std::min<size_t>(32, ntotal2 - j0);
114-
jj++) {
115-
auto real_idx = res.adjust_id(0, jj);
116-
if (res.sel->is_member(real_idx)) {
117-
skip_flag = false;
118-
break;
119-
}
120-
}
121-
if (skip_flag) {
122-
continue;
123-
}
124-
}
97+
if (whether_all_vectors_filtered_out(
98+
res, std::min<size_t>(32, res.ntotal - j0)))
99+
continue;
125100
}
101+
126102
kernel_accumulate_block<NQ, ResultHandler>(
127103
nsq, codes, LUT, res, scaler);
128104
}
@@ -208,25 +184,13 @@ void pq4_accumulate_loop_qbs_fixed_scaler(
208184

209185
// default implementation where qbs is not known at compile time
210186
for (size_t j0 = 0; j0 < ntotal2; j0 += 32, codes += block_stride) {
211-
// skip computing distances if all vectors inside a block are filtered out
187+
res.set_block_origin(0, j0);
212188
if constexpr (has_sel_member_v<ResultHandler>) {
213-
if (res.sel != nullptr) {
214-
res.set_block_origin(0, j0);
215-
bool skip_flag = true;
216-
for (size_t jj = 0;
217-
jj < std::min<size_t>(32, ntotal2 - j0);
218-
jj++) {
219-
auto real_idx = res.adjust_id(0, jj);
220-
if (res.sel->is_member(real_idx)) {
221-
skip_flag = false;
222-
break;
223-
}
224-
}
225-
if (skip_flag) {
226-
continue;
227-
}
228-
}
189+
if (whether_all_vectors_filtered_out(
190+
res, std::min<size_t>(32, res.ntotal - j0)))
191+
continue;
229192
}
193+
230194
const uint8_t* LUT = LUT0;
231195
int qi = qbs;
232196
int i0 = 0;

thirdparty/faiss/faiss/impl/fast_scan/simd_result_handlers.h

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ struct has_sel_member<T, std::void_t<decltype(T::sel)>> : std::true_type {};
3737
template <typename T>
3838
inline constexpr bool has_sel_member_v = has_sel_member<T>::value;
3939

40+
/// Check if all vectors in a block are filtered out by the IDSelector.
41+
/// Returns true if the block should be skipped (all vectors filtered).
42+
/// Requires set_block_origin() to have been called before this.
43+
/// Compiles to nothing (returns false) when ResultHandler has no sel member.
44+
template <class ResultHandler>
45+
inline bool whether_all_vectors_filtered_out(
46+
ResultHandler& res,
47+
size_t block_size) {
48+
if constexpr (has_sel_member_v<ResultHandler>) {
49+
if (res.sel != nullptr) {
50+
for (size_t jj = 0; jj < block_size; jj++) {
51+
if (res.sel->is_member(res.adjust_id(0, jj))) {
52+
return false;
53+
}
54+
}
55+
return true;
56+
}
57+
}
58+
return false;
59+
}
60+
4061
} // namespace
4162

4263
struct SIMDResultHandler {
@@ -346,7 +367,7 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map, SL> {
346367
idis[q] = d;
347368
ids[q] = real_idx;
348369

349-
this->in_range_num += 1;
370+
this->in_range_num++;
350371
}
351372
}
352373
}
@@ -355,13 +376,13 @@ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map, SL> {
355376
// find first non-zero
356377
int j = __builtin_ctz(lt_mask);
357378
lt_mask -= 1 << j;
379+
this->scan_cnt++;
358380
T d = d32tab[j];
359381
if (C::cmp(idis[q], d)) {
360-
this->scan_cnt++;
361382
idis[q] = d;
362383
ids[q] = this->adjust_id(b, j);
363384

364-
this->in_range_num += 1;
385+
this->in_range_num++;
365386
}
366387
}
367388
}
@@ -470,7 +491,7 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map, SL> {
470491
k, heap_dis, heap_ids, dis_for_j, real_idx);
471492
nup++;
472493

473-
this->in_range_num += 1;
494+
this->in_range_num++;
474495
}
475496
}
476497
}
@@ -479,14 +500,14 @@ struct HeapHandler : ResultHandlerCompare<C, with_id_map, SL> {
479500
// find first non-zero
480501
int j = __builtin_ctz(lt_mask);
481502
lt_mask -= 1 << j;
503+
this->scan_cnt++;
482504
T dis_for_j = d32tab[j];
483505
if (C::cmp(heap_dis[0], dis_for_j)) {
484-
this->scan_cnt++;
485506
int64_t idx = this->adjust_id(b, j);
486507
heap_replace_top<C>(k, heap_dis, heap_ids, dis_for_j, idx);
487508
nup++;
488509

489-
this->in_range_num += 1;
510+
this->in_range_num++;
490511
}
491512
}
492513
}
@@ -596,7 +617,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map, SL> {
596617
T dis_for_j = d32tab[j];
597618
res.add(dis_for_j, real_idx);
598619

599-
this->in_range_num += 1;
620+
this->in_range_num++;
600621
}
601622
}
602623
} else {
@@ -608,7 +629,7 @@ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map, SL> {
608629
this->scan_cnt++;
609630
res.add(dis_for_j, this->adjust_id(b, j));
610631

611-
this->in_range_num += 1;
632+
this->in_range_num++;
612633
}
613634
}
614635
}
@@ -730,7 +751,7 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map, SL> {
730751
n_per_query[q]++;
731752
triplets.push_back({idx_t(q + q0), real_idx, dis});
732753

733-
this->in_range_num += 1;
754+
this->in_range_num++;
734755
}
735756
}
736757
} else {
@@ -742,7 +763,7 @@ struct RangeHandler : ResultHandlerCompare<C, with_id_map, SL> {
742763
n_per_query[q]++;
743764
triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
744765

745-
this->in_range_num += 1;
766+
this->in_range_num++;
746767
}
747768
}
748769
}
@@ -898,7 +919,7 @@ struct SingleQueryResultCollectHandler
898919
if (this->sel->is_member(real_idx)) {
899920
T dis = d32tab[j];
900921
collect.emplace_back(real_idx, dis);
901-
this->in_range_num += 1;
922+
this->in_range_num++;
902923
}
903924
}
904925
} else {
@@ -908,7 +929,7 @@ struct SingleQueryResultCollectHandler
908929
T dis = d32tab[j];
909930
int64_t idx = this->adjust_id(b, j);
910931
collect.emplace_back(idx, dis);
911-
this->in_range_num += 1;
932+
this->in_range_num++;
912933
}
913934
}
914935
}
@@ -921,6 +942,7 @@ struct SingleQueryResultCollectHandler
921942
collect[i].second = collect[i].second * one_a + b;
922943
}
923944
}
945+
this->scan_cnt = 0;
924946
}
925947
};
926948

@@ -946,9 +968,10 @@ void with_SIMDResultHandler(SIMDResultHandler& res, Lambda&& lambda) {
946968
lambda(*resh);
947969
} else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
948970
lambda(*resh);
949-
} else if (auto resh =
950-
dynamic_cast<SingleQueryResultCollectHandler<C, W>*>(
951-
&res)) {
971+
} else if (
972+
auto resh =
973+
dynamic_cast<SingleQueryResultCollectHandler<C, W>*>(
974+
&res)) {
952975
lambda(*resh);
953976
} else { // generic handler -- will not be inlined
954977
FAISS_THROW_IF_NOT_FMT(

thirdparty/faiss/tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ set(FAISS_TEST_SRC
4141
test_scalar_quantizer.cpp
4242
test_factory_tools.cpp
4343
test_custom_result_handler.cpp
44+
test_fastscan_filter.cpp
45+
test_single_query_collect_handler.cpp
4446
# These tests work in both static and DD modes (uniform SIMDConfig API)
4547
test_distances_simd.cpp
4648
test_simd_levels.cpp

0 commit comments

Comments
 (0)