@@ -473,6 +473,178 @@ function ER.reverse(cfg::ER.RevConfig,
473473 return (nothing , nothing , nothing )
474474end
475475
476+ # C = A * B' where B is Adjoint_Lattice wrapper.
477+ function ER. augmented_primal(cfg:: ER.RevConfig ,
478+ :: ER.Const{typeof(mul!)} ,
479+ :: Type{RT} ,
480+ C:: ER.Annotation{<:LatticeMatrix} ,
481+ A:: ER.Annotation{<:LatticeMatrix} ,
482+ B:: ER.Annotation{<:Adjoint_Lattice} ,
483+ ) where {RT}
484+ mul_ABdag!(C. val, A. val, B. val. data)
485+
486+ tapeA_obj, itA = get_block(A. val. temps)
487+ tapeA_obj .= A. val. A
488+ tapeA = (tapeA_obj, itA)
489+
490+ tapeB_obj, itB = get_block(B. val. data. temps)
491+ tapeB_obj .= B. val. data. A
492+ tapeB = (tapeB_obj, itB)
493+
494+ return ER. AugmentedReturn(nothing , nothing , (tapeA, tapeB))
495+ end
496+
497+ function ER. reverse(cfg:: ER.RevConfig ,
498+ :: ER.Const{typeof(mul!)} ,
499+ dCout, tape,
500+ C:: ER.Annotation{<:LatticeMatrix} ,
501+ A:: ER.Annotation{<:LatticeMatrix} ,
502+ B:: ER.Annotation{<:Adjoint_Lattice} ,
503+ )
504+ dC_struct = _getshadow_out(dCout, C)
505+ dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
506+ dC_struct === nothing && return (nothing , nothing , nothing )
507+ dCval = dC_struct. A
508+
509+ dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
510+ dAval = (dA_struct isa LatticeMatrix) ? dA_struct. A : nothing
511+
512+ dB_struct = hasproperty(B, :dval) ? _getshadow_data(B. dval) : nothing
513+ dBval = (dB_struct isa LatticeMatrix) ? dB_struct. A : nothing
514+
515+ tapeA, tapeB = tape
516+ Aval = (tapeA === nothing ) ? A. val. A : tapeA[1 ]
517+ Bval = (tapeB === nothing ) ? B. val. data. A : tapeB[1 ]
518+
519+ NC1 = Val(C. val. NC1)
520+ NC2 = Val(C. val. NC2)
521+ NC3 = Val(A. val. NC2)
522+ nw = Val(C. val. nw)
523+ idxr = C. val. indexer
524+ Nsites = prod(C. val. PN)
525+
526+ if dAval != = nothing
527+ JACC. parallel_for(
528+ Nsites,
529+ kernel_Dmatrix_mulACadd!,
530+ dAval, dCval, Bval,
531+ NC1, NC2, NC3, nw, idxr
532+ )
533+ end
534+
535+ if dBval != = nothing
536+ JACC. parallel_for(
537+ Nsites,
538+ kernel_Dmatrix_mulCdagAadd!,
539+ dBval, dCval, Aval,
540+ NC2, NC1, NC3, nw, idxr
541+ )
542+ end
543+
544+ if tapeA != = nothing
545+ unused!(A. val. temps, tapeA[2 ])
546+ end
547+ if tapeB != = nothing
548+ unused!(B. val. data. temps, tapeB[2 ])
549+ end
550+
551+ _should_zero_dC(dCout) && _zero_shadow!(dC_struct)
552+ return (nothing , nothing , nothing )
553+ end
554+
555+ # C = β*C + α*A*B' where B is Adjoint_Lattice wrapper.
556+ function ER. augmented_primal(cfg:: ER.RevConfig ,
557+ :: ER.Const{typeof(mul!)} ,
558+ :: Type{RT} ,
559+ C:: ER.Annotation{<:LatticeMatrix} ,
560+ A:: ER.Annotation{<:LatticeMatrix} ,
561+ B:: ER.Annotation{<:Adjoint_Lattice} ,
562+ α:: S1 ,
563+ β:: S2 ,
564+ ) where {RT,S1,S2}
565+ αval = hasproperty(α, :val) ? α. val : α
566+ βval = hasproperty(β, :val) ? β. val : β
567+ primal_ret = mul_ABdag!(C. val, A. val, B. val. data, αval, βval)
568+
569+ tapeA_obj, itA = get_block(A. val. temps)
570+ tapeA_obj .= A. val. A
571+ tapeA = (tapeA_obj, itA)
572+
573+ tapeB_obj, itB = get_block(B. val. data. temps)
574+ tapeB_obj .= B. val. data. A
575+ tapeB = (tapeB_obj, itB)
576+
577+ tape = (tapeA, tapeB, αval)
578+ RetT = ER. augmented_rule_return_type(cfg, RT, tape)
579+ primal = ER. needs_primal(cfg) ? primal_ret : nothing
580+ shadow = ER. needs_shadow(cfg) ? nothing : nothing
581+ return RetT(primal, shadow, tape)
582+ end
583+
584+ function ER. reverse(cfg:: ER.RevConfig ,
585+ :: ER.Const{typeof(mul!)} ,
586+ dCout, tape,
587+ C:: ER.Annotation{<:LatticeMatrix} ,
588+ A:: ER.Annotation{<:LatticeMatrix} ,
589+ B:: ER.Annotation{<:Adjoint_Lattice} ,
590+ α:: S1 ,
591+ β:: S2 ,
592+ ) where {S1,S2}
593+ dα = _zero_cotangent(α)
594+ dβ = _zero_cotangent(β)
595+
596+ dC_struct = _getshadow_out(dCout, C)
597+ dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
598+ dC_struct === nothing && return (nothing , nothing , nothing , dα, dβ)
599+ dCval = dC_struct. A
600+
601+ dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
602+ dAval = (dA_struct isa LatticeMatrix) ? dA_struct. A : nothing
603+
604+ dB_struct = hasproperty(B, :dval) ? _getshadow_data(B. dval) : nothing
605+ dBval = (dB_struct isa LatticeMatrix) ? dB_struct. A : nothing
606+
607+ tapeA, tapeB, tape_α = tape
608+ Aval = (tapeA === nothing ) ? A. val. A : tapeA[1 ]
609+ Bval = (tapeB === nothing ) ? B. val. data. A : tapeB[1 ]
610+
611+ NC1 = Val(C. val. NC1)
612+ NC2 = Val(C. val. NC2)
613+ NC3 = Val(A. val. NC2)
614+ nw = Val(C. val. nw)
615+ idxr = C. val. indexer
616+ Nsites = prod(C. val. PN)
617+ fac = conj(tape_α)
618+
619+ if dAval != = nothing
620+ JACC. parallel_for(
621+ Nsites,
622+ kernel_Dmatrix_mulACadd_scaled!,
623+ dAval, dCval, Bval,
624+ NC1, NC2, NC3, nw, idxr, fac
625+ )
626+ end
627+
628+ if dBval != = nothing
629+ JACC. parallel_for(
630+ Nsites,
631+ kernel_Dmatrix_mulCdagAadd_scaled!,
632+ dBval, dCval, Aval,
633+ NC2, NC1, NC3, nw, idxr, fac
634+ )
635+ end
636+
637+ if tapeA != = nothing
638+ unused!(A. val. temps, tapeA[2 ])
639+ end
640+ if tapeB != = nothing
641+ unused!(B. val. data. temps, tapeB[2 ])
642+ end
643+
644+ _should_zero_dC(dCout) && _zero_shadow!(dC_struct)
645+ return (nothing , nothing , nothing , dα, dβ)
646+ end
647+
476648# C = A * (shifted B)'
477649function ER. augmented_primal(cfg:: ER.RevConfig ,
478650 :: ER.Const{typeof(mul!)} ,
@@ -1027,6 +1199,167 @@ function ER.reverse(cfg::ER.RevConfig,
10271199 return _rev_mul_ABdag!(cfg, dCout, tape, C, A, B; do_dB= do_dB)
10281200end
10291201
1202+ # C = β*C + α*A*B'
1203+ function ER. augmented_primal(cfg:: ER.RevConfig ,
1204+ :: ER.Const{typeof(mul_ABdag!)} ,
1205+ :: Type{RT} ,
1206+ C:: ER.Annotation{<:LatticeMatrix} ,
1207+ A:: ER.Annotation{<:LatticeMatrix} ,
1208+ B:: ER.Annotation{<:LatticeMatrix} ,
1209+ α:: S1 ,
1210+ β:: S2 ,
1211+ ) where {RT,S1,S2}
1212+ αval = hasproperty(α, :val) ? α. val : α
1213+ βval = hasproperty(β, :val) ? β. val : β
1214+ primal_ret = mul_ABdag!(C. val, A. val, B. val, αval, βval)
1215+
1216+ tapeA_obj, itA = get_block(A. val. temps)
1217+ tapeA_obj .= A. val. A
1218+ tapeA = (tapeA_obj, itA)
1219+
1220+ tapeB_obj, itB = get_block(B. val. temps)
1221+ tapeB_obj .= B. val. A
1222+ tapeB = (tapeB_obj, itB)
1223+
1224+ tape = (tapeA, tapeB, αval)
1225+ RetT = ER. augmented_rule_return_type(cfg, RT, tape)
1226+ primal = ER. needs_primal(cfg) ? primal_ret : nothing
1227+ shadow = ER. needs_shadow(cfg) ? nothing : nothing
1228+ return RetT(primal, shadow, tape)
1229+ end
1230+
1231+ function ER. reverse(cfg:: ER.RevConfig ,
1232+ :: ER.Const{typeof(mul_ABdag!)} ,
1233+ dCout, tape,
1234+ C:: ER.Annotation{<:LatticeMatrix} ,
1235+ A:: ER.Annotation{<:LatticeMatrix} ,
1236+ B:: ER.Duplicated{<:LatticeMatrix} ,
1237+ α:: S1 ,
1238+ β:: S2 ,
1239+ ) where {S1,S2}
1240+ dα = _zero_cotangent(α)
1241+ dβ = _zero_cotangent(β)
1242+
1243+ dC_struct = _getshadow_out(dCout, C)
1244+ dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
1245+ dC_struct === nothing && return (nothing , nothing , nothing , dα, dβ)
1246+ dCval = dC_struct. A
1247+
1248+ dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
1249+ dAval = (dA_struct isa LatticeMatrix) ? dA_struct. A : nothing
1250+
1251+ s = _getshadow(B. dval)
1252+ dBval = (s isa LatticeMatrix) ? s. A : nothing
1253+
1254+ tapeA, tapeB, tape_α = tape
1255+ Aval = (tapeA === nothing ) ? A. val. A : tapeA[1 ]
1256+ Bval = (tapeB === nothing ) ? B. val. A : tapeB[1 ]
1257+
1258+ NC1 = Val(C. val. NC1)
1259+ NC2 = Val(C. val. NC2)
1260+ NC3 = Val(A. val. NC2)
1261+ nw = Val(C. val. nw)
1262+ idxr = C. val. indexer
1263+ Nsites = prod(C. val. PN)
1264+ fac = conj(tape_α)
1265+
1266+ if dAval != = nothing && Bval != = nothing
1267+ JACC. parallel_for(
1268+ Nsites,
1269+ kernel_Dmatrix_mulACadd_scaled!,
1270+ dAval, dCval, Bval,
1271+ NC1, NC2, NC3, nw, idxr, fac
1272+ )
1273+ end
1274+
1275+ if dBval != = nothing
1276+ JACC. parallel_for(
1277+ Nsites,
1278+ kernel_Dmatrix_mulCdagAadd_scaled!,
1279+ dBval, dCval, Aval,
1280+ NC2, NC1, NC3, nw, idxr, fac
1281+ )
1282+ end
1283+
1284+ if tapeA != = nothing
1285+ unused!(A. val. temps, tapeA[2 ])
1286+ end
1287+ if tapeB != = nothing
1288+ unused!(B. val. temps, tapeB[2 ])
1289+ end
1290+
1291+ _should_zero_dC(dCout) && _zero_shadow!(dC_struct)
1292+ return (nothing , nothing , nothing , dα, dβ)
1293+ end
1294+
1295+ function ER. reverse(cfg:: ER.RevConfig ,
1296+ :: ER.Const{typeof(mul_ABdag!)} ,
1297+ dCout, tape,
1298+ C:: ER.Annotation{<:LatticeMatrix} ,
1299+ A:: ER.Annotation{<:LatticeMatrix} ,
1300+ B,
1301+ α:: S1 ,
1302+ β:: S2 ,
1303+ ) where {S1,S2}
1304+ dα = _zero_cotangent(α)
1305+ dβ = _zero_cotangent(β)
1306+
1307+ dC_struct = _getshadow_out(dCout, C)
1308+ dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
1309+ dC_struct === nothing && return (nothing , nothing , nothing , dα, dβ)
1310+ dCval = dC_struct. A
1311+
1312+ dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
1313+ dAval = (dA_struct isa LatticeMatrix) ? dA_struct. A : nothing
1314+
1315+ dB_struct = (hasproperty(B, :dval) ? _getshadow(B. dval) : nothing )
1316+ dBval = (dB_struct isa LatticeMatrix) ? dB_struct. A : nothing
1317+
1318+ tapeA, tapeB, tape_α = tape
1319+ Aval = (tapeA === nothing ) ? A. val. A : tapeA[1 ]
1320+ Bval = if tapeB === nothing
1321+ hasproperty(B, :val) ? B. val. A : nothing
1322+ else
1323+ tapeB[1 ]
1324+ end
1325+
1326+ NC1 = Val(C. val. NC1)
1327+ NC2 = Val(C. val. NC2)
1328+ NC3 = Val(A. val. NC2)
1329+ nw = Val(C. val. nw)
1330+ idxr = C. val. indexer
1331+ Nsites = prod(C. val. PN)
1332+ fac = conj(tape_α)
1333+
1334+ if dAval != = nothing && Bval != = nothing
1335+ JACC. parallel_for(
1336+ Nsites,
1337+ kernel_Dmatrix_mulACadd_scaled!,
1338+ dAval, dCval, Bval,
1339+ NC1, NC2, NC3, nw, idxr, fac
1340+ )
1341+ end
1342+
1343+ if dBval != = nothing
1344+ JACC. parallel_for(
1345+ Nsites,
1346+ kernel_Dmatrix_mulCdagAadd_scaled!,
1347+ dBval, dCval, Aval,
1348+ NC2, NC1, NC3, nw, idxr, fac
1349+ )
1350+ end
1351+
1352+ if tapeA != = nothing
1353+ unused!(A. val. temps, tapeA[2 ])
1354+ end
1355+ if tapeB != = nothing && hasproperty(B, :val)
1356+ unused!(B. val. temps, tapeB[2 ])
1357+ end
1358+
1359+ _should_zero_dC(dCout) && _zero_shadow!(dC_struct)
1360+ return (nothing , nothing , nothing , dα, dβ)
1361+ end
1362+
10301363function _rev_mul_ABdag!(
10311364 cfg:: ER.RevConfig ,
10321365 dCout, tape,
@@ -3160,6 +3493,35 @@ end
31603493 end
31613494end
31623495
3496+ @inline function kernel_Dmatrix_mulACadd_scaled!(i, dA, dC, B,
3497+ :: Val{NC1} , :: Val{NC2} , :: Val{NC3} , :: Val{nw} , dindexer, fac
3498+ ) where {NC1,NC2,NC3,nw}
3499+ indices = delinearize(dindexer, i, nw)
3500+ @inbounds for kc = 1 : NC3
3501+ for jc = 1 : NC2
3502+ b = B[jc, kc, indices... ]
3503+ for ic = 1 : NC1
3504+ dA[ic, kc, indices... ] += fac * dC[ic, jc, indices... ] * b
3505+ end
3506+ end
3507+ end
3508+ end
3509+
3510+ @inline function kernel_Dmatrix_mulCdagAadd_scaled!(i, dB, dC, A,
3511+ :: Val{NC2} , :: Val{NC1} , :: Val{NC3} , :: Val{nw} , dindexer, fac
3512+ ) where {NC2,NC1,NC3,nw}
3513+ indices = delinearize(dindexer, i, nw)
3514+ @inbounds for jc = 1 : NC2
3515+ for kc = 1 : NC3
3516+ acc = zero(eltype(dB))
3517+ for ic = 1 : NC1
3518+ acc += fac * conj(dC[ic, jc, indices... ]) * A[ic, kc, indices... ]
3519+ end
3520+ dB[jc, kc, indices... ] += acc
3521+ end
3522+ end
3523+ end
3524+
31633525
31643526
31653527@inline function _replace_index(indices, dim, newval)
0 commit comments