5
5
6
6
#include < thrust/device_vector.h>
7
7
8
- #include < cub/device/device_select.cuh>
9
8
#include < cub/device/device_histogram.cuh>
10
9
11
10
#include < random>
@@ -250,10 +249,28 @@ KmMatrix<T> GreedyRecluster<T>::recluster(KmMatrix<T>& _centroids, size_t _k) {
250
249
251
250
} // namespace detail
252
251
252
+
253
+ /* ============ KmeansRandomInit Class member functions ============ */
254
+
255
+ template <typename T>
256
+ KmMatrix<T> KmeansRandomInit<T>::operator ()(KmMatrix<T>& _data, size_t _k) {
257
+
258
+ KmMatrix<T> dist = generator_impl_->generate (_k);
259
+ MulOp<T>().mul (dist, dist, (T)_data.rows () - 1 );
260
+
261
+ dist = dist.reshape (dist.cols (), dist.rows ());
262
+
263
+ KmMatrix<T> centroids = _data.rows (dist);
264
+
265
+ return centroids;
266
+ }
253
267
254
268
255
269
/* ============== KmeansLlInit Class member functions ============== */
256
270
271
+ // Although the paper suggested calculating the probability independently,
272
+ // but due to zero distance between a point and itself, the already selected
273
+ // points will have very low probability.
257
274
template <typename T, template <class > class ReclusterPolicy >
258
275
KmMatrix<T> KmeansLlInit<T, ReclusterPolicy>::probability(
259
276
KmMatrix<T>& _data, KmMatrix<T>& _centroids) {
@@ -360,17 +377,18 @@ KmeansLlInit<T, ReclusterPolicy>::operator()(KmMatrix<T>& _data, size_t _k) {
360
377
361
378
T cost = SumOp<T>().sum (prob);
362
379
363
- size_t max_iter = std::max (T (MAX_ITER), std::log (cost));
380
+ size_t max_iter = std::max ((size_t )(MAX_ITER),
381
+ (size_t )std::ceil (std::log (cost)));
364
382
for (size_t i = 0 ; i < max_iter; ++i) {
365
383
prob = probability (_data, centroids);
366
384
KmMatrix<T> new_centroids = sample_centroids (_data, prob);
367
385
centroids = stack (centroids, new_centroids, KmMatrixDim::ROW);
368
386
}
369
387
370
388
if (centroids.rows () < _k) {
371
- // FIXME: When n_centroids < k
372
- // Get random selection in?
373
- M_ERROR ( " Not implemented. " );
389
+ KmMatrix<T> new_centroids = KmeansRandomInit<T>(generator_)(_data,
390
+ _k - centroids. rows ());
391
+ centroids = stack (centroids, new_centroids, KmMatrixDim::ROW );
374
392
}
375
393
376
394
centroids = ReclusterPolicy<T>::recluster (centroids, k_);
@@ -385,9 +403,12 @@ KmeansLlInit<T, ReclusterPolicy>::operator()(KmMatrix<T>& _data, size_t _k) {
385
403
KmMatrix<T>& data, KmMatrix<T>& centroids); \
386
404
template KmMatrix<T> KmeansLlInit<T>::sample_centroids( \
387
405
KmMatrix<T>& data, KmMatrix<T>& centroids); \
406
+ template KmMatrix<T> KmeansRandomInit<T>::operator ()( \
407
+ KmMatrix<T>& _data, size_t _k);
388
408
389
409
INSTANTIATE (float )
390
410
INSTANTIATE(double )
411
+ INSTANTIATE(int )
391
412
392
413
#undef INSTANTIATE
393
414
@@ -403,15 +424,14 @@ namespace detail {
403
424
KmMatrix<T>& _centroids_dot, \
404
425
KmMatrix<T>& _distance_pairs); \
405
426
template KmMatrix<T> PairWiseDistanceOp<T>::operator ()( \
406
- KmMatrix<T>& _data, \
407
- KmMatrix<T>& _centroids); \
427
+ KmMatrix<T>& _data, KmMatrix<T>& _centroids); \
408
428
409
429
INSTANTIATE (float )
410
430
INSTANTIATE(double )
431
+ INSTANTIATE(int )
411
432
412
433
#undef INSTANTIATE
413
434
}
414
- // FIXME: int is not supported due to random kernel
415
435
416
436
} // namespace Kmeans
417
437
} // namespace H2O4GPU
0 commit comments