-
Notifications
You must be signed in to change notification settings - Fork 113
Expand file tree
/
Copy pathdslash.h
More file actions
628 lines (555 loc) · 26.4 KB
/
dslash.h
File metadata and controls
628 lines (555 loc) · 26.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
#pragma once
#include <typeinfo>
#include <color_spinor_field.h>
#include <dslash_quda.h>
#include <dslash_helper.cuh>
#include <tunable_nd.h>
#include <instantiate.h>
#include <instantiate_dslash.h>
#include <tma_helper.hpp>
namespace quda
{
/**
@brief This is the generic driver for launching Dslash kernels
(the base kernel of which is defined in dslash_helper.cuh). This
is templated on the a template template parameter which is the
underlying operator wrapped in a class,
@tparam D A class that defines the linear operator we wish to
apply. This class should define an operator() method that is
used to apply the operator by the dslash kernel. See the wilson
class in the file kernels/dslash_wilson.cuh as an exmaple.
@tparam Arg The argument struct that is used to parameterize the
kernel. For the wilson class example above, the WilsonArg class
defined in the same file is the corresponding argument class.
*/
template <template <bool, bool, KernelType, typename> class D, typename Arg> class Dslash : public TunableKernel3D
{
protected:
Arg &arg;
cvector_ref<ColorSpinorField> &out;
cvector_ref<const ColorSpinorField> ∈
const ColorSpinorField &halo;
const int nDimComms;
char aux_base[TuneKey::aux_n - 32];
char aux[8][TuneKey::aux_n];
char aux_pack[TuneKey::aux_n];
char aux_barrier[TuneKey::aux_n];
// pointers to ghost buffers we are packing to
void *packBuffer[4 * QUDA_MAX_DIM];
/**
@brief Set the base strings used by the different dslash kernel
types for autotuning.
*/
inline void fillAuxBase(const std::string &app_base)
{
strcpy(aux_base, TunableKernel3D::aux);
char comm[5];
comm[0] = (arg.commDim[0] ? '1' : '0');
comm[1] = (arg.commDim[1] ? '1' : '0');
comm[2] = (arg.commDim[2] ? '1' : '0');
comm[3] = (arg.commDim[3] ? '1' : '0');
comm[4] = '\0';
strcat(aux_base, ",commDim=");
strcat(aux_base, comm);
strcat(aux_base, app_base.c_str());
if (arg.xpay) strcat(aux_base, ",xpay");
if (arg.dagger) strcat(aux_base, ",dagger");
setRHSstring(aux_base, in.size());
strcat(aux_base, ",n_rhs_tile=");
char tile_str[16];
i32toa(tile_str, Arg::n_src_tile);
strcat(aux_base, tile_str);
if constexpr (dslash_double_store()) strcat(aux_base, ",double_store");
if constexpr (Arg::prefetch_distance > 0) {
strcat(aux_base, ",prefetch=");
i32toa(tile_str, Arg::prefetch_distance);
strcat(aux_base, tile_str);
if constexpr (dslash_prefetch_type() == PrefetchType::THREAD)
strcat(aux_base, ",prefetch=thread");
else if constexpr (dslash_prefetch_type() == PrefetchType::BULK)
strcat(aux_base, ",prefetch=bulk");
else if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR)
strcat(aux_base, ",prefetch=tensor");
}
}
/**
@brief Specialize the auxiliary strings for each kernel type
@param[in] kernel_type The kernel_type we are generating the string got
@param[in] kernel_str String corresponding to the kernel type
*/
inline void fillAux(KernelType kernel_type, const char *kernel_str)
{
strcpy(aux[kernel_type], kernel_str);
strncat(aux[kernel_type], aux_base, TuneKey::aux_n - 1);
if (kernel_type == INTERIOR_KERNEL) strcat(aux[kernel_type], comm_dim_partitioned_string());
}
bool tuneSharedCarveOut() const override
{
// default is to do carve out tuning if the architecture supports it
static bool tune_shared = device::shared_carve_out_supported();
static bool init = false;
if (!init) {
char *enable_shared_env = getenv("QUDA_ENABLE_TUNING_SHARED_CARVE_OUT_DSLASH");
if (enable_shared_env) {
if (strcmp(enable_shared_env, "0") == 0) { tune_shared = false; }
}
init = true;
}
return tune_shared;
}
virtual bool tuneGridDim() const override { return arg.kernel_type == EXTERIOR_KERNEL_ALL && arg.shmem > 0; }
virtual unsigned int minThreads() const override { return arg.threads; }
virtual unsigned int minGridSize() const override
{
/* when using nvshmem we perform the exterior Dslash using a grid strided loop and uniquely assign communication
* directions to CUDA block and have all communication directions resident. We therefore figure out the number of
* communicating dimensions and make sure that the number of blocks is a multiple of the communicating directions (2*dim)
*/
if (arg.kernel_type == EXTERIOR_KERNEL_ALL && arg.shmem > 0) {
int nDimComms = 0;
for (int d = 0; d < 4; d++) nDimComms += arg.commDim[d];
return (device::processor_count() / (2 * nDimComms)) * (2 * nDimComms);
} else {
return TunableKernel3D::minGridSize();
}
}
virtual int gridStep() const override
{
/* see comment for minGridSize above for gridStep choice when using nvshmem */
if (arg.kernel_type == EXTERIOR_KERNEL_ALL && arg.shmem > 0) {
int nDimComms = 0;
for (int d = 0; d < 4; d++) nDimComms += arg.commDim[d];
return (device::processor_count() / (2 * nDimComms)) * (2 * nDimComms);
} else {
return TunableKernel3D::gridStep();
}
}
template <bool improved = false> inline void setParam(TuneParam &tp, const GaugeField &U, const GaugeField &L = {})
{
// Need to reset ghost pointers prior to every call since the
// ghost buffer may have been changed during policy tuning.
// Also, the accessor constructor calls Ghost(), which uses
// ghost_buf, but this is only presently set with the
// synchronous exchangeGhost.
static void *ghost[8] = {}; // needs to be persistent across interior and exterior calls
for (int dim = 0; dim < 4; dim++) {
for (int dir = 0; dir < 2; dir++) {
// if doing interior kernel, then this is the initial call,
// so we set all ghost pointers else if doing exterior
// kernel, then we only have to update the non-p2p ghosts,
// since these may have been assigned to zero-copy memory
if (!comm_peer2peer_enabled(dir, dim) || arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL) {
ghost[2 * dim + dir] = (typename Arg::Float *)((char *)halo.Ghost2() + halo.GhostOffset(dim, dir));
}
}
}
arg.halo.resetGhost(ghost, halo.SiteSubset());
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) {
arg.blocks_per_dir = tp.aux.x;
arg.setPack(true, this->packBuffer); // need to recompute for updated block_per_dir
arg.halo_pack.resetGhost(this->packBuffer, halo.SiteSubset());
tp.grid.x += arg.pack_blocks;
arg.counter = dslash::get_dslash_shmem_sync_counter();
}
if (arg.shmem > 0 && arg.kernel_type == EXTERIOR_KERNEL_ALL) {
// if we are doing tuning we should not wait on the sync_arr to be set.
arg.counter = (activeTuning() && !policyTuning()) ? 2 : dslash::get_dslash_shmem_sync_counter();
}
if (arg.shmem > 0 && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) {
arg.counter = activeTuning() ? (uberTuning() && !policyTuning() ? dslash::inc_dslash_shmem_sync_counter() :
dslash::get_dslash_shmem_sync_counter()) :
dslash::get_dslash_shmem_sync_counter();
arg.exterior_blocks = ((arg.shmem & 64) && arg.exterior_dims > 0) ?
(device::processor_count() / (2 * arg.exterior_dims)) * (2 * arg.exterior_dims * tp.aux.y) :
0;
tp.grid.x += arg.exterior_blocks;
}
if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR && Arg::prefetch_distance > 0) {
Dslash::arg.U.tensor_desc = get_tensor_descriptor(U, tp.block.x);
Dslash::arg.Uback.tensor_desc = get_tensor_descriptor(U.shift(), tp.block.x);
if constexpr (improved) {
assert(!U.empty());
Dslash::arg.L.tensor_desc = get_tensor_descriptor(L, tp.block.x);
Dslash::arg.Lback.tensor_desc = get_tensor_descriptor(L.shift(), tp.block.x);
}
}
}
virtual int blockStep() const override { return (arg.shmem & 64) ? 8 : 16; }
virtual int blockMin() const override { return (arg.shmem & 64) ? 8 : 16; }
unsigned int maxSharedBytesPerBlock() const override { return maxDynamicSharedBytesPerBlock(); }
virtual bool advanceAux(TuneParam ¶m) const override
{
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) {
int max_threads_per_dir = 0;
for (int i = 0; i < 4; ++i) {
max_threads_per_dir = std::max(max_threads_per_dir, (arg.threadDimMapUpper[i] - arg.threadDimMapLower[i]) / 2);
}
int nDimComms = 0;
for (int d = 0; d < 4; d++) nDimComms += arg.commDim[d];
/* if doing the fused packing + interior kernel we tune how many blocks to use for communication */
// use up to a quarter of the GPU for packing (but at least up to 4 blocks per dir)
const int max_blocks_per_dir = std::max(device::processor_count() / (8 * nDimComms), 4u);
if (param.aux.x + 1 <= max_blocks_per_dir
&& (param.aux.x + 1) * param.block.x < (max_threads_per_dir + param.block.x - 1)) {
param.aux.x++;
return true;
} else {
param.aux.x = 1;
if (arg.exterior_dims > 0 && arg.shmem & 64) {
/* if doing a fused interior+exterior kernel we use aux.y to control the number of blocks we add for the
* exterior. We make sure to use multiple blocks per communication direction.
*/
if (param.aux.y < 4) {
param.aux.y++;
return true;
} else {
param.aux.y = 1;
return false;
}
}
return false;
}
} else {
return false;
}
}
virtual bool advanceBlockDim(TuneParam ¶m) const override
{
// if TMA is enabled we must keep parity separate in the block (2-d tuning)
if constexpr (dslash_prefetch_tma())
return TunableKernel2D_base<false>::advanceBlockDim(param);
else
return TunableKernel3D::advanceBlockDim(param);
}
virtual bool advanceTuneParam(TuneParam ¶m) const override
{
return advanceAux(param) || advanceSharedBytes(param) || advanceBlockDim(param) || advanceSharedCarveOut(param)
|| advanceGridDim(param);
}
virtual void initTuneParam(TuneParam ¶m) const override
{
/* for nvshmem uber kernels the current synchronization requires us to keep the y and z dimension local to the
* block. This can be removed when we introduce a finer grained synchronization which takes into account the y and
* z components explicitly */
step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup;
step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup;
TunableKernel3D::initTuneParam(param);
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL))
param.aux.x = 1; // packing blocks per direction
if (arg.exterior_dims && arg.kernel_type == UBER_KERNEL) param.aux.y = 1; // exterior blocks
// if not autotuning the carve out, set to the historical optimal value (prefer shared memory)
param.shared_carve_out = tuneSharedCarveOut() ? 0 : 100;
}
virtual void defaultTuneParam(TuneParam ¶m) const override
{
/* for nvshmem uber kernels the current synchronization requires use to keep the y and z dimension local to the
* block. This can be removed when we introduce a finer grained synchronization which takes into account the y and
* z components explicitly. */
step_y = arg.shmem & 64 ? vector_length_y : step_y_bkup;
step_z = arg.shmem & 64 ? vector_length_z : step_z_bkup;
TunableKernel3D::defaultTuneParam(param);
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL))
param.aux.x = 1; // packing blocks per direction
if (arg.exterior_dims && arg.kernel_type == UBER_KERNEL) param.aux.y = 1; // exterior blocks
param.shared_carve_out = 100; // historical optimal value
}
/**
@brief This is a helper class that is used to instantiate the
correct templated kernel for the dslash. This can be used for
all dslash types, though in some cases we specialize to reduce
compilation time.
*/
template <template <bool, QudaPCType, typename> class P, bool dagger, bool xpay, KernelType kernel_type>
inline void launch(TuneParam &tp, const qudaStream_t &stream)
{
tp.set_max_shared_bytes = true;
if (dslash_prefetch_tma() && tp.block.z > 1) errorQuda("Z-dimension block size must be 1 when using TMA");
launch_device<dslash_functor>(
tp, stream, dslash_functor_arg<D, P, dagger, xpay, kernel_type, Arg>(arg, tp.block.x * tp.grid.x));
}
public:
/**
@brief This instantiate function is used to instantiate the
the KernelType template required for the multi-GPU dslash kernels.
@param[in] tp The tuning parameters to use for this kernel
@param[in] stream The qudaStream_t where the kernel will run
*/
template <template <bool, QudaPCType, typename> class P, bool dagger, bool xpay>
inline void instantiate(TuneParam &tp, const qudaStream_t &stream)
{
if (in.Location() == QUDA_CPU_FIELD_LOCATION) {
errorQuda("Not implemented");
} else {
switch (arg.kernel_type) {
case INTERIOR_KERNEL: launch<P, dagger, xpay, INTERIOR_KERNEL>(tp, stream); break;
#ifdef MULTI_GPU
#ifdef NVSHMEM_COMMS
case UBER_KERNEL: launch<P, dagger, xpay, UBER_KERNEL>(tp, stream); break;
#endif
case EXTERIOR_KERNEL_X: launch<P, dagger, xpay, EXTERIOR_KERNEL_X>(tp, stream); break;
case EXTERIOR_KERNEL_Y: launch<P, dagger, xpay, EXTERIOR_KERNEL_Y>(tp, stream); break;
case EXTERIOR_KERNEL_Z: launch<P, dagger, xpay, EXTERIOR_KERNEL_Z>(tp, stream); break;
case EXTERIOR_KERNEL_T: launch<P, dagger, xpay, EXTERIOR_KERNEL_T>(tp, stream); break;
case EXTERIOR_KERNEL_ALL: launch<P, dagger, xpay, EXTERIOR_KERNEL_ALL>(tp, stream); break;
default: errorQuda("Unexpected kernel type %d", arg.kernel_type);
#else
default: errorQuda("Unexpected kernel type %d for single-GPU build", arg.kernel_type);
#endif
}
}
}
/**
@brief This instantiate function is used to instantiate the
the dagger template
@param[in] tp The tuning parameters to use for this kernel
@param[in] stream The qudaStream_t where the kernel will run
*/
template <template <bool, QudaPCType, typename> class P, bool xpay>
inline void instantiate(TuneParam &tp, const qudaStream_t &stream)
{
if (arg.dagger)
instantiate<P, true, xpay>(tp, stream);
else
instantiate<P, false, xpay>(tp, stream);
}
/**
@brief This instantiate function is used to instantiate the
the xpay template
@param[in] tp The tuning parameters to use for this kernel
@param[in] stream The qudaStream_t where the kernel will run
*/
template <template <bool, QudaPCType, typename> class P>
inline void instantiate(TuneParam &tp, const qudaStream_t &stream)
{
if (arg.xpay)
instantiate<P, true>(tp, stream);
else
instantiate<P, false>(tp, stream);
}
Arg &dslashParam; // temporary addition for policy compatibility
Dslash(Arg &arg, cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
const ColorSpinorField &halo, const std::string &app_base = "") :
TunableKernel3D(in[0], (halo.X(4) + Arg::n_src_tile - 1) / Arg::n_src_tile, arg.nParity),
arg(arg),
out(out),
in(in),
halo(halo),
nDimComms(4),
dslashParam(arg)
{
if (checkLocation(out, in) == QUDA_CPU_FIELD_LOCATION)
errorQuda("CPU Fields not supported in Dslash framework yet");
// this sets the communications pattern for the packing kernel
setPackComms(arg.commDim);
if (!TunableKernel3D::tuneSharedCarveOut() && tuneSharedCarveOut())
strcat(TunableKernel3D::aux, getSharedCarveOutStr().c_str());
fillAuxBase(app_base);
#ifdef MULTI_GPU
fillAux(INTERIOR_KERNEL, "policy_kernel=interior,");
fillAux(UBER_KERNEL, "policy_kernel=uber,");
fillAux(EXTERIOR_KERNEL_ALL, "policy_kernel=exterior_all,");
fillAux(EXTERIOR_KERNEL_X, "policy_kernel=exterior_x,");
fillAux(EXTERIOR_KERNEL_Y, "policy_kernel=exterior_y,");
fillAux(EXTERIOR_KERNEL_Z, "policy_kernel=exterior_z,");
fillAux(EXTERIOR_KERNEL_T, "policy_kernel=exterior_t,");
#else
fillAux(INTERIOR_KERNEL, "policy_kernel=single,");
#endif // MULTI_GPU
fillAux(KERNEL_POLICY, "policy,");
#ifdef NVSHMEM_COMMS
strcpy(aux_barrier, aux[EXTERIOR_KERNEL_ALL]);
strcat(aux_barrier, ",shmem");
#endif
}
#ifdef NVSHMEM_COMMS
void setShmem(int shmem)
{
arg.shmem = shmem;
setUberTuning(arg.shmem & 64);
}
#else
void setShmem(int) { setUberTuning(arg.shmem & 64); }
#endif
void setPack(bool pack, MemoryLocation location)
{
if (!pack) {
arg.setPack(pack, packBuffer);
return;
}
for (int dim = 0; dim < 4; dim++) {
for (int dir = 0; dir < 2; dir++) {
if ((location & Remote) && comm_peer2peer_enabled(dir, dim)) { // pack to p2p remote
packBuffer[2 * dim + dir] = static_cast<char *>(halo.remoteFace_d(dir, dim)) + halo.GhostOffset(dim, 1 - dir);
} else if (location & Host && !comm_peer2peer_enabled(dir, dim)) { // pack to cpu memory
packBuffer[2 * dim + dir] = halo.myFace_hd(dir, dim);
} else if (location & Shmem) {
// we check whether we can directly pack into the in.remoteFace_d(dir, dim) buffer on the remote GPU
// pack directly into remote or local memory
packBuffer[2 * dim + dir] = halo.remoteFace_d(dir, dim) ?
static_cast<char *>(halo.remoteFace_d(dir, dim)) + halo.GhostOffset(dim, 1 - dir) :
halo.myFace_d(dir, dim);
// whether we need to shmem_putmem into the receiving buffer
packBuffer[2 * QUDA_MAX_DIM + 2 * dim + dir] = halo.remoteFace_d(dir, dim) ?
nullptr :
static_cast<char *>(halo.remoteFace_r()) + halo.GhostOffset(dim, 1 - dir);
} else { // pack to local gpu memory
packBuffer[2 * dim + dir] = halo.myFace_d(dir, dim);
}
}
}
arg.setPack(pack, packBuffer);
// set the tuning string for the fused interior + packer kernel
strcpy(aux_pack, aux[arg.kernel_type]);
strcat(aux_pack, "");
// label the locations we are packing to
// location label is nonp2p-p2p
switch ((int)location) {
case Device | Remote: strcat(aux_pack, ",device-remote"); break;
case Host | Remote: strcat(aux_pack, ",host-remote"); break;
case Device: strcat(aux_pack, ",device-device"); break;
case Host: strcat(aux_pack, comm_peer2peer_enabled_global() ? ",host-device" : ",host-host"); break;
case Shmem:
strcat(aux_pack, arg.exterior_dims > 0 ? ",shmemuber" : ",shmem");
strcat(aux_pack, (arg.shmem & 1 && arg.shmem & 2) ? "3" : "1");
strcat(aux_pack, comm_dim_topology_string());
break;
default: errorQuda("Unknown pack target location %d\n", location);
}
}
int Nface() const
{
return 2 * arg.nFace;
} // factor of 2 is for forwards/backwards (convention used in dslash policy)
int Dagger() const { return arg.dagger; }
const char *getAux(KernelType type) const { return aux[type]; }
void setAux(KernelType type, const char *aux_) { strcpy(aux[type], aux_); }
void augmentAux(KernelType type, const char *extra) { strcat(aux[type], extra); }
virtual TuneKey tuneKey() const override
{
auto aux_ = (arg.pack_blocks > 0 && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL)) ?
aux_pack :
((arg.shmem > 0 && arg.kernel_type == EXTERIOR_KERNEL_ALL) ? aux_barrier : aux[arg.kernel_type]);
return TuneKey(in.VolString().c_str(), typeid(*this).name(), aux_);
}
/**
@brief Save the output field since the output field is both
read from and written to in the exterior kernels
*/
virtual void preTune() override
{
if (arg.kernel_type != INTERIOR_KERNEL && arg.kernel_type != UBER_KERNEL && arg.kernel_type != KERNEL_POLICY)
out.backup();
}
/**
@brief Restore the output field if doing exterior kernel
*/
virtual void postTune() override
{
if (arg.kernel_type != INTERIOR_KERNEL && arg.kernel_type != UBER_KERNEL && arg.kernel_type != KERNEL_POLICY)
out.restore();
}
/*
per direction / dimension flops
spin project flops = Nc * Ns
SU(3) matrix-vector flops = (8 Nc - 2) * Nc
spin reconstruction flops = 2 * Nc * Ns (just an accumulation to all components)
xpay = 2 * 2 * Nc * Ns
So for the full dslash we have, where for the final spin
reconstruct we have -1 since the first direction does not
require any accumulation.
flops = (2 * Nd * Nc * Ns) + (2 * Nd * (Ns/2) * (8*Nc-2) * Nc) + ((2 * Nd - 1) * 2 * Nc * Ns)
flops_xpay = flops + 2 * 2 * Nc * Ns
For Wilson this should give 1344 for Nc=3,Ns=2 and 1368 for the xpay equivalent
*/
virtual long long flops() const override
{
int mv_flops = (8 * in.Ncolor() - 2) * in.Ncolor(); // SU(3) matrix-vector flops
int num_mv_multiply = in.Nspin() == 4 ? 2 : 1;
int ghost_flops = (num_mv_multiply * mv_flops + 2 * in.Ncolor() * in.Nspin());
int xpay_flops = 2 * 2 * in.Ncolor() * in.Nspin(); // multiply and add per real component
int num_dir = 2 * 4; // set to 4-d since we take care of 5-d fermions in derived classes where necessary
int pack_flops = (in.Nspin() == 4 ? 2 * in.Nspin() / 2 * in.Ncolor() : 0); // only flops if spin projecting
long long flops_ = 0;
// FIXME - should we count the xpay flops in the derived kernels
// since some kernels require the xpay in the exterior (preconditiond clover)
switch (arg.kernel_type) {
case EXTERIOR_KERNEL_X:
case EXTERIOR_KERNEL_Y:
case EXTERIOR_KERNEL_Z:
case EXTERIOR_KERNEL_T:
flops_ = (ghost_flops + (arg.xpay ? xpay_flops : xpay_flops / 2)) * 2 * halo.GhostFace()[arg.kernel_type];
break;
case EXTERIOR_KERNEL_ALL: {
long long ghost_sites
= 2 * (halo.GhostFace()[0] + halo.GhostFace()[1] + halo.GhostFace()[2] + halo.GhostFace()[3]);
flops_ = (ghost_flops + (arg.xpay ? xpay_flops : xpay_flops / 2)) * ghost_sites;
break;
}
case INTERIOR_KERNEL:
case UBER_KERNEL:
case KERNEL_POLICY: {
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL))
flops_ += pack_flops * arg.nParity * halo.getDslashConstant().Ls * arg.pack_threads;
long long sites = halo.Volume();
flops_ = (num_dir * (in.Nspin() / 4) * in.Ncolor() * in.Nspin() + // spin project (=0 for staggered)
num_dir * num_mv_multiply * mv_flops + // SU(3) matrix-vector multiplies
((num_dir - 1) * 2 * in.Ncolor() * in.Nspin()))
* sites; // accumulation
if (arg.xpay) flops_ += xpay_flops * sites;
if (arg.kernel_type == KERNEL_POLICY) break;
// now correct for flops done by exterior kernel
long long ghost_sites = 0;
for (int d = 0; d < 4; d++)
if (arg.commDim[d]) ghost_sites += 2 * halo.GhostFace()[d];
flops_ -= ghost_flops * ghost_sites;
if (arg.kernel_type == INTERIOR_KERNEL && arg.pack_threads)
flops_ += pack_flops * arg.nParity * halo.getDslashConstant().Ls * arg.pack_threads;
break;
}
}
return flops_;
}
virtual long long bytes() const override
{
int gauge_bytes = arg.reconstruct * in.Precision();
bool isFixed = (in.Precision() == sizeof(short) || in.Precision() == sizeof(char)) ? true : false;
int spinor_bytes = 2 * in.Ncolor() * in.Nspin() * in.Precision() + (isFixed ? sizeof(float) : 0);
int proj_spinor_bytes = in.Nspin() == 4 ? spinor_bytes / 2 : spinor_bytes;
int ghost_bytes = (proj_spinor_bytes + gauge_bytes) + 2 * spinor_bytes; // 2 since we have to load the partial
int num_dir = 2 * 4; // set to 4-d since we take care of 5-d fermions in derived classes where necessary
int pack_bytes = 2 * ((in.Nspin() == 4 ? in.Nspin() / 2 : in.Nspin()) + in.Nspin()) * in.Ncolor() * in.Precision();
long long bytes_ = 0;
switch (arg.kernel_type) {
case EXTERIOR_KERNEL_X:
case EXTERIOR_KERNEL_Y:
case EXTERIOR_KERNEL_Z:
case EXTERIOR_KERNEL_T: bytes_ = ghost_bytes * 2 * halo.GhostFace()[arg.kernel_type]; break;
case EXTERIOR_KERNEL_ALL: {
long long ghost_sites
= 2 * (halo.GhostFace()[0] + halo.GhostFace()[1] + halo.GhostFace()[2] + halo.GhostFace()[3]);
bytes_ = ghost_bytes * ghost_sites;
break;
}
case INTERIOR_KERNEL:
case UBER_KERNEL:
case KERNEL_POLICY: {
if (arg.pack_threads && (arg.kernel_type == INTERIOR_KERNEL || arg.kernel_type == UBER_KERNEL))
bytes_ += pack_bytes * arg.nParity * halo.getDslashConstant().Ls * arg.pack_threads;
long long sites = halo.Volume();
bytes_ = (num_dir * gauge_bytes + ((num_dir - 2) * spinor_bytes + 2 * proj_spinor_bytes) + spinor_bytes) * sites;
if (arg.xpay) bytes_ += spinor_bytes;
if (arg.kernel_type == KERNEL_POLICY) break;
// now correct for bytes done by exterior kernel
long long ghost_sites = 0;
for (int d = 0; d < 4; d++)
if (arg.commDim[d]) ghost_sites += 2 * halo.GhostFace()[d];
bytes_ -= ghost_bytes * ghost_sites;
if (arg.kernel_type == INTERIOR_KERNEL && arg.pack_threads)
bytes_ += pack_bytes * arg.nParity * halo.getDslashConstant().Ls * arg.pack_threads;
break;
}
}
return bytes_;
}
};
} // namespace quda