Skip to content

Commit fb88db8

Browse files
authored
add param use_knowhere_build_pool to bypass threadpoll for growing index (zilliztech#1126)
Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
1 parent b823dc9 commit fb88db8

22 files changed

+136
-89
lines changed

include/knowhere/comp/thread_pool.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,33 @@ WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
330330
return Status::success;
331331
}
332332

333+
// This class is used to wrap the thread pool and the inline executor
334+
// If use_pool is true, the function will be pushed to the thread pool
335+
// If use_pool is false, the function will be executed directly
336+
class ThreadPoolWrapper {
337+
public:
338+
ThreadPoolWrapper(const std::shared_ptr<ThreadPool>& pool, bool use_pool = true)
339+
: pool_(pool), use_pool_(use_pool) {
340+
}
341+
342+
template <typename Func, typename... Args>
343+
auto
344+
push(Func&& func, Args&&... args) {
345+
if (use_pool_) {
346+
return pool_->push(std::forward<Func>(func), std::forward<Args>(args)...);
347+
} else {
348+
// If the pool is not used, the function will be executed within the current thread directly
349+
return folly::makeSemiFuture()
350+
.via(&folly::InlineExecutor::instance())
351+
.then([func = std::forward<Func>(func), &args...](auto&&) mutable {
352+
return func(std::forward<Args>(args)...);
353+
});
354+
}
355+
}
356+
357+
private:
358+
std::shared_ptr<ThreadPool> pool_;
359+
bool use_pool_;
360+
};
361+
333362
} // namespace knowhere

include/knowhere/index/index.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,22 @@ class Index {
140140
}
141141

142142
Status
143-
Build(const DataSetPtr dataset, const Json& json);
143+
Build(const DataSetPtr dataset, const Json& json, bool use_knowhere_build_pool = true);
144144

145145
#ifdef KNOWHERE_WITH_CARDINAL
146146
const std::shared_ptr<Interrupt>
147147
BuildAsync(const DataSetPtr dataset, const Json& json,
148148
const std::chrono::seconds timeout = std::chrono::seconds::max());
149149
#else
150150
const std::shared_ptr<Interrupt>
151-
BuildAsync(const DataSetPtr dataset, const Json& json);
151+
BuildAsync(const DataSetPtr dataset, const Json& json, bool use_knowhere_build_pool = true);
152152
#endif
153153

154154
Status
155-
Train(const DataSetPtr dataset, const Json& json);
155+
Train(const DataSetPtr dataset, const Json& json, bool use_knowhere_build_pool = true);
156156

157157
Status
158-
Add(const DataSetPtr dataset, const Json& json);
158+
Add(const DataSetPtr dataset, const Json& json, bool use_knowhere_build_pool = true);
159159

160160
expected<DataSetPtr>
161161
Search(const DataSetPtr dataset, const Json& json, const BitsetView& bitset) const;

include/knowhere/index/index_node.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ class IndexNode : public Object {
6868
* shared, so `shared_ptr` is used instead.
6969
*/
7070
virtual Status
71-
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg) {
72-
RETURN_IF_ERROR(Train(dataset, cfg));
73-
return Add(dataset, std::move(cfg));
71+
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) {
72+
RETURN_IF_ERROR(Train(dataset, cfg, use_knowhere_build_pool));
73+
return Add(dataset, std::move(cfg), use_knowhere_build_pool);
7474
}
7575

7676
/*
@@ -98,7 +98,7 @@ class IndexNode : public Object {
9898
* shared, so `shared_ptr` is used instead.
9999
*/
100100
virtual Status
101-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) = 0;
101+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) = 0;
102102

103103
/**
104104
* @brief Adds data to the trained index.
@@ -118,7 +118,7 @@ class IndexNode : public Object {
118118
* shared, so `shared_ptr` is used instead.
119119
*/
120120
virtual Status
121-
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) = 0;
121+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) = 0;
122122

123123
/**
124124
* @brief Performs a search operation on the index.

include/knowhere/index/index_node_data_mock_wrapper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ class IndexNodeDataMockWrapper : public IndexNode {
2727
}
2828

2929
Status
30-
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
30+
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
3131

3232
Status
33-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
33+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
3434

3535
Status
36-
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
36+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
3737

3838
expected<DataSetPtr>
3939
Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;

include/knowhere/index/index_node_thread_pool_wrapper.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ class IndexNodeThreadPoolWrapper : public IndexNode {
2424
IndexNodeThreadPoolWrapper(std::unique_ptr<IndexNode> index_node, std::shared_ptr<ThreadPool> thread_pool);
2525

2626
Status
27-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
28-
return index_node_->Train(dataset, std::move(cfg));
27+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
28+
return index_node_->Train(dataset, std::move(cfg), use_knowhere_build_pool);
2929
}
3030

3131
Status
32-
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
33-
return index_node_->Add(dataset, std::move(cfg));
32+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
33+
return index_node_->Add(dataset, std::move(cfg), use_knowhere_build_pool);
3434
}
3535

3636
expected<DataSetPtr>

src/index/data_view_dense_index/data_view_dense_index.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class DataViewIndexBase {
7878
virtual ~DataViewIndexBase(){};
7979

8080
virtual void
81-
Train(idx_t n, const void* __restrict x) = 0;
81+
Train(idx_t n, const void* __restrict x, bool use_knowhere_build_pool) = 0;
8282

8383
virtual void
84-
Add(idx_t n, const void* __restrict x, const float* __restrict norms_) = 0;
84+
Add(idx_t n, const void* __restrict x, const float* __restrict norms_, bool use_knowhere_build_pool) = 0;
8585

8686
virtual void
8787
Search(const idx_t n, const void* __restrict x, const idx_t k, float* __restrict distances,
@@ -174,10 +174,11 @@ class DataViewIndexFlat : public DataViewIndexBase {
174174
this->ntotal_.store(0);
175175
}
176176
void
177-
Train(idx_t n, const void* x) override {
177+
Train(idx_t n, const void* x, bool use_knowhere_build_pool) override {
178178
if (quant_data_ != nullptr) {
179-
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
180-
auto task = build_pool
179+
auto build_pool_wrapper =
180+
std::make_shared<ThreadPoolWrapper>(ThreadPool::GetGlobalBuildThreadPool(), use_knowhere_build_pool);
181+
auto task = build_pool_wrapper
181182
->push([&] {
182183
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
183184
if (build_thread_num_.has_value()) {
@@ -197,10 +198,11 @@ class DataViewIndexFlat : public DataViewIndexBase {
197198
}
198199

199200
void
200-
Add(idx_t n, const void* x, const float* __restrict in_norms) override {
201+
Add(idx_t n, const void* x, const float* __restrict in_norms, bool use_knowhere_build_pool) override {
201202
if (quant_data_ != nullptr) {
202-
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
203-
auto task = build_pool
203+
auto build_pool_wrapper =
204+
std::make_shared<ThreadPoolWrapper>(ThreadPool::GetGlobalBuildThreadPool(), use_knowhere_build_pool);
205+
auto task = build_pool_wrapper
204206
->push([&] {
205207
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
206208
if (build_thread_num_.has_value()) {

src/index/data_view_dense_index/index_node_with_data_view_refiner.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ class IndexNodeWithDataViewRefiner : public IndexNode {
4343
}
4444

4545
Status
46-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
46+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
4747

4848
Status
49-
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
49+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
5050

5151
expected<DataSetPtr>
5252
Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override;
@@ -306,7 +306,8 @@ ConvertToBaseIndexFp32DataSet(const DataSetPtr& src, bool is_cosine = false,
306306

307307
template <typename DataType, typename BaseIndexNode>
308308
Status
309-
IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) {
309+
IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg,
310+
bool use_knowhere_build_pool) {
310311
BaseConfig& base_cfg = static_cast<BaseConfig&>(*cfg);
311312
this->is_cosine_ = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE);
312313
auto dim = dataset->GetDim();
@@ -326,19 +327,20 @@ IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Train(const DataSetPtr da
326327
refine_offset_index_ = std::make_unique<DataViewIndexFlat>(
327328
dim, datatype_v<DataType>, refine_metric, this->view_data_op_, is_cosine_, refine_type, build_thread_num);
328329
try {
329-
refine_offset_index_->Train(train_rows, data);
330+
refine_offset_index_->Train(train_rows, data, use_knowhere_build_pool);
330331
} catch (const std::exception& e) {
331332
LOG_KNOWHERE_WARNING_ << "data view index inner error: " << e.what();
332333
return Status::internal_error;
333334
}
334-
return base_index_->Train(
335-
fp32_train_ds,
336-
cfg); // train not need base_index_lock_, all add and search will fail if train not called before
335+
return base_index_->Train(fp32_train_ds, cfg,
336+
use_knowhere_build_pool); // train not need base_index_lock_, all add and search will
337+
// fail if train not called before
337338
}
338339

339340
template <typename DataType, typename BaseIndexNode>
340341
Status
341-
IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) {
342+
IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg,
343+
bool use_knowhere_build_pool) {
342344
auto rows = dataset->GetRows();
343345
auto dim = dataset->GetDim();
344346
auto data = (const DataType*)dataset->GetTensor();
@@ -349,14 +351,14 @@ IndexNodeWithDataViewRefiner<DataType, BaseIndexNode>::Add(const DataSetPtr data
349351
auto [fp32_base_ds, norms] =
350352
ConvertToBaseIndexFp32DataSet<DataType>(dataset, is_cosine_, blk_i, blk_size, base_index_->Dim());
351353
try {
352-
refine_offset_index_->Add(blk_size, data + blk_i * dim, norms.data());
354+
refine_offset_index_->Add(blk_size, data + blk_i * dim, norms.data(), use_knowhere_build_pool);
353355
} catch (const std::exception& e) {
354356
LOG_KNOWHERE_WARNING_ << "data view index inner error: " << e.what();
355357
return Status::internal_error;
356358
}
357359
{
358360
FairWriteLockGuard guard(*this->base_index_lock_);
359-
add_stat = base_index_->Add(fp32_base_ds, cfg);
361+
add_stat = base_index_->Add(fp32_base_ds, cfg, use_knowhere_build_pool);
360362
}
361363

362364
if (add_stat != Status::success) {

src/index/diskann/diskann.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ class DiskANNIndexNode : public IndexNode {
4646
}
4747

4848
Status
49-
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override;
49+
Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override;
5050

5151
Status
52-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
52+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
5353
return Status::not_implemented;
5454
}
5555

5656
Status
57-
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
57+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
5858
return Status::not_implemented;
5959
}
6060

@@ -303,7 +303,7 @@ CheckMetric(const std::string& diskann_metric) {
303303

304304
template <typename DataType>
305305
Status
306-
DiskANNIndexNode<DataType>::Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg) {
306+
DiskANNIndexNode<DataType>::Build(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) {
307307
assert(file_manager_ != nullptr);
308308
std::lock_guard<std::mutex> lock(preparation_lock_);
309309
auto build_conf = static_cast<const DiskANNConfig&>(*cfg);

src/index/flat/flat.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class FlatIndexNode : public IndexNode {
4040
}
4141

4242
Status
43-
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
43+
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
4444
const FlatConfig& f_cfg = static_cast<const FlatConfig&>(*cfg);
4545

4646
auto metric = Str2FaissMetricType(f_cfg.metric_type.value());
@@ -59,7 +59,7 @@ class FlatIndexNode : public IndexNode {
5959
}
6060

6161
Status
62-
Add(const DataSetPtr dataset, std::shared_ptr<Config>) override {
62+
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg, bool use_knowhere_build_pool) override {
6363
auto x = dataset->GetTensor();
6464
auto n = dataset->GetRows();
6565
index_->add(n, (const DataType*)x);

src/index/gpu/flat_gpu/flat_gpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class GpuFlatIndexNode : public IndexNode {
2828
}
2929

3030
Status
31-
Train(const DataSetPtr dataset, const Config& cfg) override {
31+
Train(const DataSetPtr dataset, const Config& cfg, bool use_knowhere_build_pool) override {
3232
const GpuFlatConfig& f_cfg = static_cast<const GpuFlatConfig&>(cfg);
3333
auto metric = Str2FaissMetricType(f_cfg.metric_type);
3434
if (!metric.has_value()) {
@@ -40,7 +40,7 @@ class GpuFlatIndexNode : public IndexNode {
4040
}
4141

4242
Status
43-
Add(const DataSetPtr dataset, const Config& cfg) override {
43+
Add(const DataSetPtr dataset, const Config& cfg, bool use_knowhere_build_pool) override {
4444
const void* x = dataset->GetTensor();
4545
const int64_t n = dataset->GetRows();
4646
try {

0 commit comments

Comments
 (0)