|
1 | 1 | import Enzyme.EnzymeRules: augmented_primal, reverse, RevConfig, AugmentedReturn, needs_primal, needs_shadow |
2 | 2 | import LatticeMatrices: add_matrix!, add_matrix_Adag!, add_matrix_shiftedA!, add_matrix_shiftedAdag!, kernel_add_4D!, kernel_add_4D_dag!, kernel_add_4D_shift!, Adjoint_Lattice, get_shift, |
3 | 3 | kernel_Dmatrix_mul_AshiftB!, kernel_Dmatrix_mul_AshiftBdag!, kernel_clear_4D!, |
4 | | - mul_ABdag!, mul_A_shiftBdag!, mul_AshiftB!, mul_shiftAshiftB!, substitute!, AbstractLattice, expt_TA!, clear_matrix!, set_halo!, |
| 4 | + mul_ABdag!, mul_A_shiftBdag!, mul_AshiftB!, mul_shiftAshiftB!, substitute!, AbstractLattice, expt!, expt_TA!, clear_matrix!, set_halo!, |
5 | 5 | fold_halo_dim_to_core_grad! |
6 | 6 | using PreallocatedArrays |
7 | 7 | using MPI |
@@ -494,6 +494,34 @@ function ER.augmented_primal(cfg::ER.RevConfig, |
494 | 494 | return ER.AugmentedReturn(nothing, nothing, (tapeA, tapeB)) |
495 | 495 | end |
496 | 496 |
|
| 497 | +function ER.augmented_primal(cfg::ER.RevConfig, |
| 498 | + ::ER.Const{typeof(mul!)}, |
| 499 | + ::Type{RT}, |
| 500 | + C::ER.Annotation{T}, |
| 501 | + A::ER.Annotation{T}, |
| 502 | + B::ER.Annotation{T}, |
| 503 | + α::S1, |
| 504 | + β::S2, |
| 505 | +) where {T<:LatticeMatrix,RT,S1,S2} |
| 506 | + αval = hasproperty(α, :val) ? α.val : α |
| 507 | + βval = hasproperty(β, :val) ? β.val : β |
| 508 | + primal_ret = mul!(C.val, A.val, B.val, αval, βval) |
| 509 | + |
| 510 | + tapeA_obj, it_tapeA = get_block(A.val.temps) |
| 511 | + tapeA_obj .= A.val.A |
| 512 | + tapeA = (tapeA_obj, it_tapeA) |
| 513 | + |
| 514 | + tapeB_obj, it_tapeB = get_block(B.val.temps) |
| 515 | + tapeB_obj .= B.val.A |
| 516 | + tapeB = (tapeB_obj, it_tapeB) |
| 517 | + |
| 518 | + tape = (tapeA, tapeB, αval) |
| 519 | + RetT = ER.augmented_rule_return_type(cfg, RT, tape) |
| 520 | + primal = ER.needs_primal(cfg) ? primal_ret : nothing |
| 521 | + shadow = ER.needs_shadow(cfg) ? nothing : nothing |
| 522 | + return RetT(primal, shadow, tape) |
| 523 | +end |
| 524 | + |
497 | 525 | function ER.reverse(cfg::ER.RevConfig, |
498 | 526 | ::ER.Const{typeof(mul!)}, |
499 | 527 | dCout, tape, |
@@ -2184,6 +2212,94 @@ end |
2184 | 2212 | const _expt_ta_eps_q = 1e-18 |
2185 | 2213 | const fac13 = 1 / 3 |
2186 | 2214 |
|
| 2215 | +function ER.augmented_primal(cfg::ER.RevConfig, |
| 2216 | + ::ER.Const{typeof(expt!)}, |
| 2217 | + ::Type{RT}, |
| 2218 | + C::ER.Annotation{<:LatticeMatrix}, |
| 2219 | + A::ER.Annotation{<:LatticeMatrix}, |
| 2220 | + t::S, |
| 2221 | +) where {RT,S} |
| 2222 | + tval = hasproperty(t, :val) ? t.val : t |
| 2223 | + expt!(C.val, A.val, tval) |
| 2224 | + |
| 2225 | + tapeA_obj, itA = get_block(A.val.temps) |
| 2226 | + tapeA_obj .= A.val.A |
| 2227 | + tapeA = (tapeA_obj, itA) |
| 2228 | + |
| 2229 | + tapeC_obj, itC = get_block(C.val.temps) |
| 2230 | + tapeC_obj .= C.val.A |
| 2231 | + tapeC = (tapeC_obj, itC) |
| 2232 | + |
| 2233 | + return ER.AugmentedReturn(nothing, nothing, (tapeA, tapeC)) |
| 2234 | +end |
| 2235 | + |
| 2236 | +function ER.reverse(cfg::ER.RevConfig, |
| 2237 | + ::ER.Const{typeof(expt!)}, |
| 2238 | + dCout, tape, |
| 2239 | + C::ER.Annotation{<:LatticeMatrix}, |
| 2240 | + A::ER.Annotation{<:LatticeMatrix}, |
| 2241 | + t::S, |
| 2242 | +) where {S} |
| 2243 | + dC_struct = _getshadow_out(dCout, C) |
| 2244 | + dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval)) |
| 2245 | + dC_struct === nothing && return (nothing, nothing, nothing) |
| 2246 | + dCval = dC_struct.A |
| 2247 | + |
| 2248 | + dA_struct = _getshadow(A.dval) |
| 2249 | + dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing |
| 2250 | + dAval === nothing && return (nothing, nothing, nothing) |
| 2251 | + |
| 2252 | + tapeA = (tape === nothing) ? nothing : tape[1] |
| 2253 | + tapeC = (tape === nothing) ? nothing : tape[2] |
| 2254 | + Aval = (tapeA === nothing) ? A.val.A : tapeA[1] |
| 2255 | + Cval = (tapeC === nothing) ? C.val.A : tapeC[1] |
| 2256 | + |
| 2257 | + tval = hasproperty(t, :val) ? t.val : t |
| 2258 | + |
| 2259 | + dt = nothing |
| 2260 | + if t isa Active |
| 2261 | + init = zero(real(zero(eltype(dCval)))) |
| 2262 | + dt_local = JACC.parallel_reduce( |
| 2263 | + prod(C.val.PN), |
| 2264 | + kernel_expt_TA_dt!, |
| 2265 | + dCval, Cval, Aval, |
| 2266 | + C.val.indexer, Val(C.val.NC1), Val(C.val.nw); |
| 2267 | + init=init, op=+ |
| 2268 | + ) |
| 2269 | + dt = MPI.Allreduce(dt_local, MPI.SUM, C.val.comm) |
| 2270 | + end |
| 2271 | + |
| 2272 | + if C.val.NC1 == 2 && C.val.NC2 == 2 |
| 2273 | + JACC.parallel_for( |
| 2274 | + prod(C.val.PN), |
| 2275 | + kernel_expt_TA_rev_su2!, |
| 2276 | + dAval, dCval, Aval, |
| 2277 | + C.val.indexer, Val(C.val.nw), |
| 2278 | + tval, _expt_ta_eps_q |
| 2279 | + ) |
| 2280 | + elseif C.val.NC1 == 3 && C.val.NC2 == 3 |
| 2281 | + JACC.parallel_for( |
| 2282 | + prod(C.val.PN), |
| 2283 | + kernel_expt_TA_rev_su3!, |
| 2284 | + dAval, dCval, Aval, |
| 2285 | + C.val.indexer, Val(C.val.nw), |
| 2286 | + tval, _expt_ta_eps_q |
| 2287 | + ) |
| 2288 | + else |
| 2289 | + error("expt! reverse is only implemented for NC=2 or NC=3.") |
| 2290 | + end |
| 2291 | + |
| 2292 | + if tapeA !== nothing |
| 2293 | + unused!(A.val.temps, tapeA[2]) |
| 2294 | + end |
| 2295 | + if tapeC !== nothing |
| 2296 | + unused!(C.val.temps, tapeC[2]) |
| 2297 | + end |
| 2298 | + |
| 2299 | + _should_zero_dC(dCout) && _zero_shadow!(dC_struct) |
| 2300 | + return (nothing, nothing, dt) |
| 2301 | +end |
| 2302 | + |
2187 | 2303 | function ER.augmented_primal(cfg::ER.RevConfig, |
2188 | 2304 | ::ER.Const{typeof(expt_TA!)}, |
2189 | 2305 | ::Type{RT}, |
@@ -3265,6 +3381,64 @@ function Enzyme.EnzymeRules.reverse(::RevConfig, |
3265 | 3381 | return (nothing, nothing, nothing) |
3266 | 3382 | end |
3267 | 3383 |
|
| 3384 | +function ER.reverse(cfg::ER.RevConfig, |
| 3385 | + ::ER.Const{typeof(mul!)}, |
| 3386 | + dCout, _tape, |
| 3387 | + C::ER.Annotation{<:LatticeMatrix}, |
| 3388 | + A::ER.Annotation{<:LatticeMatrix}, |
| 3389 | + B::ER.Annotation{<:LatticeMatrix}, |
| 3390 | + α::S1, |
| 3391 | + β::S2, |
| 3392 | +) where {S1,S2} |
| 3393 | + dα = _zero_cotangent(α) |
| 3394 | + dβ = _zero_cotangent(β) |
| 3395 | + |
| 3396 | + dC_struct = _getshadow_out(dCout, C) |
| 3397 | + dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval)) |
| 3398 | + dC_struct === nothing && return (nothing, nothing, nothing, dα, dβ) |
| 3399 | + dCval = dC_struct.A |
| 3400 | + |
| 3401 | + dA_struct = _getshadow(A.dval) |
| 3402 | + dB_struct = _getshadow(B.dval) |
| 3403 | + dAval = (dA_struct === nothing) ? nothing : dA_struct.A |
| 3404 | + dBval = (dB_struct === nothing) ? nothing : dB_struct.A |
| 3405 | + |
| 3406 | + tapeA, tapeB, tape_α = _tape |
| 3407 | + Aval = (tapeA === nothing) ? A.val.A : tapeA[1] |
| 3408 | + Bval = (tapeB === nothing) ? B.val.A : tapeB[1] |
| 3409 | + |
| 3410 | + NC1 = Val(C.val.NC1) |
| 3411 | + NC2 = Val(C.val.NC2) |
| 3412 | + NC3 = Val(A.val.NC2) |
| 3413 | + nw = Val(C.val.nw) |
| 3414 | + idxr = C.val.indexer |
| 3415 | + Nsites = prod(C.val.PN) |
| 3416 | + fac = conj(tape_α) |
| 3417 | + |
| 3418 | + if dAval isa AbstractArray |
| 3419 | + JACC.parallel_for( |
| 3420 | + Nsites, kernel_Dmatrix_mulABdagadd_scaled!, |
| 3421 | + dAval, dCval, Bval, NC1, NC2, NC3, nw, idxr, fac |
| 3422 | + ) |
| 3423 | + end |
| 3424 | + |
| 3425 | + if dBval isa AbstractArray |
| 3426 | + JACC.parallel_for( |
| 3427 | + Nsites, kernel_Dmatrix_mulAdagBadd_scaled!, |
| 3428 | + dBval, Aval, dCval, NC1, NC2, NC3, nw, idxr, fac |
| 3429 | + ) |
| 3430 | + end |
| 3431 | + if tapeA !== nothing |
| 3432 | + unused!(A.val.temps, tapeA[2]) |
| 3433 | + end |
| 3434 | + if tapeB !== nothing |
| 3435 | + unused!(B.val.temps, tapeB[2]) |
| 3436 | + end |
| 3437 | + |
| 3438 | + _should_zero_dC(dCout) && _zero_shadow!(dC_struct) |
| 3439 | + return (nothing, nothing, nothing, dα, dβ) |
| 3440 | +end |
| 3441 | + |
3268 | 3442 | function Enzyme.EnzymeRules.reverse(cfg::RevConfig, |
3269 | 3443 | ::Const{typeof(LinearAlgebra.mul!)}, |
3270 | 3444 | dCout, _tape, |
@@ -3382,6 +3556,18 @@ end |
3382 | 3556 | end |
3383 | 3557 | end |
3384 | 3558 |
|
| 3559 | +@inline function kernel_Dmatrix_mulABdagadd_scaled!(i, C, A, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer, fac) where {NC1,NC2,NC3,nw} |
| 3560 | + indices = delinearize(dindexer, i, nw) |
| 3561 | + @inbounds for jc = 1:NC2 |
| 3562 | + for kc = 1:NC3 |
| 3563 | + b = conj(B[jc, kc, indices...]) |
| 3564 | + for ic = 1:NC1 |
| 3565 | + C[ic, jc, indices...] += fac * A[ic, kc, indices...] * b |
| 3566 | + end |
| 3567 | + end |
| 3568 | + end |
| 3569 | +end |
| 3570 | + |
3385 | 3571 | @inline function kernel_Dmatrix_mulACadd_matrix!(i, dA, dC, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer) where {NC1,NC2,NC3,nw} |
3386 | 3572 | indices = delinearize(dindexer, i, nw) |
3387 | 3573 | @inbounds for jc = 1:NC2 |
@@ -3452,6 +3638,18 @@ end |
3452 | 3638 | end |
3453 | 3639 | end |
3454 | 3640 |
|
| 3641 | +@inline function kernel_Dmatrix_mulAdagBadd_scaled!(i, C, A, B, ::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer, fac) where {NC1,NC2,NC3,nw} |
| 3642 | + indices = delinearize(dindexer, i, nw) |
| 3643 | + @inbounds for jc = 1:NC2 |
| 3644 | + for kc = 1:NC3 |
| 3645 | + b = B[kc, jc, indices...] |
| 3646 | + for ic = 1:NC1 |
| 3647 | + C[ic, jc, indices...] += fac * conj(A[kc, ic, indices...]) * b |
| 3648 | + end |
| 3649 | + end |
| 3650 | + end |
| 3651 | +end |
| 3652 | + |
3455 | 3653 | @inline function _should_zero_dC(dCout) |
3456 | 3654 | return dCout !== nothing |
3457 | 3655 | end |
|
0 commit comments