-
Notifications
You must be signed in to change notification settings - Fork 113
Expand file tree
/
Copy pathcovariant_derivative.cuh
More file actions
193 lines (153 loc) · 7.95 KB
/
covariant_derivative.cuh
File metadata and controls
193 lines (153 loc) · 7.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#pragma once
#include <dslash_helper.cuh>
#include <color_spinor_field_order.h>
#include <gauge_field_order.h>
#include <color_spinor.h>
#include <dslash_helper.cuh>
#include <index_helper.cuh>
#include <kernels/dslash_pack.cuh>
namespace quda
{
/**
@brief Parameter structure for driving the covariant derivative operator
*/
template <typename Float, int nSpin_, int nColor_, typename DDArg, QudaReconstructType reconstruct_, int nDim, bool shift_>
struct CovDevArg : DslashArg<Float, nDim, DDArg> {
static constexpr int nColor = nColor_;
static constexpr int nSpin = nSpin_;
static constexpr bool spin_project = false;
static constexpr bool spinor_direct_load = false; // false means texture load
typedef typename colorspinor_mapper<Float, nSpin, nColor, spin_project, spinor_direct_load, true>::type F;
using Ghost = typename colorspinor::GhostNOrder<Float, nSpin, nColor, spin_project, spinor_direct_load, false>;
static constexpr QudaReconstructType reconstruct = reconstruct_;
static constexpr bool gauge_direct_load = false; // false means texture load
static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_PAD;
typedef typename gauge_mapper<Float, reconstruct, 18, QUDA_STAGGERED_PHASE_NO, gauge_direct_load, ghost>::type G;
static constexpr bool shift = shift_;
typedef typename mapper<Float>::type real;
F out[MAX_MULTI_RHS]; /** output vector field */
F in[MAX_MULTI_RHS]; /** input vector field */
const Ghost halo_pack; /** accessor for writing the halo field */
const Ghost halo; /** accessor for reading the halo field */
const G U; /** the gauge field */
int mu; /** The direction in which to apply the derivative */
CovDevArg(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &halo,
const GaugeField &U, int mu, int parity, bool dagger, const int *comm_override) :
DslashArg<Float, nDim, DDArg>(out, in, halo, U, in, parity, dagger, false, spin_project, comm_override),
halo_pack(halo),
halo(halo),
U(U),
mu(mu)
{
for (auto i = 0u; i < out.size(); i++) {
this->out[i] = out[i];
this->in[i] = in[i];
}
}
};
/**
Applies the off-diagonal part of the covariant derivative operator
@param[out] out The out result field
@param[in,out] arg Parameter struct
@param[in] U The gauge field
@param[in] coord Site coordinate struct
@param[in] x_cb The checker-boarded site index. This is a 4-d index only
@param[in] parity The site parity
@param[in] idx Thread index (equal to face index for exterior kernels)
@param[in] thread_dim Which dimension this thread corresponds to (fused exterior only)
*/
template <bool dagger, KernelType kernel_type, int mu, typename Coord, typename Arg, typename Vector>
__device__ __host__ inline void applyCovDev(Vector &out, const Arg &arg, Coord &coord, int parity, int,
int thread_dim, bool &active, int src_idx)
{
typedef typename mapper<typename Arg::Float>::type real;
typedef Matrix<complex<real>, Arg::nColor> Link;
const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0;
const int d = mu % 4;
if (mu < 4 && arg.dd_in.doHopping(coord, d, +1)) {
// Forward gather - compute fwd offset for vector fetch
const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc);
const bool ghost = (coord[d] + 1 >= arg.dc.X[d]) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
if (doHalo<kernel_type>(d) && ghost) {
const int ghost_idx = ghostFaceIndex<1>(coord, arg.dc.X, d, arg.nFace);
const Vector in = arg.halo.Ghost(d, 1, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity);
if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U(d, coord.x_cb, parity);
out += U * in;
}
} else if (doBulk<kernel_type>() && !ghost) {
const Vector in = arg.in[src_idx](fwd_idx, their_spinor_parity);
if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U(d, coord.x_cb, parity);
out += U * in;
}
}
} else if (mu >= 4 && arg.dd_in.doHopping(coord, d, -1)) {
// Backward gather - compute back offset for spinor and gauge fetch
const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc);
const bool ghost = (coord[d] - 1 < 0) && isActive<kernel_type>(active, thread_dim, d, coord, arg);
if (doHalo<kernel_type>(d) && ghost) {
const int ghost_idx = ghostFaceIndex<0>(coord, arg.dc.X, d, arg.nFace);
const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
const Vector in = arg.halo.Ghost(d, 0, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity);
if constexpr (Arg::shift) {
out += in;
} else {
const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity);
out += conj(U) * in;
}
} else if (doBulk<kernel_type>() && !ghost) {
const Vector in = arg.in[src_idx](back_idx, their_spinor_parity);
if constexpr (Arg::shift) {
out += in;
} else {
const int gauge_idx = back_idx;
const Link U = Arg::shift ? Link() : arg.U(d, gauge_idx, 1 - parity);
out += conj(U) * in;
}
}
} // Forward/backward derivative
}
// out(x) = M*in
template <bool dagger, bool xpay, KernelType kernel_type, typename Arg> struct covDev : dslash_default {
const Arg &arg;
template <typename Ftor> constexpr covDev(const Ftor &ftor) : arg(ftor.arg) { }
static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation
template <KernelType mykernel_type = kernel_type>
__device__ __host__ inline void operator()(int idx, int src_idx, int parity)
{
using real = typename mapper<typename Arg::Float>::type;
using Vector = ColorSpinor<real, Arg::nColor, Arg::nSpin>;
// is thread active (non-trival for fused kernel only)
bool active = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true;
// which dimension is thread working on (fused kernel only)
int thread_dim;
auto coord = getCoords<QUDA_4D_PC, mykernel_type, Arg>(arg, idx, 0, parity, thread_dim);
const int my_spinor_parity = arg.nParity == 2 ? parity : 0;
Vector out;
if (arg.dd_x.isZero(coord)) {
if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](coord.x_cb, my_spinor_parity) = out;
return;
}
switch (arg.mu) { // ensure that mu is known to compiler for indexing in applyCovDev (avoid register spillage)
case 0: applyCovDev<dagger, mykernel_type, 0>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 1: applyCovDev<dagger, mykernel_type, 1>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 2: applyCovDev<dagger, mykernel_type, 2>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 3: applyCovDev<dagger, mykernel_type, 3>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 4: applyCovDev<dagger, mykernel_type, 4>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 5: applyCovDev<dagger, mykernel_type, 5>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 6: applyCovDev<dagger, mykernel_type, 6>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
case 7: applyCovDev<dagger, mykernel_type, 7>(out, arg, coord, parity, idx, thread_dim, active, src_idx); break;
}
if (mykernel_type != INTERIOR_KERNEL && active) {
Vector x = arg.out[src_idx](coord.x_cb, my_spinor_parity);
out += x;
}
if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](coord.x_cb, my_spinor_parity) = out;
}
};
} // namespace quda