Skip to content

Commit ef0cd18

Browse files
committed
TA is revised
1 parent 693904b commit ef0cd18

File tree

1 file changed

+108
-101
lines changed

1 file changed

+108
-101
lines changed

src/TA/TA.jl

Lines changed: 108 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,26 @@ const sr3i = 1 / sr3
2121
const sr3ih = 0.5 * sr3i
2222
const sqr3inv = sr3i
2323
const sr3i2 = 2 * sr3i
24+
25+
function traceless_antihermitian!(A::TA,B::TB) where {D,T1,AT1,N,nw,DI,
26+
TA<:LatticeMatrix{D,T1,AT1,N,N,nw,DI},TB<:LatticeMatrix{D,T1,AT1,N,N,nw,DI}}
27+
substitute!(A, B)
28+
traceless_antihermitian!(A)
29+
end
2430

2531
function traceless_antihermitian!(A::TA) where {D,T1,AT1,N,nw,DI,TA<:LatticeMatrix{D,T1,AT1,N,N,nw,DI}}
2632
if N == 3
2733
JACC.parallel_for(
28-
prod(A.PN), kernel_traceless_antihermitian_4DNC3!, A.A, A.nw, A.PN)
34+
prod(A.PN), kernel_traceless_antihermitian_4DNC3!, A.A, A.nw, A.indexer)
2935
elseif N == 2
3036
JACC.parallel_for(
31-
prod(A.PN), kernel_traceless_antihermitian_4DNC2!, A.A, A.nw, A.PN)
37+
prod(A.PN), kernel_traceless_antihermitian_4DNC2!, A.A, A.nw, A.indexer)
3238
elseif N == 1
3339
@warn("No traceless antihermitian condition applied for SU(1). This is a scalar lattice, so no special unitary condition is needed.")
3440
# For N=1, no SU(N) condition is needed, as it is just a scalar.
3541
else
3642
JACC.parallel_for(
37-
prod(A.PN), kernel_traceless_antihermitian_4D!, A.A, N, A.nw, A.PN)
43+
prod(A.PN), kernel_traceless_antihermitian_4D!, A.A, N, A.nw, A.indexer)
3844
#error("Unsupported number of colors for special unitary lattice: $N")
3945
end
4046
set_halo!(A)
@@ -62,40 +68,41 @@ function traceless_antihermitian!(A::TA) where {T,AT,N, TA<:TALattice{4,T,AT,N}}
6268
end
6369

6470

65-
function kernel_traceless_antihermitian_4DNC2!(i, v, nw, PN)
66-
ix, iy, iz, it = get_4Dindex(i, PN)
67-
v11 = v[1, 1, ix, iy, iz, it]
68-
v22 = v[2, 2, ix, iy, iz, it]
71+
function kernel_traceless_antihermitian_4DNC2!(i, v, nw, dindexer)
72+
#ix, iy, iz, it = get_4Dindex(i, PN)
73+
indices = delinearize(dindexer, i, nw)
74+
v11 = v[1, 1, indices...]
75+
v22 = v[2, 2, indices...]
6976

7077
tri = (imag(v11) + imag(v22)) * 0.5
7178

72-
v12 = v[1, 2, ix, iy, iz, it]
73-
v21 = v[2, 1, ix, iy, iz, it]
79+
v12 = v[1, 2, indices...]
80+
v21 = v[2, 1, indices...]
7481

7582
x12 = v12 - conj(v21)
7683

7784
x21 = -conj(x12)
7885

79-
v[1, 1, ix, iy, iz, it] = (imag(v11) - tri) * im
80-
v[1, 2, ix, iy, iz, it] = 0.5 * x12
81-
v[2, 1, ix, iy, iz, it] = 0.5 * x21
82-
v[2, 2, ix, iy, iz, it] = (imag(v22) - tri) * im
86+
v[1, 1, indices...] = (imag(v11) - tri) * im
87+
v[1, 2, indices...] = 0.5 * x12
88+
v[2, 1, indices...] = 0.5 * x21
89+
v[2, 2, indices...] = (imag(v22) - tri) * im
8390

8491
end
8592

86-
function kernel_traceless_antihermitian_4DNC3!(i, v, nw, PN)
87-
ix, iy, iz, it = get_4Dindex(i, PN)
88-
v11 = v[1, 1, ix, iy, iz, it]
89-
v21 = v[2, 1, ix, iy, iz, it]
90-
v31 = v[3, 1, ix, iy, iz, it]
93+
function kernel_traceless_antihermitian_4DNC3!(i, v, nw, dindexer)
94+
indices = delinearize(dindexer, i, nw)
95+
v11 = v[1, 1, indices...]
96+
v21 = v[2, 1, indices...]
97+
v31 = v[3, 1, indices...]
9198

92-
v12 = v[1, 2, ix, iy, iz, it]
93-
v22 = v[2, 2, ix, iy, iz, it]
94-
v32 = v[3, 2, ix, iy, iz, it]
99+
v12 = v[1, 2, indices...]
100+
v22 = v[2, 2, indices...]
101+
v32 = v[3, 2, indices...]
95102

96-
v13 = v[1, 3, ix, iy, iz, it]
97-
v23 = v[2, 3, ix, iy, iz, it]
98-
v33 = v[3, 3, ix, iy, iz, it]
103+
v13 = v[1, 3, indices...]
104+
v23 = v[2, 3, indices...]
105+
v33 = v[3, 3, indices...]
99106

100107

101108
tri = fac13 * (imag(v11) + imag(v22) + imag(v33))
@@ -119,43 +126,43 @@ function kernel_traceless_antihermitian_4DNC3!(i, v, nw, PN)
119126
y31 = 0.5 * x31
120127
y32 = 0.5 * x32
121128

122-
v[1, 1, ix, iy, iz, it] = y11
123-
v[2, 1, ix, iy, iz, it] = y21
124-
v[3, 1, ix, iy, iz, it] = y31
129+
v[1, 1, indices...] = y11
130+
v[2, 1, indices...] = y21
131+
v[3, 1, indices...] = y31
125132

126-
v[1, 2, ix, iy, iz, it] = y12
127-
v[2, 2, ix, iy, iz, it] = y22
128-
v[3, 2, ix, iy, iz, it] = y32
133+
v[1, 2, indices...] = y12
134+
v[2, 2, indices...] = y22
135+
v[3, 2, indices...] = y32
129136

130-
v[1, 3, ix, iy, iz, it] = y13
131-
v[2, 3, ix, iy, iz, it] = y23
132-
v[3, 3, ix, iy, iz, it] = y33
137+
v[1, 3, indices...] = y13
138+
v[2, 3, indices...] = y23
139+
v[3, 3, indices...] = y33
133140

134141
end
135142

136-
function kernel_traceless_antihermitian_4D!(i, v, N, nw, PN)
137-
ix, iy, iz, it = get_4Dindex(i, PN)
143+
function kernel_traceless_antihermitian_4D!(i, v, N, nw, dindexer)
144+
indices = delinearize(dindexer, i, nw)
138145
fac1N = 1 / N
139146
tri = 0.0
140147
for k = 1:N
141-
tri += imag(v[k, k, ix, iy, iz, it])
148+
tri += imag(v[k, k, indices...])
142149
end
143150
tri *= fac1N
144151
for k = 1:N
145-
v[k, k, ix, iy, iz, it] =
146-
(imag(v[k, k, ix, iy, iz, it]) - tri) * im
152+
v[k, k, indices...] =
153+
(imag(v[k, k, indices...]) - tri) * im
147154
end
148155

149156

150157
for k1 = 1:N
151158
for k2 = k1+1:N
152159
vv =
153160
0.5 * (
154-
v[k1, k2, ix, iy, iz, it] -
155-
conj(v[k2, k1, ix, iy, iz, it])
161+
v[k1, k2, indices...] -
162+
conj(v[k2, k1, indices...])
156163
)
157-
v[k1, k2, ix, iy, iz, it] = vv
158-
v[k2, k1, ix, iy, iz, it] = -conj(vv)
164+
v[k1, k2, indices...] = vv
165+
v[k2, k1, indices...] = -conj(vv)
159166
end
160167
end
161168

@@ -182,19 +189,19 @@ function expt!(C::TC, A::TA, t::S=one(S)) where {T,AT,NC1,NC2,S<:Number,T1,AT1,
182189
set_halo!(C)
183190
end
184191

185-
function kernel_4Dexpt_SU3!(i, C, A, PN, t)
186-
ix, iy, iz, it = get_4Dindex(i, PN)
192+
function kernel_4Dexpt_SU3!(i, C, A, dindexer, t)
193+
indices = delinearize(dindexer, i, nw)
187194
T = eltype(C)
188195

189-
y11 = A[1, 1, ix, iy, iz, it]
190-
y22 = A[2, 2, ix, iy, iz, it]
191-
y33 = A[3, 3, ix, iy, iz, it]
192-
y12 = A[1, 2, ix, iy, iz, it]
193-
y13 = A[1, 3, ix, iy, iz, it]
194-
y21 = A[2, 1, ix, iy, iz, it]
195-
y23 = A[2, 3, ix, iy, iz, it]
196-
y31 = A[3, 1, ix, iy, iz, it]
197-
y32 = A[3, 2, ix, iy, iz, it]
196+
y11 = A[1, 1, indices...]
197+
y22 = A[2, 2, indices...]
198+
y33 = A[3, 3, indices...]
199+
y12 = A[1, 2, indices...]
200+
y13 = A[1, 3, indices...]
201+
y21 = A[2, 1, indices...]
202+
y23 = A[2, 3, indices...]
203+
y31 = A[3, 1, indices...]
204+
y32 = A[3, 2, indices...]
198205

199206
c1_0 = (imag(y12) + imag(y21))
200207
c2_0 = (real(y12) - real(y21))
@@ -217,15 +224,15 @@ function kernel_4Dexpt_SU3!(i, C, A, PN, t)
217224
csum = c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8
218225
if csum == 0
219226
c = Mat3{eltype(C)}(one(eltype(C)))
220-
C[1, 1, ix, iy, iz, it] = c.a11
221-
C[1, 2, ix, iy, iz, it] = c.a12
222-
C[1, 3, ix, iy, iz, it] = c.a13
223-
C[2, 1, ix, iy, iz, it] = c.a21
224-
C[2, 2, ix, iy, iz, it] = c.a22
225-
C[2, 3, ix, iy, iz, it] = c.a23
226-
C[3, 1, ix, iy, iz, it] = c.a31
227-
C[3, 2, ix, iy, iz, it] = c.a32
228-
C[3, 3, ix, iy, iz, it] = c.a33
227+
C[1, 1, indices...] = c.a11
228+
C[1, 2, indices...] = c.a12
229+
C[1, 3, indices...] = c.a13
230+
C[2, 1, indices...] = c.a21
231+
C[2, 2, indices...] = c.a22
232+
C[2, 3, indices...] = c.a23
233+
C[3, 1, indices...] = c.a31
234+
C[3, 2, indices...] = c.a32
235+
C[3, 3, indices...] = c.a33
229236

230237
end
231238

@@ -382,51 +389,51 @@ function kernel_4Dexpt_SU3!(i, C, A, PN, t)
382389
ww15 + im * ww16,
383390
ww17 + im * ww18)
384391
c = mul3(conjugate3(w), ww)
385-
#C[:, :, ix, iy, iz, it] = T[c.a11 c.a12 c.a13;
392+
#C[:, :, indices...] = T[c.a11 c.a12 c.a13;
386393
# c.a21 c.a22 c.a23;
387394
# c.a31 c.a32 c.a33]
388395

389-
C[1, 1, ix, iy, iz, it] = c.a11
390-
C[1, 2, ix, iy, iz, it] = c.a12
391-
C[1, 3, ix, iy, iz, it] = c.a13
392-
C[2, 1, ix, iy, iz, it] = c.a21
393-
C[2, 2, ix, iy, iz, it] = c.a22
394-
C[2, 3, ix, iy, iz, it] = c.a23
395-
C[3, 1, ix, iy, iz, it] = c.a31
396-
C[3, 2, ix, iy, iz, it] = c.a32
397-
C[3, 3, ix, iy, iz, it] = c.a33
396+
C[1, 1, indices...] = c.a11
397+
C[1, 2, indices...] = c.a12
398+
C[1, 3, indices...] = c.a13
399+
C[2, 1, indices...] = c.a21
400+
C[2, 2, indices...] = c.a22
401+
C[2, 3, indices...] = c.a23
402+
C[3, 1, indices...] = c.a31
403+
C[3, 2, indices...] = c.a32
404+
C[3, 3, indices...] = c.a33
398405

399406
#=
400-
w[1, 1, ix, iy, iz, it] = w1 + im * w2
401-
w[1, 2, ix, iy, iz, it] = w3 + im * w4
402-
w[1, 3, ix, iy, iz, it] = w5 + im * w6
403-
w[2, 1, ix, iy, iz, it] = w7 + im * w8
404-
w[2, 2, ix, iy, iz, it] = w9 + im * w10
405-
w[2, 3, ix, iy, iz, it] = w11 + im * w12
406-
w[3, 1, ix, iy, iz, it] = w13 + im * w14
407-
w[3, 2, ix, iy, iz, it] = w15 + im * w16
408-
w[3, 3, ix, iy, iz, it] = w17 + im * w18
409-
410-
ww[1, 1, ix, iy, iz, it] = ww1 + im * ww2
411-
ww[1, 2, ix, iy, iz, it] = ww3 + im * ww4
412-
ww[1, 3, ix, iy, iz, it] = ww5 + im * ww6
413-
ww[2, 1, ix, iy, iz, it] = ww7 + im * ww8
414-
ww[2, 2, ix, iy, iz, it] = ww9 + im * ww10
415-
ww[2, 3, ix, iy, iz, it] = ww11 + im * ww12
416-
ww[3, 1, ix, iy, iz, it] = ww13 + im * ww14
417-
ww[3, 2, ix, iy, iz, it] = ww15 + im * ww16
418-
ww[3, 3, ix, iy, iz, it] = ww17 + im * ww18
407+
w[1, 1, indices...] = w1 + im * w2
408+
w[1, 2, indices...] = w3 + im * w4
409+
w[1, 3, indices...] = w5 + im * w6
410+
w[2, 1, indices...] = w7 + im * w8
411+
w[2, 2, indices...] = w9 + im * w10
412+
w[2, 3, indices...] = w11 + im * w12
413+
w[3, 1, indices...] = w13 + im * w14
414+
w[3, 2, indices...] = w15 + im * w16
415+
w[3, 3, indices...] = w17 + im * w18
416+
417+
ww[1, 1, indices...] = ww1 + im * ww2
418+
ww[1, 2, indices...] = ww3 + im * ww4
419+
ww[1, 3, indices...] = ww5 + im * ww6
420+
ww[2, 1, indices...] = ww7 + im * ww8
421+
ww[2, 2, indices...] = ww9 + im * ww10
422+
ww[2, 3, indices...] = ww11 + im * ww12
423+
ww[3, 1, indices...] = ww13 + im * ww14
424+
ww[3, 2, indices...] = ww15 + im * ww16
425+
ww[3, 3, indices...] = ww17 + im * ww18
419426
=#
420427

421428
end
422429

423-
function kernel_4Dexpt_SU2!(i, uout, v, PN, t)
424-
ix, iy, iz, it = get_4Dindex(i, PN)
430+
function kernel_4Dexpt_SU2!(i, uout, v, dindexer, t)
431+
indices = delinearize(dindexer, i, nw)
425432

426-
y11 = v[1, 1, ix, iy, iz, it]
427-
y12 = v[1, 2, ix, iy, iz, it]
428-
y21 = v[2, 1, ix, iy, iz, it]
429-
y22 = v[2, 2, ix, iy, iz, it]
433+
y11 = v[1, 1, indices...]
434+
y12 = v[1, 2, indices...]
435+
y21 = v[2, 1, indices...]
436+
y22 = v[2, 2, indices...]
430437

431438
c1_0 = (imag(y12) + imag(y21))
432439
c2_0 = (real(y12) - real(y21))
@@ -444,8 +451,8 @@ function kernel_4Dexpt_SU2!(i, uout, v, PN, t)
444451
a2 = u2 * sR
445452
a3 = u3 * sR
446453

447-
uout[1, 1, ix, iy, iz, it] = cos(R) + im * a3
448-
uout[1, 2, ix, iy, iz, it] = im * a1 + a2
449-
uout[2, 1, ix, iy, iz, it] = im * a1 - a2
450-
uout[2, 2, ix, iy, iz, it] = cos(R) - im * a3
454+
uout[1, 1, indices...] = cos(R) + im * a3
455+
uout[1, 2, indices...] = im * a1 + a2
456+
uout[2, 1, indices...] = im * a1 - a2
457+
uout[2, 2, indices...] = cos(R) - im * a3
451458
end

0 commit comments

Comments
 (0)