Skip to content

Commit f1a22e8

Browse files
authored
use thread_pool to speedup kmeans (#707)
Signed-off-by: LHT129 <[email protected]>
1 parent cce585e commit f1a22e8

File tree

3 files changed

+75
-22
lines changed

3 files changed

+75
-22
lines changed

src/impl/kmeans_cluster.cpp

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "kmeans_cluster.h"
1717

1818
#include <cblas.h>
19+
#include <omp.h>
1920

2021
#include <random>
2122

@@ -24,7 +25,12 @@
2425

2526
namespace vsag {
2627

27-
KMeansCluster::KMeansCluster(int32_t dim, Allocator* allocator) : dim_(dim), allocator_(allocator) {
28+
KMeansCluster::KMeansCluster(int32_t dim, Allocator* allocator, SafeThreadPoolPtr thread_pool)
29+
: dim_(dim), allocator_(allocator), thread_pool_(std::move(thread_pool)) {
30+
if (thread_pool_ == nullptr) {
31+
this->thread_pool_ = SafeThreadPool::FactoryDefaultThreadPool();
32+
// this->thread_pool_->SetPoolSize(10);
33+
}
2834
}
2935

3036
KMeansCluster::~KMeansCluster() {
@@ -62,12 +68,25 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) {
6268

6369
Vector<int> labels(count, -1, this->allocator_);
6470
bool have_empty = false;
71+
std::vector<std::mutex> mutexes(k);
6572
for (int it = 0; it < iter; ++it) {
66-
bool has_converged = true;
73+
std::atomic<bool> has_converged = true;
74+
std::vector<std::future<void>> futures;
6775

68-
for (int64_t i = 0; i < k; ++i) {
69-
y_sqr[i] = FP32ComputeIP(k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_);
76+
auto compute_ip_func = [&](uint64_t start, uint64_t end) -> void {
77+
for (uint64_t i = start; i < end; ++i) {
78+
y_sqr[i] = FP32ComputeIP(k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_);
79+
}
80+
};
81+
auto bs = 1024;
82+
for (uint64_t i = 0; i < static_cast<uint64_t>(k); i += bs) {
83+
futures.emplace_back(thread_pool_->GeneralEnqueue(
84+
compute_ip_func, i, std::min(i + bs, static_cast<uint64_t>(k))));
85+
}
86+
for (auto& future : futures) {
87+
future.wait();
7088
}
89+
futures.clear();
7190

7291
cblas_sgemm(CblasColMajor,
7392
CblasTrans,
@@ -84,34 +103,61 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) {
84103
distances,
85104
static_cast<blasint>(k));
86105

87-
for (uint64_t i = 0; i < count; ++i) {
88-
cblas_saxpy(static_cast<blasint>(k), 1.0, y_sqr, 1, distances + i * k, 1);
89-
auto* min_elem = std::min_element(distances + i * k, distances + i * k + k);
90-
auto min_index = std::distance(distances + i * k, min_elem);
91-
if (min_index != labels[i]) {
92-
labels[i] = static_cast<int>(min_index);
93-
has_converged = false;
106+
// Assign labels to each data point
107+
auto assign_labels_func = [&](uint64_t start, uint64_t end) {
108+
omp_set_num_threads(1);
109+
for (uint64_t i = start; i < end; ++i) {
110+
cblas_saxpy(static_cast<blasint>(k), 1.0, y_sqr, 1, distances + i * k, 1);
111+
auto* min_elem = std::min_element(distances + i * k, distances + i * k + k);
112+
auto min_index = std::distance(distances + i * k, min_elem);
113+
if (min_index != labels[i]) {
114+
labels[i] = static_cast<int>(min_index);
115+
has_converged.store(false);
116+
}
94117
}
118+
};
119+
for (uint64_t i = 0; i < count; i += bs) {
120+
futures.emplace_back(
121+
thread_pool_->GeneralEnqueue(assign_labels_func, i, std::min(i + bs, count)));
122+
}
123+
for (auto& future : futures) {
124+
future.wait();
95125
}
126+
futures.clear();
96127

97-
if (has_converged and not have_empty) {
128+
if (has_converged.load() and not have_empty) {
98129
break;
99130
}
100131

101132
// Update centroids
102133
Vector<int> counts(k, 0, allocator_);
103134
Vector<float> new_centroids(static_cast<uint64_t>(k) * dim_, 0.0F, allocator_);
104135
have_empty = false;
105-
for (uint64_t i = 0; i < count; ++i) {
106-
uint32_t label = labels[i];
107-
counts[label]++;
108-
cblas_saxpy(dim_,
109-
1.0F,
110-
datas + i * dim_,
111-
1,
112-
new_centroids.data() + label * static_cast<uint64_t>(dim_),
113-
1);
136+
137+
auto update_centroids_func = [&](uint64_t start, uint64_t end) {
138+
omp_set_num_threads(1);
139+
for (uint64_t i = start; i < end; ++i) {
140+
uint32_t label = labels[i];
141+
{
142+
std::lock_guard<std::mutex> lock(mutexes[label]);
143+
counts[label]++;
144+
cblas_saxpy(dim_,
145+
1.0F,
146+
datas + i * dim_,
147+
1,
148+
new_centroids.data() + label * static_cast<uint64_t>(dim_),
149+
1);
150+
}
151+
}
152+
};
153+
for (uint64_t i = 0; i < count; i += bs) {
154+
futures.emplace_back(
155+
thread_pool_->GeneralEnqueue(update_centroids_func, i, std::min(i + bs, count)));
156+
}
157+
for (auto& future : futures) {
158+
future.wait();
114159
}
160+
futures.clear();
115161

116162
for (int j = 0; j < k; ++j) {
117163
if (counts[j] > 0) {

src/impl/kmeans_cluster.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
#pragma once
1717

18+
#include "safe_thread_pool.h"
1819
#include "typing.h"
1920
#include "vsag/allocator.h"
2021

2122
namespace vsag {
2223

2324
class KMeansCluster {
2425
public:
25-
explicit KMeansCluster(int32_t dim, Allocator* allocator);
26+
explicit KMeansCluster(int32_t dim,
27+
Allocator* allocator,
28+
SafeThreadPoolPtr thread_pool = nullptr);
2629

2730
~KMeansCluster();
2831

@@ -35,6 +38,8 @@ class KMeansCluster {
3538
private:
3639
Allocator* const allocator_{nullptr};
3740

41+
SafeThreadPoolPtr thread_pool_{nullptr};
42+
3843
const int32_t dim_{0};
3944
};
4045

src/safe_thread_pool.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,6 @@ class SafeThreadPool : public ThreadPool {
8686
bool owner_{false};
8787
};
8888

89+
using SafeThreadPoolPtr = std::shared_ptr<SafeThreadPool>;
90+
8991
} // namespace vsag

0 commit comments

Comments
 (0)