Skip to content

Commit e45ee04

Browse files
committed
Construct kmeans|| based on KmMatrix.
Builds basic kmeans|| framework on top of KmMatrix. The algorithm is not working yet.
1 parent e0b0c47 commit e45ee04

File tree

7 files changed

+553
-293
lines changed

7 files changed

+553
-293
lines changed

src/gpu/kmeans/KmMatrix/GpuInfo.cuh

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <stdlib.h>
1414
#include <stdio.h>
1515

16+
// Singleton class storing gpu info.
1617
class GpuInfo {
1718
private:
1819
int n_gpu_;
@@ -67,6 +68,4 @@ class GpuInfo {
6768

6869
};
6970

70-
// const GpuInfoImpl GpuInfo::impl = GpuInfoImpl();
71-
7271
#endif // GPU_INFO_HPP_

src/gpu/kmeans/KmMatrix/KmMatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ KmMatrix<T> KmMatrix<T>::stack(KmMatrix<T> &_second,
250250

251251
template <typename T>
252252
std::ostream& operator<<(std::ostream& os, KmMatrix<T>& m) {
253-
std::cout << "matrix: " << m.name() << std::endl << "---" << std::endl;
253+
std::cout << "\nmatrix: " << m.name() << std::endl << "---" << std::endl;
254254
T * ptr = m.host_ptr();
255255
kParam<T> param = m.k_param();
256256
for (size_t i = 0; i < param.rows; ++i) {

src/gpu/kmeans/KmMatrix/blas.cuh

+102-15
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ namespace KMeans {
1717
namespace Blas {
1818
// LEVEL 1
1919
inline void axpy(cublasHandle_t handle, int n,
20-
const float *alpha,
21-
const float *x, int incx,
22-
float *y, int incy) {
20+
const float *alpha,
21+
const float *x, int incx,
22+
float *y, int incy) {
2323
CUBLAS_CHECK(cublasSaxpy(handle, n,
2424
alpha,
2525
x, incx,
@@ -50,19 +50,13 @@ inline void gemm(cublasHandle_t handle,
5050
float *C,
5151
int ldc) {
5252
CUBLAS_CHECK(cublasSgemm(handle,
53-
transa,
54-
transb,
55-
m,
56-
n,
57-
k,
53+
transa, transb,
54+
m, n, k,
5855
alpha, /* host or device pointer */
59-
A,
60-
lda,
61-
B,
62-
ldb,
56+
A, lda,
57+
B, ldb,
6358
beta, /* host or device pointer */
64-
C,
65-
ldc));}
59+
C, ldc));}
6660

6761
inline void gemm(cublasHandle_t handle,
6862
cublasOperation_t transa,
@@ -93,8 +87,101 @@ inline void gemm(cublasHandle_t handle,
9387
C,
9488
ldc));}
9589

96-
} // Blas
90+
inline void gemm_batched(cublasHandle_t handle,
91+
cublasOperation_t transa,
92+
cublasOperation_t transb,
93+
int m, int n, int k,
94+
const double *alpha,
95+
const double *Aarray[], int lda,
96+
const double *Barray[], int ldb,
97+
const double *beta,
98+
double *Carray[], int ldc,
99+
int batchCount) {
100+
CUBLAS_CHECK(cublasDgemmBatched(handle,
101+
transa,
102+
transb,
103+
m, n, k,
104+
alpha,
105+
Aarray, lda,
106+
Barray, ldb,
107+
beta,
108+
Carray, ldc,
109+
batchCount));
110+
}
97111

112+
inline void gemm_batched(cublasHandle_t handle,
113+
cublasOperation_t transa,
114+
cublasOperation_t transb,
115+
int m, int n, int k,
116+
const float *alpha,
117+
const float *Aarray[], int lda,
118+
const float *Barray[], int ldb,
119+
const float *beta,
120+
float *Carray[], int ldc,
121+
int batchCount) {
122+
CUBLAS_CHECK(cublasSgemmBatched(handle,
123+
transa,
124+
transb,
125+
m, n, k,
126+
alpha,
127+
Aarray, lda,
128+
Barray, ldb,
129+
beta,
130+
Carray, ldc,
131+
batchCount));
132+
}
133+
134+
inline void gemm_strided_batched(
135+
cublasHandle_t handle,
136+
cublasOperation_t transA, cublasOperation_t transB,
137+
int M, int N, int K,
138+
const double* alpha,
139+
const double* A, int ldA, int strideA,
140+
const double* B, int ldB, int strideB,
141+
const double* beta,
142+
double* C, int ldC, int strideC,
143+
int batchCount) {
144+
CUBLAS_CHECK(cublasDgemmStridedBatched(handle,
145+
transA,
146+
transB,
147+
M, N, K,
148+
alpha,
149+
A, ldA,
150+
strideA,
151+
B, ldB,
152+
strideB,
153+
beta,
154+
C, ldC,
155+
strideC,
156+
batchCount));
157+
}
158+
159+
inline void gemm_strided_batched(
160+
cublasHandle_t handle,
161+
cublasOperation_t transA, cublasOperation_t transB,
162+
int M, int N, int K,
163+
const float* alpha,
164+
const float* A, int ldA, int strideA,
165+
const float* B, int ldB, int strideB,
166+
const float* beta,
167+
float* C, int ldC, int strideC,
168+
int batchCount) {
169+
CUBLAS_CHECK(cublasSgemmStridedBatched(handle,
170+
transA,
171+
transB,
172+
M, N, K,
173+
alpha,
174+
A, ldA,
175+
strideA,
176+
B, ldB,
177+
strideB,
178+
beta,
179+
C, ldC,
180+
strideC,
181+
batchCount));
182+
}
183+
184+
} // Blas
98185
} // KMeans
99186
} // H2O4GPU
100187

src/gpu/kmeans/KmMatrix/utils.cuh

+12-6
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@ M_DEVINLINE size_t global_thread_idx () {
1515
return threadIdx.x + blockIdx.x * blockDim.x;
1616
}
1717

18-
M_DEVINLINE size_t grid_stride () {
18+
M_DEVINLINE size_t global_thread_idy () {
19+
return threadIdx.y + blockIdx.y * blockDim.y;
20+
}
21+
22+
M_DEVINLINE size_t grid_stride_x () {
1923
return blockDim.x * gridDim.x;
2024
}
2125

22-
// This wrapper function is created to work around a possible bug in nvcc,
23-
// which threats GpuInfo::ins() as calling base class method when used inside a
24-
// class member function.
25-
size_t get_blocks(size_t _mul, int _device=0) {
26-
return GpuInfo::ins().blocks(_mul, _device);
26+
M_DEVINLINE size_t grid_stride_y () {
27+
return blockDim.y * gridDim.y;
28+
}
29+
30+
template <typename T1, typename T2>
31+
T1 M_HOSTDEVINLINE div_roundup(const T1 a, const T2 b) {
32+
return static_cast<T1>(ceil(static_cast<double>(a) / b));
2733
}
2834

2935
} // KMeans

0 commit comments

Comments
 (0)