Skip to content

Commit 9d820c1

Browse files
committed
Fix prob for kmeans||.
1 parent e45ee04 commit 9d820c1

File tree

2 files changed

+25
-34
lines changed

2 files changed

+25
-34
lines changed

src/gpu/kmeans/kmeans_init.cu

+23-33
Original file line numberDiff line numberDiff line change
@@ -235,29 +235,17 @@ struct PairWiseDistanceOp {
235235

236236
kernel::construct_distance_pairs_kernel<<<
237237
dim3(GpuInfo::ins().blocks(32), div_roundup(_centroids.rows(), 16)),
238-
dim3(16, 16)>>>(
238+
dim3(16, 16)>>>( // FIXME: Tune this.
239239
distance_pairs_.k_param(),
240240
data_dot_.k_param(),
241241
centroids_dot_.k_param());
242242
243243
CUDA_CHECK(cudaGetLastError());
244-
std::cout << std::endl;
245-
std::cout << "in distance op" << std::endl;
246-
std::cout << distance_pairs_ << std::endl;
247244
248245
cublasHandle_t handle = GpuInfo::ins().cublas_handle();
249246
250247
T alpha = -2.0;
251248
T beta = 1.0;
252-
std::cout << "data.shape: " << _data.rows() << ", " << _data.cols() <<
253-
"\tcentroids.shape: " << _centroids.rows() << ", " << _centroids.cols() <<
254-
"\tdp.shape: " << distance_pairs_.rows() << ", " << distance_pairs_.cols() <<
255-
std::endl;
256-
257-
std::cout << _data << std::endl;
258-
std::cout << _centroids << std::endl;
259-
260-
std::cout << _centroids.dev_ptr() << std::endl;
261249
262250
Blas::gemm(
263251
handle,
@@ -270,12 +258,11 @@ struct PairWiseDistanceOp {
270258
&beta,
271259
distance_pairs_.dev_ptr(), distance_pairs_.rows());
272260
273-
std::cout << distance_pairs_ << std::endl;
274-
std::cout << "return" << std::endl;
275261
return distance_pairs_;
276262
}
277263
};
278264
265+
279266
template <typename T>
280267
KmMatrix<T> KmeansLlInit<T>::probability(
281268
KmMatrix<T>& _data, KmMatrix<T>& _centroids) {
@@ -301,13 +288,19 @@ KmMatrix<T> KmeansLlInit<T>::probability(
301288
302289
CUDA_CHECK(cudaGetLastError());
303290
291+
std::cout << min_distances << std::endl;
292+
304293
T cost = SumOp<T>().sum(min_distances);
294+
std::cout << "cost: " << cost << std::endl;
305295
306-
// Re-use min_distances to store prob
307296
MulOp<T> mul_op;
308-
mul_op.mul(min_distances, min_distances, 1 / cost * over_sample_ * k_);
309297
310-
return min_distances;
298+
KmMatrix<T> prob (min_distances.rows(), 1);
299+
mul_op.mul(prob, min_distances, (over_sample_ * k_ * 1) / cost);
300+
301+
std::cout << prob << std::endl;
302+
303+
return prob;
311304
}
312305
313306
@@ -316,11 +309,7 @@ KmMatrix<T> KmeansLlInit<T>::sample_centroids(KmMatrix<T>& _data, KmMatrix<T>& _
316309
317310
KmMatrix<T> distances (1, _data.rows());
318311
319-
T potential = SumOp<T>().sum(_prob);
320-
321-
MulOp<T>().mul(_prob, _prob, 1 / potential);
322-
323-
312+
// FIXME: Keep generator out.
324313
Generator<T> uniform_dist(_data.rows());
325314
KmMatrix<T> thresholds = uniform_dist.generate();
326315
@@ -357,19 +346,19 @@ KmMatrix<T> KmeansLlInit<T>::sample_centroids(KmMatrix<T>& _data, KmMatrix<T>& _
357346
T prob_x = prob_ptr[idx];
358347
return prob_x > thresh;
359348
});
360-
349+
std::cout << std::endl;
361350
return new_centroids;
362351
}
363352
364353
template <typename T>
365354
KmMatrix<T>
366-
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
355+
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t _k) {
367356
368357
if (seed_ < 0) {
369358
std::random_device rd;
370359
seed_ = rd();
371360
}
372-
k_ = k;
361+
k_ = _k;
373362
374363
std::mt19937 generator(0);
375364
@@ -386,14 +375,15 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
386375
KmMatrix<T> prob = probability(_data, centroids);
387376
388377
T cost = SumOp<T>().sum(prob);
389-
// FIXME
390-
// for (size_t i = 0; i < std::log(cost); ++i) {
391-
for (size_t i = 0; i < 1; ++i) {
392-
std::cout << "looping" << std::endl;
393-
KmMatrix<T> new_centroids = sample_centroids(_data, centroids);
394-
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
378+
379+
for (size_t i = 0; i < std::log(cost); ++i) {
395380
prob = probability(_data, centroids);
381+
KmMatrix<T> new_centroids = sample_centroids(_data, prob);
382+
new_centroids.set_name ("new centroids");
383+
std::cout << new_centroids << std::endl;
396384
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
385+
centroids.set_name ("centroids");
386+
std::cout << centroids << std::endl;
397387
}
398388
399389
if (centroids.rows() < k_) {
@@ -407,7 +397,7 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
407397
408398
#define INSTANTIATE(T) \
409399
template KmMatrix<T> KmeansLlInit<T>::operator()( \
410-
KmMatrix<T>& data, size_t k); \
400+
KmMatrix<T>& _data, size_t _k); \
411401
template KmMatrix<T> KmeansLlInit<T>::probability(KmMatrix<T>& data, \
412402
KmMatrix<T>& centroids); \
413403
template KmMatrix<T> KmeansLlInit<T>::sample_centroids( \

src/gpu/kmeans/kmeans_init.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ struct KmeansLlInit : public KmeansInitBase<T> {
159159
KmMatrix<T> probability(KmMatrix<T>& data, KmMatrix<T>& centroids);
160160

161161
public:
162-
KmeansLlInit () : over_sample_ (2.0), seed_ (0), k_(0) {
162+
KmeansLlInit (T _over_sample=2.0) :
163+
over_sample_ (_over_sample), seed_ (0), k_(0) {
163164
data_dot_.set_name ("data_dot");
164165
distance_pairs_.set_name ("distance pairs");
165166
}

0 commit comments

Comments
 (0)