Skip to content

Commit a29036f

Browse files
committed
AD with GPU is fixed
1 parent 1596828 commit a29036f

File tree

1 file changed

+199
-1
lines changed

1 file changed

+199
-1
lines changed

ext/AD/AD.jl

Lines changed: 199 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Enzyme.EnzymeRules: augmented_primal, reverse, RevConfig, AugmentedReturn, needs_primal, needs_shadow
22
import LatticeMatrices: add_matrix!, add_matrix_Adag!, add_matrix_shiftedA!, add_matrix_shiftedAdag!, kernel_add_4D!, kernel_add_4D_dag!, kernel_add_4D_shift!, Adjoint_Lattice, get_shift,
33
kernel_Dmatrix_mul_AshiftB!, kernel_Dmatrix_mul_AshiftBdag!, kernel_clear_4D!,
4-
mul_ABdag!, mul_A_shiftBdag!, mul_AshiftB!, mul_shiftAshiftB!, substitute!, AbstractLattice, expt_TA!, clear_matrix!, set_halo!,
4+
mul_ABdag!, mul_A_shiftBdag!, mul_AshiftB!, mul_shiftAshiftB!, substitute!, AbstractLattice, expt!, expt_TA!, clear_matrix!, set_halo!,
55
fold_halo_dim_to_core_grad!
66
using PreallocatedArrays
77
using MPI
@@ -494,6 +494,34 @@ function ER.augmented_primal(cfg::ER.RevConfig,
494494
return ER.AugmentedReturn(nothing, nothing, (tapeA, tapeB))
495495
end
496496

497+
function ER.augmented_primal(cfg::ER.RevConfig,
498+
::ER.Const{typeof(mul!)},
499+
::Type{RT},
500+
C::ER.Annotation{T},
501+
A::ER.Annotation{T},
502+
B::ER.Annotation{T},
503+
α::S1,
504+
β::S2,
505+
) where {T<:LatticeMatrix,RT,S1,S2}
506+
αval = hasproperty(α, :val) ? α.val : α
507+
βval = hasproperty(β, :val) ? β.val : β
508+
primal_ret = mul!(C.val, A.val, B.val, αval, βval)
509+
510+
tapeA_obj, it_tapeA = get_block(A.val.temps)
511+
tapeA_obj .= A.val.A
512+
tapeA = (tapeA_obj, it_tapeA)
513+
514+
tapeB_obj, it_tapeB = get_block(B.val.temps)
515+
tapeB_obj .= B.val.A
516+
tapeB = (tapeB_obj, it_tapeB)
517+
518+
tape = (tapeA, tapeB, αval)
519+
RetT = ER.augmented_rule_return_type(cfg, RT, tape)
520+
primal = ER.needs_primal(cfg) ? primal_ret : nothing
521+
shadow = ER.needs_shadow(cfg) ? nothing : nothing
522+
return RetT(primal, shadow, tape)
523+
end
524+
497525
function ER.reverse(cfg::ER.RevConfig,
498526
::ER.Const{typeof(mul!)},
499527
dCout, tape,
@@ -2184,6 +2212,94 @@ end
21842212
const _expt_ta_eps_q = 1e-18
21852213
const fac13 = 1 / 3
21862214

2215+
function ER.augmented_primal(cfg::ER.RevConfig,
2216+
::ER.Const{typeof(expt!)},
2217+
::Type{RT},
2218+
C::ER.Annotation{<:LatticeMatrix},
2219+
A::ER.Annotation{<:LatticeMatrix},
2220+
t::S,
2221+
) where {RT,S}
2222+
tval = hasproperty(t, :val) ? t.val : t
2223+
expt!(C.val, A.val, tval)
2224+
2225+
tapeA_obj, itA = get_block(A.val.temps)
2226+
tapeA_obj .= A.val.A
2227+
tapeA = (tapeA_obj, itA)
2228+
2229+
tapeC_obj, itC = get_block(C.val.temps)
2230+
tapeC_obj .= C.val.A
2231+
tapeC = (tapeC_obj, itC)
2232+
2233+
return ER.AugmentedReturn(nothing, nothing, (tapeA, tapeC))
2234+
end
2235+
2236+
function ER.reverse(cfg::ER.RevConfig,
2237+
::ER.Const{typeof(expt!)},
2238+
dCout, tape,
2239+
C::ER.Annotation{<:LatticeMatrix},
2240+
A::ER.Annotation{<:LatticeMatrix},
2241+
t::S,
2242+
) where {S}
2243+
dC_struct = _getshadow_out(dCout, C)
2244+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
2245+
dC_struct === nothing && return (nothing, nothing, nothing)
2246+
dCval = dC_struct.A
2247+
2248+
dA_struct = _getshadow(A.dval)
2249+
dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing
2250+
dAval === nothing && return (nothing, nothing, nothing)
2251+
2252+
tapeA = (tape === nothing) ? nothing : tape[1]
2253+
tapeC = (tape === nothing) ? nothing : tape[2]
2254+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
2255+
Cval = (tapeC === nothing) ? C.val.A : tapeC[1]
2256+
2257+
tval = hasproperty(t, :val) ? t.val : t
2258+
2259+
dt = nothing
2260+
if t isa Active
2261+
init = zero(real(zero(eltype(dCval))))
2262+
dt_local = JACC.parallel_reduce(
2263+
prod(C.val.PN),
2264+
kernel_expt_TA_dt!,
2265+
dCval, Cval, Aval,
2266+
C.val.indexer, Val(C.val.NC1), Val(C.val.nw);
2267+
init=init, op=+
2268+
)
2269+
dt = MPI.Allreduce(dt_local, MPI.SUM, C.val.comm)
2270+
end
2271+
2272+
if C.val.NC1 == 2 && C.val.NC2 == 2
2273+
JACC.parallel_for(
2274+
prod(C.val.PN),
2275+
kernel_expt_TA_rev_su2!,
2276+
dAval, dCval, Aval,
2277+
C.val.indexer, Val(C.val.nw),
2278+
tval, _expt_ta_eps_q
2279+
)
2280+
elseif C.val.NC1 == 3 && C.val.NC2 == 3
2281+
JACC.parallel_for(
2282+
prod(C.val.PN),
2283+
kernel_expt_TA_rev_su3!,
2284+
dAval, dCval, Aval,
2285+
C.val.indexer, Val(C.val.nw),
2286+
tval, _expt_ta_eps_q
2287+
)
2288+
else
2289+
error("expt! reverse is only implemented for NC=2 or NC=3.")
2290+
end
2291+
2292+
if tapeA !== nothing
2293+
unused!(A.val.temps, tapeA[2])
2294+
end
2295+
if tapeC !== nothing
2296+
unused!(C.val.temps, tapeC[2])
2297+
end
2298+
2299+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
2300+
return (nothing, nothing, dt)
2301+
end
2302+
21872303
function ER.augmented_primal(cfg::ER.RevConfig,
21882304
::ER.Const{typeof(expt_TA!)},
21892305
::Type{RT},
@@ -3265,6 +3381,64 @@ function Enzyme.EnzymeRules.reverse(::RevConfig,
32653381
return (nothing, nothing, nothing)
32663382
end
32673383

3384+
function ER.reverse(cfg::ER.RevConfig,
3385+
::ER.Const{typeof(mul!)},
3386+
dCout, _tape,
3387+
C::ER.Annotation{<:LatticeMatrix},
3388+
A::ER.Annotation{<:LatticeMatrix},
3389+
B::ER.Annotation{<:LatticeMatrix},
3390+
α::S1,
3391+
β::S2,
3392+
) where {S1,S2}
3393+
= _zero_cotangent(α)
3394+
= _zero_cotangent(β)
3395+
3396+
dC_struct = _getshadow_out(dCout, C)
3397+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
3398+
dC_struct === nothing && return (nothing, nothing, nothing, dα, dβ)
3399+
dCval = dC_struct.A
3400+
3401+
dA_struct = _getshadow(A.dval)
3402+
dB_struct = _getshadow(B.dval)
3403+
dAval = (dA_struct === nothing) ? nothing : dA_struct.A
3404+
dBval = (dB_struct === nothing) ? nothing : dB_struct.A
3405+
3406+
tapeA, tapeB, tape_α = _tape
3407+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
3408+
Bval = (tapeB === nothing) ? B.val.A : tapeB[1]
3409+
3410+
NC1 = Val(C.val.NC1)
3411+
NC2 = Val(C.val.NC2)
3412+
NC3 = Val(A.val.NC2)
3413+
nw = Val(C.val.nw)
3414+
idxr = C.val.indexer
3415+
Nsites = prod(C.val.PN)
3416+
fac = conj(tape_α)
3417+
3418+
if dAval isa AbstractArray
3419+
JACC.parallel_for(
3420+
Nsites, kernel_Dmatrix_mulABdagadd_scaled!,
3421+
dAval, dCval, Bval, NC1, NC2, NC3, nw, idxr, fac
3422+
)
3423+
end
3424+
3425+
if dBval isa AbstractArray
3426+
JACC.parallel_for(
3427+
Nsites, kernel_Dmatrix_mulAdagBadd_scaled!,
3428+
dBval, Aval, dCval, NC1, NC2, NC3, nw, idxr, fac
3429+
)
3430+
end
3431+
if tapeA !== nothing
3432+
unused!(A.val.temps, tapeA[2])
3433+
end
3434+
if tapeB !== nothing
3435+
unused!(B.val.temps, tapeB[2])
3436+
end
3437+
3438+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
3439+
return (nothing, nothing, nothing, dα, dβ)
3440+
end
3441+
32683442
function Enzyme.EnzymeRules.reverse(cfg::RevConfig,
32693443
::Const{typeof(LinearAlgebra.mul!)},
32703444
dCout, _tape,
@@ -3382,6 +3556,18 @@ end
33823556
end
33833557
end
33843558

3559+
@inline function kernel_Dmatrix_mulABdagadd_scaled!(i, C, A, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer, fac) where {NC1,NC2,NC3,nw}
3560+
indices = delinearize(dindexer, i, nw)
3561+
@inbounds for jc = 1:NC2
3562+
for kc = 1:NC3
3563+
b = conj(B[jc, kc, indices...])
3564+
for ic = 1:NC1
3565+
C[ic, jc, indices...] += fac * A[ic, kc, indices...] * b
3566+
end
3567+
end
3568+
end
3569+
end
3570+
33853571
@inline function kernel_Dmatrix_mulACadd_matrix!(i, dA, dC, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer) where {NC1,NC2,NC3,nw}
33863572
indices = delinearize(dindexer, i, nw)
33873573
@inbounds for jc = 1:NC2
@@ -3452,6 +3638,18 @@ end
34523638
end
34533639
end
34543640

3641+
@inline function kernel_Dmatrix_mulAdagBadd_scaled!(i, C, A, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer, fac) where {NC1,NC2,NC3,nw}
3642+
indices = delinearize(dindexer, i, nw)
3643+
@inbounds for jc = 1:NC2
3644+
for kc = 1:NC3
3645+
b = B[kc, jc, indices...]
3646+
for ic = 1:NC1
3647+
C[ic, jc, indices...] += fac * conj(A[kc, ic, indices...]) * b
3648+
end
3649+
end
3650+
end
3651+
end
3652+
34553653
@inline function _should_zero_dC(dCout)
34563654
return dCout !== nothing
34573655
end

0 commit comments

Comments
 (0)