16
16
#include " kmeans_cluster.h"
17
17
18
18
#include < cblas.h>
19
+ #include < omp.h>
19
20
20
21
#include < random>
21
22
24
25
25
26
namespace vsag {
26
27
27
- KMeansCluster::KMeansCluster (int32_t dim, Allocator* allocator) : dim_(dim), allocator_(allocator) {
28
+ KMeansCluster::KMeansCluster (int32_t dim, Allocator* allocator, SafeThreadPoolPtr thread_pool)
29
+ : dim_(dim), allocator_(allocator), thread_pool_(std::move(thread_pool)) {
30
+ if (thread_pool_ == nullptr ) {
31
+ this ->thread_pool_ = SafeThreadPool::FactoryDefaultThreadPool ();
32
+ // this->thread_pool_->SetPoolSize(10);
33
+ }
28
34
}
29
35
30
36
KMeansCluster::~KMeansCluster () {
@@ -62,12 +68,25 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) {
62
68
63
69
Vector<int > labels (count, -1 , this ->allocator_ );
64
70
bool have_empty = false ;
71
+ std::vector<std::mutex> mutexes (k);
65
72
for (int it = 0 ; it < iter; ++it) {
66
- bool has_converged = true ;
73
+ std::atomic<bool > has_converged = true ;
74
+ std::vector<std::future<void >> futures;
67
75
68
- for (int64_t i = 0 ; i < k; ++i) {
69
- y_sqr[i] = FP32ComputeIP (k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_);
76
+ auto compute_ip_func = [&](uint64_t start, uint64_t end) -> void {
77
+ for (uint64_t i = start; i < end; ++i) {
78
+ y_sqr[i] = FP32ComputeIP (k_centroids_ + i * dim_, k_centroids_ + i * dim_, dim_);
79
+ }
80
+ };
81
+ auto bs = 1024 ;
82
+ for (uint64_t i = 0 ; i < static_cast <uint64_t >(k); i += bs) {
83
+ futures.emplace_back (thread_pool_->GeneralEnqueue (
84
+ compute_ip_func, i, std::min (i + bs, static_cast <uint64_t >(k))));
85
+ }
86
+ for (auto & future : futures) {
87
+ future.wait ();
70
88
}
89
+ futures.clear ();
71
90
72
91
cblas_sgemm (CblasColMajor,
73
92
CblasTrans,
@@ -84,34 +103,61 @@ KMeansCluster::Run(uint32_t k, const float* datas, uint64_t count, int iter) {
84
103
distances,
85
104
static_cast <blasint>(k));
86
105
87
- for (uint64_t i = 0 ; i < count; ++i) {
88
- cblas_saxpy (static_cast <blasint>(k), 1.0 , y_sqr, 1 , distances + i * k, 1 );
89
- auto * min_elem = std::min_element (distances + i * k, distances + i * k + k);
90
- auto min_index = std::distance (distances + i * k, min_elem);
91
- if (min_index != labels[i]) {
92
- labels[i] = static_cast <int >(min_index);
93
- has_converged = false ;
106
+ // Assign labels to each data point
107
+ auto assign_labels_func = [&](uint64_t start, uint64_t end) {
108
+ omp_set_num_threads (1 );
109
+ for (uint64_t i = start; i < end; ++i) {
110
+ cblas_saxpy (static_cast <blasint>(k), 1.0 , y_sqr, 1 , distances + i * k, 1 );
111
+ auto * min_elem = std::min_element (distances + i * k, distances + i * k + k);
112
+ auto min_index = std::distance (distances + i * k, min_elem);
113
+ if (min_index != labels[i]) {
114
+ labels[i] = static_cast <int >(min_index);
115
+ has_converged.store (false );
116
+ }
94
117
}
118
+ };
119
+ for (uint64_t i = 0 ; i < count; i += bs) {
120
+ futures.emplace_back (
121
+ thread_pool_->GeneralEnqueue (assign_labels_func, i, std::min (i + bs, count)));
122
+ }
123
+ for (auto & future : futures) {
124
+ future.wait ();
95
125
}
126
+ futures.clear ();
96
127
97
- if (has_converged and not have_empty) {
128
+ if (has_converged. load () and not have_empty) {
98
129
break ;
99
130
}
100
131
101
132
// Update centroids
102
133
Vector<int > counts (k, 0 , allocator_);
103
134
Vector<float > new_centroids (static_cast <uint64_t >(k) * dim_, 0 .0F , allocator_);
104
135
have_empty = false ;
105
- for (uint64_t i = 0 ; i < count; ++i) {
106
- uint32_t label = labels[i];
107
- counts[label]++;
108
- cblas_saxpy (dim_,
109
- 1 .0F ,
110
- datas + i * dim_,
111
- 1 ,
112
- new_centroids.data () + label * static_cast <uint64_t >(dim_),
113
- 1 );
136
+
137
+ auto update_centroids_func = [&](uint64_t start, uint64_t end) {
138
+ omp_set_num_threads (1 );
139
+ for (uint64_t i = start; i < end; ++i) {
140
+ uint32_t label = labels[i];
141
+ {
142
+ std::lock_guard<std::mutex> lock (mutexes[label]);
143
+ counts[label]++;
144
+ cblas_saxpy (dim_,
145
+ 1 .0F ,
146
+ datas + i * dim_,
147
+ 1 ,
148
+ new_centroids.data () + label * static_cast <uint64_t >(dim_),
149
+ 1 );
150
+ }
151
+ }
152
+ };
153
+ for (uint64_t i = 0 ; i < count; i += bs) {
154
+ futures.emplace_back (
155
+ thread_pool_->GeneralEnqueue (update_centroids_func, i, std::min (i + bs, count)));
156
+ }
157
+ for (auto & future : futures) {
158
+ future.wait ();
114
159
}
160
+ futures.clear ();
115
161
116
162
for (int j = 0 ; j < k; ++j) {
117
163
if (counts[j] > 0 ) {
0 commit comments