-
Notifications
You must be signed in to change notification settings - Fork 113
Expand file tree
/
Copy pathmulti_blas_core.cuh
More file actions
290 lines (253 loc) · 10.5 KB
/
multi_blas_core.cuh
File metadata and controls
290 lines (253 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#pragma once
#include <blas_helper.cuh>
#include <multi_blas_helper.cuh>
#include <array.h>
#include <constant_kernel_arg.h>
#include <kernel.h>
#include <warp_collective.h>
namespace quda
{
namespace blas
{
#ifndef QUDA_FAST_COMPILE_REDUCE
constexpr bool enable_warp_split() { return false; }
#else
constexpr bool enable_warp_split() { return true; }
#endif
/**
@brief Parameter struct for generic multi-blas kernel.
@tparam warp_split_ The degree of warp splitting over the NXZ dimension
@tparam real_ The precision of the calculation
@tparam n_ The number of real elements per thread
@tparam NXZ_ is dimension of input vectors: X,Z
@tparam store_t Default store type for the fields
@tparam N Default field vector i/o length
@tparam y_store_t Store type for the y fields
@tparam N Y-field vector i/o length
@tparam Functor_ Functor used to operate on data
*/
template <int warp_split_, typename real_, int n_, int NXZ_, typename store_t, int N, typename y_store_t, int Ny, typename Functor_>
struct MultiBlasArg : kernel_param<>,
SpinorXZ<NXZ_, store_t, N, Functor_::use_z>,
SpinorYW<max_YW_size<NXZ_, store_t, y_store_t, Functor_>(), store_t, N, y_store_t, Ny, Functor_::use_w> {
static constexpr ThreadsSync requires_threads_sync = ThreadsSyncAll;
using real = real_;
using Functor = Functor_;
static constexpr int warp_split = warp_split_;
static constexpr int n = n_;
static constexpr int NXZ = NXZ_;
static constexpr int NYW_max = max_YW_size<NXZ, store_t, y_store_t, Functor>();
Functor f;
template <typename V>
MultiBlasArg(V &x, V &y, V &z, V &w, Functor f, int NYW, int length) :
kernel_param(dim3(length * warp_split, NYW, x.SiteSubset())), f(f)
{
if (NYW > NYW_max) errorQuda("NYW = %d greater than maximum size of %d", NYW, NYW_max);
for (int i = 0; i < NXZ; ++i) {
this->X[i] = static_cast<ColorSpinorField&>(x[i]);
if (Functor::use_z) this->Z[i] = static_cast<ColorSpinorField&>(z[i]);
}
for (int i = 0; i < NYW; ++i) {
this->Y[i] = static_cast<ColorSpinorField&>(y[i]);
if (Functor::use_w) this->W[i] = static_cast<ColorSpinorField&>(w[i]);
}
}
};
/**
@brief Generic multi-blas kernel with four loads and up to four stores.
@param[in,out] arg Argument struct with required meta data
(input/output fields, functor, etc.)
*/
template <typename Arg> struct MultiBlas_ {
const Arg &arg;
constexpr MultiBlas_(const Arg &arg) : arg(arg) {}
static constexpr const char *filename() { return KERNEL_FILE; }
__device__ __host__ inline void operator()(int i, int k, int parity)
{
using vec = array<complex<typename Arg::real>, Arg::n/2>;
// partition the warp between grid points and the NXZ update
constexpr int warp_size = device::warp_size();
constexpr int warp_split = Arg::warp_split;
constexpr int vector_site_width = warp_size / warp_split;
const int lane_id = i % warp_size;
const int warp_id = i / warp_size;
const int idx = warp_id * (warp_size / warp_split) + lane_id % vector_site_width;
const int l_idx = lane_id / vector_site_width;
vec x, y, z, w;
if (l_idx == 0 || warp_split == 1) {
if (arg.f.read.Y) arg.Y[k].load(y, idx, parity);
if (arg.f.read.W) arg.W[k].load(w, idx, parity);
} else {
y = ::quda::zero<complex<typename Arg::real>, Arg::n/2>();
w = ::quda::zero<complex<typename Arg::real>, Arg::n/2>();
}
#pragma unroll
for (int l_ = 0; l_ < Arg::NXZ; l_ += warp_split) {
const int l = l_ + l_idx;
if (l < Arg::NXZ || warp_split == 1) {
if (arg.f.read.X) arg.X[l].load(x, idx, parity);
if (arg.f.read.Z) arg.Z[l].load(z, idx, parity);
arg.f(x, y, z, w, k, l);
}
}
// now combine the results across the warp if needed
if (arg.f.write.Y) y = warp_combine<warp_split>(y);
if (arg.f.write.W) w = warp_combine<warp_split>(w);
if (l_idx == 0 || warp_split == 1) {
if (arg.f.write.Y) arg.Y[k].save(y, idx, parity);
if (arg.f.write.W) arg.W[k].save(w, idx, parity);
}
}
};
template <typename coeff_t_, bool multi_1d_ = false>
struct MultiBlasFunctor : MultiBlasParam<coeff_t_, false, multi_1d_> {
using coeff_t = coeff_t_;
static constexpr bool reducer = false;
static constexpr bool coeff_mul = true;
static constexpr bool multi_1d = multi_1d_;
MultiBlasFunctor(int NXZ, int NYW) : MultiBlasParam<coeff_t, reducer, multi_1d>(NXZ, NYW) {}
};
/**
Functor performing the operations: y[i] = a*x[i] + y[i]
*/
template <typename real>
struct multiaxpy_ : public MultiBlasFunctor<real> {
static constexpr memory_access<1, 1> read{ };
static constexpr memory_access<0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = false;
static constexpr int NXZ_max = 0;
using MultiBlasFunctor<real>::a;
multiaxpy_(int NXZ, int NYW) : MultiBlasFunctor<real>(NXZ, NYW) {}
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &, int i, int j) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) y[k] += a(j, i) * x[k];
}
constexpr int flops() const { return 2; } //! flops per real element
};
/**
Functor to perform the operation y += a * x (complex-valued)
*/
template <typename real>
struct multicaxpy_ : public MultiBlasFunctor<complex<real>> {
static constexpr memory_access<1, 1> read{ };
static constexpr memory_access<0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = false;
static constexpr int NXZ_max = 0;
using MultiBlasFunctor<complex<real>>::a;
multicaxpy_(int NXZ, int NYW) : MultiBlasFunctor<complex<real>>(NXZ, NYW) {}
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &, int i, int j) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) y[k] = cmac(a(j, i), x[k], y[k]);
}
constexpr int flops() const { return 4; } //! flops per real element
};
/**
Functor to perform the operation w = a * x + y
*/
template <typename real>
struct multiaxpyz_ : public MultiBlasFunctor<real> {
static constexpr memory_access<1, 1, 0, 0> read{ };
static constexpr memory_access<0, 0, 0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = true;
static constexpr int NXZ_max = 0;
using MultiBlasFunctor<real>::a;
multiaxpyz_(int NXZ, int NYW) : MultiBlasFunctor<real>(NXZ, NYW) {}
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &w, int i, int j) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) {
if (j == 0) w[k] = y[k];
w[k] = a(j, i) * x[k] + w[k];
}
}
constexpr int flops() const { return 2; } //! flops per real element
};
/**
Functor to perform the operation w = a * x + y (complex-valued)
*/
template <typename real>
struct multicaxpyz_ : public MultiBlasFunctor<complex<real>> {
static constexpr memory_access<1, 1, 0, 0> read{ };
static constexpr memory_access<0, 0, 0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = true;
static constexpr int NXZ_max = 0;
using MultiBlasFunctor<complex<real>>::a;
multicaxpyz_(int NXZ, int NYW) : MultiBlasFunctor<complex<real>>(NXZ, NYW) {}
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &w, int i, int j) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) {
if (j == 0) w[k] = y[k];
w[k] = cmac(a(j, i), x[k], w[k]);
}
}
constexpr int flops() const { return 4; } //! flops per real element
};
/**
Functor performing the operations: y[i] = a*w[i] + y[i]; w[i] = b*x[i] + c*w[i]
*/
template <typename real>
struct multi_axpyBzpcx_ : public MultiBlasFunctor<real, true> {
static constexpr memory_access<1, 1, 0, 1> read{ };
static constexpr memory_access<0, 1, 0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = true;
static constexpr int NXZ_max = 1; // we never have NXZ > 1 for this kernel
// this is a multi-1d functor so the coefficients are stored in the struct
// set max 1-d size equal to max power of two
static constexpr int N = max_N_multi_1d();
real a[N];
real b[N];
real c[N];
multi_axpyBzpcx_(int NXZ, int NYW) : MultiBlasFunctor<real, true>(NXZ, NYW) {}
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &w, int i, int) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) {
y[k] += a[i] * w[k];
w[k] = b[i] * x[k] + c[i] * w[k];
}
}
constexpr int flops() const { return 5; } //! flops per real element
};
/**
Functor performing the operations y[i] = a*x[i] + y[i] and w[i] = b*x[i] + w[i]
*/
template <typename real>
struct multi_caxpyBxpz_ : public MultiBlasFunctor<complex<real>, true> {
static constexpr memory_access<1, 1, 0, 1> read{ };
static constexpr memory_access<0, 1, 0, 1> write{ };
static constexpr bool use_z = false;
static constexpr bool use_w = true;
static constexpr int NXZ_max = 0;
static constexpr int N = max_N_multi_1d();
complex<real> a[N];
complex<real> b[N];
complex<real> c[N];
multi_caxpyBxpz_(int NXZ, int NYW) : MultiBlasFunctor<complex<real>, true>(NXZ, NYW)
{
for (int i = 0; i < N; i++) {
a[i] = 0.0;
b[i] = 0.0;
c[i] = 0.0;
}
}
// i loops over NYW, j loops over NXZ
template <typename T> __device__ __host__ inline void operator()(T &x, T &y, T &, T &w, int, int j) const
{
#pragma unroll
for (int k = 0; k < x.size(); k++) {
y[k] = cmac(a[j], x[k], y[k]);
w[k] = cmac(b[j], x[k], w[k]);
}
}
constexpr int flops() const { return 8; } //! flops per real element
};
} // namespace blas
} // namespace quda