Skip to content

Commit 24b7861

Browse files
committed
Add random kmeans init class.
1 parent 7ab8bf6 commit 24b7861

File tree

3 files changed

+83
-11
lines changed

3 files changed

+83
-11
lines changed

src/gpu/kmeans/kmeans_init.cu

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#include <thrust/device_vector.h>
77

8-
#include <cub/device/device_select.cuh>
98
#include <cub/device/device_histogram.cuh>
109

1110
#include <random>
@@ -250,10 +249,28 @@ KmMatrix<T> GreedyRecluster<T>::recluster(KmMatrix<T>& _centroids, size_t _k) {
250249
251250
} // namespace detail
252251
252+
253+
/* ============ KmeansRandomInit Class member functions ============ */
254+
255+
template <typename T>
256+
KmMatrix<T> KmeansRandomInit<T>::operator()(KmMatrix<T>& _data, size_t _k) {
257+
258+
KmMatrix<T> dist = generator_impl_->generate(_k);
259+
MulOp<T>().mul(dist, dist, (T)_data.rows() - 1);
260+
261+
dist = dist.reshape(dist.cols(), dist.rows());
262+
263+
KmMatrix<T> centroids = _data.rows(dist);
264+
265+
return centroids;
266+
}
253267
254268
255269
/* ============== KmeansLlInit Class member functions ============== */
256270
271+
// Although the paper suggested calculating the probability independently,
272+
// but due to zero distance between a point and itself, the already selected
273+
// points will have very low probability.
257274
template <typename T, template <class> class ReclusterPolicy >
258275
KmMatrix<T> KmeansLlInit<T, ReclusterPolicy>::probability(
259276
KmMatrix<T>& _data, KmMatrix<T>& _centroids) {
@@ -360,17 +377,18 @@ KmeansLlInit<T, ReclusterPolicy>::operator()(KmMatrix<T>& _data, size_t _k) {
360377
361378
T cost = SumOp<T>().sum(prob);
362379
363-
size_t max_iter = std::max(T(MAX_ITER), std::log(cost));
380+
size_t max_iter = std::max((size_t)(MAX_ITER),
381+
(size_t)std::ceil(std::log(cost)));
364382
for (size_t i = 0; i < max_iter; ++i) {
365383
prob = probability(_data, centroids);
366384
KmMatrix<T> new_centroids = sample_centroids(_data, prob);
367385
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
368386
}
369387
370388
if (centroids.rows() < _k) {
371-
// FIXME: When n_centroids < k
372-
// Get random selection in?
373-
M_ERROR("Not implemented.");
389+
KmMatrix<T> new_centroids = KmeansRandomInit<T>(generator_)(_data,
390+
_k - centroids.rows());
391+
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
374392
}
375393
376394
centroids = ReclusterPolicy<T>::recluster(centroids, k_);
@@ -385,9 +403,12 @@ KmeansLlInit<T, ReclusterPolicy>::operator()(KmMatrix<T>& _data, size_t _k) {
385403
KmMatrix<T>& data, KmMatrix<T>& centroids); \
386404
template KmMatrix<T> KmeansLlInit<T>::sample_centroids( \
387405
KmMatrix<T>& data, KmMatrix<T>& centroids); \
406+
template KmMatrix<T> KmeansRandomInit<T>::operator()( \
407+
KmMatrix<T>& _data, size_t _k);
388408
389409
INSTANTIATE(float)
390410
INSTANTIATE(double)
411+
INSTANTIATE(int)
391412
392413
#undef INSTANTIATE
393414
@@ -403,15 +424,14 @@ namespace detail {
403424
KmMatrix<T>& _centroids_dot, \
404425
KmMatrix<T>& _distance_pairs); \
405426
template KmMatrix<T> PairWiseDistanceOp<T>::operator()( \
406-
KmMatrix<T>& _data, \
407-
KmMatrix<T>& _centroids); \
427+
KmMatrix<T>& _data, KmMatrix<T>& _centroids); \
408428
409429
INSTANTIATE(float)
410430
INSTANTIATE(double)
431+
INSTANTIATE(int)
411432
412433
#undef INSTANTIATE
413434
}
414-
// FIXME: int is not supported due to random kernel
415435
416436
} // namespace Kmeans
417437
} // namespace H2O4GPU

src/gpu/kmeans/kmeans_init.cuh

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,38 @@ class KmeansInitBase {
7373
virtual KmMatrix<T> operator()(KmMatrix<T>& data, size_t k) = 0;
7474
};
7575

76+
/*
77+
* Random initialization.
78+
* @tparam Numeric data type.
79+
*/
80+
template <typename T>
81+
class KmeansRandomInit : public KmeansInitBase<T> {
82+
private:
83+
int seed_;
84+
std::unique_ptr<GeneratorBase<T>> generator_impl_;
85+
86+
public:
87+
/*
88+
* @param seed Random seed for generating centroids.
89+
*/
90+
KmeansRandomInit(size_t _seed) :
91+
seed_(_seed), generator_impl_ (new UniformGenerator<T>) {}
92+
93+
/*
94+
* @param gen Unique pointer to Random generator for generating centroids.
95+
*/
96+
KmeansRandomInit(std::unique_ptr<GeneratorBase<T>>& _gen) :
97+
generator_impl_(std::move(_gen)) {}
98+
99+
virtual ~KmeansRandomInit() override {}
100+
101+
/*
102+
* @param data Data points stored in row major matrix.
103+
* @param k Number of centroids.
104+
*/
105+
virtual KmMatrix<T> operator()(KmMatrix<T>& data, size_t k) override;
106+
};
107+
76108
/*
77109
* Each instance of KmeansLlInit corresponds to one dataset, if a new data set
78110
* is used, users need to create a new instance.
@@ -82,7 +114,7 @@ class KmeansInitBase {
82114
* Scalable K-Means++
83115
* </a>
84116
*
85-
* @tparam Data type, supported types are float and double.
117+
* @tparam Numeric data type.
86118
*/
87119
template <
88120
typename T,
@@ -107,8 +139,8 @@ struct KmeansLlInit : public KmeansInitBase<T> {
107139

108140
KmMatrix<T> probability(KmMatrix<T>& data, KmMatrix<T>& centroids);
109141
public:
110-
// sample_centroids/recluster should not be part of the interface, but
111-
// following error is generated when put in private section:
142+
// sample_centroids should not be part of the interface, but following error
143+
// is generated when put in private section:
112144
// The enclosing parent function ("sample_centroids") for an extended
113145
// __device__ lambda cannot have private or protected access within its class
114146
KmMatrix<T> sample_centroids(KmMatrix<T>& data, KmMatrix<T>& centroids);

tests/cpp/gpu/kmeans/test_kmeans_init.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,26 @@ struct GeneratorMock : GeneratorBase<T> {
3636
}
3737
};
3838

39+
TEST(KmeansRandom, Init) {
40+
thrust::host_vector<float> h_data (20);
41+
for (size_t i = 0; i < 20; ++i) {
42+
h_data[i] = i * 2;
43+
}
44+
KmMatrix<float> data (h_data, 4, 5);
45+
std::unique_ptr<GeneratorBase<float>> gen (new GeneratorMock<float>());
46+
KmeansRandomInit<float> init (gen);
47+
48+
auto res = init(data, 2);
49+
50+
std::vector<float> h_sol =
51+
{
52+
30, 32, 34, 36, 38,
53+
0, 2, 4, 6, 8
54+
};
55+
KmMatrix<float> sol (h_sol, 2, 5);
56+
ASSERT_TRUE(sol == res);
57+
}
58+
3959
// r --gtest_filter=KmeansLL.PairWiseDistance
4060
TEST(KmeansLL, PairWiseDistance) {
4161

0 commit comments

Comments
 (0)