Skip to content

Commit 6a8669a

Browse files
committed
Fix prob for kmeans||.
1 parent e45ee04 commit 6a8669a

File tree

2 files changed

+24
-28
lines changed

2 files changed

+24
-28
lines changed

src/gpu/kmeans/kmeans_init.cu

+22-27
Original file line numberDiff line numberDiff line change
@@ -235,29 +235,17 @@ struct PairWiseDistanceOp {
235235

236236
kernel::construct_distance_pairs_kernel<<<
237237
dim3(GpuInfo::ins().blocks(32), div_roundup(_centroids.rows(), 16)),
238-
dim3(16, 16)>>>(
238+
dim3(16, 16)>>>( // FIXME: Tune this.
239239
distance_pairs_.k_param(),
240240
data_dot_.k_param(),
241241
centroids_dot_.k_param());
242242
243243
CUDA_CHECK(cudaGetLastError());
244-
std::cout << std::endl;
245-
std::cout << "in distance op" << std::endl;
246-
std::cout << distance_pairs_ << std::endl;
247244
248245
cublasHandle_t handle = GpuInfo::ins().cublas_handle();
249246
250247
T alpha = -2.0;
251248
T beta = 1.0;
252-
std::cout << "data.shape: " << _data.rows() << ", " << _data.cols() <<
253-
"\tcentroids.shape: " << _centroids.rows() << ", " << _centroids.cols() <<
254-
"\tdp.shape: " << distance_pairs_.rows() << ", " << distance_pairs_.cols() <<
255-
std::endl;
256-
257-
std::cout << _data << std::endl;
258-
std::cout << _centroids << std::endl;
259-
260-
std::cout << _centroids.dev_ptr() << std::endl;
261249
262250
Blas::gemm(
263251
handle,
@@ -270,12 +258,11 @@ struct PairWiseDistanceOp {
270258
&beta,
271259
distance_pairs_.dev_ptr(), distance_pairs_.rows());
272260
273-
std::cout << distance_pairs_ << std::endl;
274-
std::cout << "return" << std::endl;
275261
return distance_pairs_;
276262
}
277263
};
278264
265+
279266
template <typename T>
280267
KmMatrix<T> KmeansLlInit<T>::probability(
281268
KmMatrix<T>& _data, KmMatrix<T>& _centroids) {
@@ -301,13 +288,20 @@ KmMatrix<T> KmeansLlInit<T>::probability(
301288
302289
CUDA_CHECK(cudaGetLastError());
303290
291+
std::cout << min_distances << std::endl;
292+
304293
T cost = SumOp<T>().sum(min_distances);
294+
std::cout << "cost: " << cost << std::endl;
305295
306296
// Re-use min_distances to store prob
307297
MulOp<T> mul_op;
308-
mul_op.mul(min_distances, min_distances, 1 / cost * over_sample_ * k_);
309298
310-
return min_distances;
299+
KmMatrix<T> prob (min_distances.rows(), 1);
300+
mul_op.mul(prob, min_distances, (over_sample_ * k_ * 1) / cost);
301+
302+
std::cout << prob << std::endl;
303+
304+
return prob;
311305
}
312306
313307
@@ -357,19 +351,19 @@ KmMatrix<T> KmeansLlInit<T>::sample_centroids(KmMatrix<T>& _data, KmMatrix<T>& _
357351
T prob_x = prob_ptr[idx];
358352
return prob_x > thresh;
359353
});
360-
354+
std::cout << std::endl;
361355
return new_centroids;
362356
}
363357
364358
template <typename T>
365359
KmMatrix<T>
366-
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
360+
KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t _k) {
367361
368362
if (seed_ < 0) {
369363
std::random_device rd;
370364
seed_ = rd();
371365
}
372-
k_ = k;
366+
k_ = _k;
373367
374368
std::mt19937 generator(0);
375369
@@ -386,14 +380,15 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
386380
KmMatrix<T> prob = probability(_data, centroids);
387381
388382
T cost = SumOp<T>().sum(prob);
389-
// FIXME
390-
// for (size_t i = 0; i < std::log(cost); ++i) {
391-
for (size_t i = 0; i < 1; ++i) {
392-
std::cout << "looping" << std::endl;
393-
KmMatrix<T> new_centroids = sample_centroids(_data, centroids);
394-
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
383+
384+
for (size_t i = 0; i < std::log(cost); ++i) {
395385
prob = probability(_data, centroids);
386+
KmMatrix<T> new_centroids = sample_centroids(_data, prob);
387+
new_centroids.set_name ("new centroids");
388+
std::cout << new_centroids << std::endl;
396389
centroids = stack(centroids, new_centroids, KmMatrixDim::ROW);
390+
centroids.set_name ("centroids");
391+
std::cout << centroids << std::endl;
397392
}
398393
399394
if (centroids.rows() < k_) {
@@ -407,7 +402,7 @@ KmeansLlInit<T>::operator()(KmMatrix<T>& _data, size_t k) {
407402
408403
#define INSTANTIATE(T) \
409404
template KmMatrix<T> KmeansLlInit<T>::operator()( \
410-
KmMatrix<T>& data, size_t k); \
405+
KmMatrix<T>& _data, size_t _k); \
411406
template KmMatrix<T> KmeansLlInit<T>::probability(KmMatrix<T>& data, \
412407
KmMatrix<T>& centroids); \
413408
template KmMatrix<T> KmeansLlInit<T>::sample_centroids( \

src/gpu/kmeans/kmeans_init.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ struct KmeansLlInit : public KmeansInitBase<T> {
159159
KmMatrix<T> probability(KmMatrix<T>& data, KmMatrix<T>& centroids);
160160

161161
public:
162-
KmeansLlInit () : over_sample_ (2.0), seed_ (0), k_(0) {
162+
KmeansLlInit (T _over_sample=2.0) :
163+
over_sample_ (_over_sample), seed_ (0), k_(0) {
163164
data_dot_.set_name ("data_dot");
164165
distance_pairs_.set_name ("distance pairs");
165166
}

0 commit comments

Comments
 (0)