@@ -17,9 +17,9 @@ namespace KMeans {
17
17
namespace Blas {
18
18
// LEVEL 1
19
19
inline void axpy (cublasHandle_t handle, int n,
20
- const float *alpha,
21
- const float *x, int incx,
22
- float *y, int incy) {
20
+ const float *alpha,
21
+ const float *x, int incx,
22
+ float *y, int incy) {
23
23
CUBLAS_CHECK (cublasSaxpy (handle, n,
24
24
alpha,
25
25
x, incx,
@@ -50,19 +50,13 @@ inline void gemm(cublasHandle_t handle,
50
50
float *C,
51
51
int ldc) {
52
52
CUBLAS_CHECK (cublasSgemm (handle,
53
- transa,
54
- transb,
55
- m,
56
- n,
57
- k,
53
+ transa, transb,
54
+ m, n, k,
58
55
alpha, /* host or device pointer */
59
- A,
60
- lda,
61
- B,
62
- ldb,
56
+ A, lda,
57
+ B, ldb,
63
58
beta, /* host or device pointer */
64
- C,
65
- ldc));}
59
+ C, ldc));}
66
60
67
61
inline void gemm (cublasHandle_t handle,
68
62
cublasOperation_t transa,
@@ -93,8 +87,101 @@ inline void gemm(cublasHandle_t handle,
93
87
C,
94
88
ldc));}
95
89
96
- } // Blas
90
+ inline void gemm_batched (cublasHandle_t handle,
91
+ cublasOperation_t transa,
92
+ cublasOperation_t transb,
93
+ int m, int n, int k,
94
+ const double *alpha,
95
+ const double *Aarray[], int lda,
96
+ const double *Barray[], int ldb,
97
+ const double *beta,
98
+ double *Carray[], int ldc,
99
+ int batchCount) {
100
+ CUBLAS_CHECK (cublasDgemmBatched (handle,
101
+ transa,
102
+ transb,
103
+ m, n, k,
104
+ alpha,
105
+ Aarray, lda,
106
+ Barray, ldb,
107
+ beta,
108
+ Carray, ldc,
109
+ batchCount));
110
+ }
97
111
112
+ inline void gemm_batched (cublasHandle_t handle,
113
+ cublasOperation_t transa,
114
+ cublasOperation_t transb,
115
+ int m, int n, int k,
116
+ const float *alpha,
117
+ const float *Aarray[], int lda,
118
+ const float *Barray[], int ldb,
119
+ const float *beta,
120
+ float *Carray[], int ldc,
121
+ int batchCount) {
122
+ CUBLAS_CHECK (cublasSgemmBatched (handle,
123
+ transa,
124
+ transb,
125
+ m, n, k,
126
+ alpha,
127
+ Aarray, lda,
128
+ Barray, ldb,
129
+ beta,
130
+ Carray, ldc,
131
+ batchCount));
132
+ }
133
+
134
+ inline void gemm_strided_batched (
135
+ cublasHandle_t handle,
136
+ cublasOperation_t transA, cublasOperation_t transB,
137
+ int M, int N, int K,
138
+ const double * alpha,
139
+ const double * A, int ldA, int strideA,
140
+ const double * B, int ldB, int strideB,
141
+ const double * beta,
142
+ double * C, int ldC, int strideC,
143
+ int batchCount) {
144
+ CUBLAS_CHECK (cublasDgemmStridedBatched (handle,
145
+ transA,
146
+ transB,
147
+ M, N, K,
148
+ alpha,
149
+ A, ldA,
150
+ strideA,
151
+ B, ldB,
152
+ strideB,
153
+ beta,
154
+ C, ldC,
155
+ strideC,
156
+ batchCount));
157
+ }
158
+
159
+ inline void gemm_strided_batched (
160
+ cublasHandle_t handle,
161
+ cublasOperation_t transA, cublasOperation_t transB,
162
+ int M, int N, int K,
163
+ const float * alpha,
164
+ const float * A, int ldA, int strideA,
165
+ const float * B, int ldB, int strideB,
166
+ const float * beta,
167
+ float * C, int ldC, int strideC,
168
+ int batchCount) {
169
+ CUBLAS_CHECK (cublasSgemmStridedBatched (handle,
170
+ transA,
171
+ transB,
172
+ M, N, K,
173
+ alpha,
174
+ A, ldA,
175
+ strideA,
176
+ B, ldB,
177
+ strideB,
178
+ beta,
179
+ C, ldC,
180
+ strideC,
181
+ batchCount));
182
+ }
183
+
184
+ } // Blas
98
185
} // KMeans
99
186
} // H2O4GPU
100
187
0 commit comments