Skip to content

Commit 1596828

Browse files
committed
AD with GPU bugfix
1 parent 8ddc10e commit 1596828

File tree

1 file changed

+362
-0
lines changed

1 file changed

+362
-0
lines changed

ext/AD/AD.jl

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,178 @@ function ER.reverse(cfg::ER.RevConfig,
473473
return (nothing, nothing, nothing)
474474
end
475475

476+
# C = A * B' where B is Adjoint_Lattice wrapper.
477+
function ER.augmented_primal(cfg::ER.RevConfig,
478+
::ER.Const{typeof(mul!)},
479+
::Type{RT},
480+
C::ER.Annotation{<:LatticeMatrix},
481+
A::ER.Annotation{<:LatticeMatrix},
482+
B::ER.Annotation{<:Adjoint_Lattice},
483+
) where {RT}
484+
mul_ABdag!(C.val, A.val, B.val.data)
485+
486+
tapeA_obj, itA = get_block(A.val.temps)
487+
tapeA_obj .= A.val.A
488+
tapeA = (tapeA_obj, itA)
489+
490+
tapeB_obj, itB = get_block(B.val.data.temps)
491+
tapeB_obj .= B.val.data.A
492+
tapeB = (tapeB_obj, itB)
493+
494+
return ER.AugmentedReturn(nothing, nothing, (tapeA, tapeB))
495+
end
496+
497+
function ER.reverse(cfg::ER.RevConfig,
498+
::ER.Const{typeof(mul!)},
499+
dCout, tape,
500+
C::ER.Annotation{<:LatticeMatrix},
501+
A::ER.Annotation{<:LatticeMatrix},
502+
B::ER.Annotation{<:Adjoint_Lattice},
503+
)
504+
dC_struct = _getshadow_out(dCout, C)
505+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
506+
dC_struct === nothing && return (nothing, nothing, nothing)
507+
dCval = dC_struct.A
508+
509+
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
510+
dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing
511+
512+
dB_struct = hasproperty(B, :dval) ? _getshadow_data(B.dval) : nothing
513+
dBval = (dB_struct isa LatticeMatrix) ? dB_struct.A : nothing
514+
515+
tapeA, tapeB = tape
516+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
517+
Bval = (tapeB === nothing) ? B.val.data.A : tapeB[1]
518+
519+
NC1 = Val(C.val.NC1)
520+
NC2 = Val(C.val.NC2)
521+
NC3 = Val(A.val.NC2)
522+
nw = Val(C.val.nw)
523+
idxr = C.val.indexer
524+
Nsites = prod(C.val.PN)
525+
526+
if dAval !== nothing
527+
JACC.parallel_for(
528+
Nsites,
529+
kernel_Dmatrix_mulACadd!,
530+
dAval, dCval, Bval,
531+
NC1, NC2, NC3, nw, idxr
532+
)
533+
end
534+
535+
if dBval !== nothing
536+
JACC.parallel_for(
537+
Nsites,
538+
kernel_Dmatrix_mulCdagAadd!,
539+
dBval, dCval, Aval,
540+
NC2, NC1, NC3, nw, idxr
541+
)
542+
end
543+
544+
if tapeA !== nothing
545+
unused!(A.val.temps, tapeA[2])
546+
end
547+
if tapeB !== nothing
548+
unused!(B.val.data.temps, tapeB[2])
549+
end
550+
551+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
552+
return (nothing, nothing, nothing)
553+
end
554+
555+
# C = β*C + α*A*B' where B is Adjoint_Lattice wrapper.
556+
function ER.augmented_primal(cfg::ER.RevConfig,
557+
::ER.Const{typeof(mul!)},
558+
::Type{RT},
559+
C::ER.Annotation{<:LatticeMatrix},
560+
A::ER.Annotation{<:LatticeMatrix},
561+
B::ER.Annotation{<:Adjoint_Lattice},
562+
α::S1,
563+
β::S2,
564+
) where {RT,S1,S2}
565+
αval = hasproperty(α, :val) ? α.val : α
566+
βval = hasproperty(β, :val) ? β.val : β
567+
primal_ret = mul_ABdag!(C.val, A.val, B.val.data, αval, βval)
568+
569+
tapeA_obj, itA = get_block(A.val.temps)
570+
tapeA_obj .= A.val.A
571+
tapeA = (tapeA_obj, itA)
572+
573+
tapeB_obj, itB = get_block(B.val.data.temps)
574+
tapeB_obj .= B.val.data.A
575+
tapeB = (tapeB_obj, itB)
576+
577+
tape = (tapeA, tapeB, αval)
578+
RetT = ER.augmented_rule_return_type(cfg, RT, tape)
579+
primal = ER.needs_primal(cfg) ? primal_ret : nothing
580+
shadow = ER.needs_shadow(cfg) ? nothing : nothing
581+
return RetT(primal, shadow, tape)
582+
end
583+
584+
function ER.reverse(cfg::ER.RevConfig,
585+
::ER.Const{typeof(mul!)},
586+
dCout, tape,
587+
C::ER.Annotation{<:LatticeMatrix},
588+
A::ER.Annotation{<:LatticeMatrix},
589+
B::ER.Annotation{<:Adjoint_Lattice},
590+
α::S1,
591+
β::S2,
592+
) where {S1,S2}
593+
= _zero_cotangent(α)
594+
= _zero_cotangent(β)
595+
596+
dC_struct = _getshadow_out(dCout, C)
597+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
598+
dC_struct === nothing && return (nothing, nothing, nothing, dα, dβ)
599+
dCval = dC_struct.A
600+
601+
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
602+
dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing
603+
604+
dB_struct = hasproperty(B, :dval) ? _getshadow_data(B.dval) : nothing
605+
dBval = (dB_struct isa LatticeMatrix) ? dB_struct.A : nothing
606+
607+
tapeA, tapeB, tape_α = tape
608+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
609+
Bval = (tapeB === nothing) ? B.val.data.A : tapeB[1]
610+
611+
NC1 = Val(C.val.NC1)
612+
NC2 = Val(C.val.NC2)
613+
NC3 = Val(A.val.NC2)
614+
nw = Val(C.val.nw)
615+
idxr = C.val.indexer
616+
Nsites = prod(C.val.PN)
617+
fac = conj(tape_α)
618+
619+
if dAval !== nothing
620+
JACC.parallel_for(
621+
Nsites,
622+
kernel_Dmatrix_mulACadd_scaled!,
623+
dAval, dCval, Bval,
624+
NC1, NC2, NC3, nw, idxr, fac
625+
)
626+
end
627+
628+
if dBval !== nothing
629+
JACC.parallel_for(
630+
Nsites,
631+
kernel_Dmatrix_mulCdagAadd_scaled!,
632+
dBval, dCval, Aval,
633+
NC2, NC1, NC3, nw, idxr, fac
634+
)
635+
end
636+
637+
if tapeA !== nothing
638+
unused!(A.val.temps, tapeA[2])
639+
end
640+
if tapeB !== nothing
641+
unused!(B.val.data.temps, tapeB[2])
642+
end
643+
644+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
645+
return (nothing, nothing, nothing, dα, dβ)
646+
end
647+
476648
# C = A * (shifted B)'
477649
function ER.augmented_primal(cfg::ER.RevConfig,
478650
::ER.Const{typeof(mul!)},
@@ -1027,6 +1199,167 @@ function ER.reverse(cfg::ER.RevConfig,
10271199
return _rev_mul_ABdag!(cfg, dCout, tape, C, A, B; do_dB=do_dB)
10281200
end
10291201

1202+
# C = β*C + α*A*B'
1203+
function ER.augmented_primal(cfg::ER.RevConfig,
1204+
::ER.Const{typeof(mul_ABdag!)},
1205+
::Type{RT},
1206+
C::ER.Annotation{<:LatticeMatrix},
1207+
A::ER.Annotation{<:LatticeMatrix},
1208+
B::ER.Annotation{<:LatticeMatrix},
1209+
α::S1,
1210+
β::S2,
1211+
) where {RT,S1,S2}
1212+
αval = hasproperty(α, :val) ? α.val : α
1213+
βval = hasproperty(β, :val) ? β.val : β
1214+
primal_ret = mul_ABdag!(C.val, A.val, B.val, αval, βval)
1215+
1216+
tapeA_obj, itA = get_block(A.val.temps)
1217+
tapeA_obj .= A.val.A
1218+
tapeA = (tapeA_obj, itA)
1219+
1220+
tapeB_obj, itB = get_block(B.val.temps)
1221+
tapeB_obj .= B.val.A
1222+
tapeB = (tapeB_obj, itB)
1223+
1224+
tape = (tapeA, tapeB, αval)
1225+
RetT = ER.augmented_rule_return_type(cfg, RT, tape)
1226+
primal = ER.needs_primal(cfg) ? primal_ret : nothing
1227+
shadow = ER.needs_shadow(cfg) ? nothing : nothing
1228+
return RetT(primal, shadow, tape)
1229+
end
1230+
1231+
function ER.reverse(cfg::ER.RevConfig,
1232+
::ER.Const{typeof(mul_ABdag!)},
1233+
dCout, tape,
1234+
C::ER.Annotation{<:LatticeMatrix},
1235+
A::ER.Annotation{<:LatticeMatrix},
1236+
B::ER.Duplicated{<:LatticeMatrix},
1237+
α::S1,
1238+
β::S2,
1239+
) where {S1,S2}
1240+
= _zero_cotangent(α)
1241+
= _zero_cotangent(β)
1242+
1243+
dC_struct = _getshadow_out(dCout, C)
1244+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
1245+
dC_struct === nothing && return (nothing, nothing, nothing, dα, dβ)
1246+
dCval = dC_struct.A
1247+
1248+
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
1249+
dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing
1250+
1251+
s = _getshadow(B.dval)
1252+
dBval = (s isa LatticeMatrix) ? s.A : nothing
1253+
1254+
tapeA, tapeB, tape_α = tape
1255+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
1256+
Bval = (tapeB === nothing) ? B.val.A : tapeB[1]
1257+
1258+
NC1 = Val(C.val.NC1)
1259+
NC2 = Val(C.val.NC2)
1260+
NC3 = Val(A.val.NC2)
1261+
nw = Val(C.val.nw)
1262+
idxr = C.val.indexer
1263+
Nsites = prod(C.val.PN)
1264+
fac = conj(tape_α)
1265+
1266+
if dAval !== nothing && Bval !== nothing
1267+
JACC.parallel_for(
1268+
Nsites,
1269+
kernel_Dmatrix_mulACadd_scaled!,
1270+
dAval, dCval, Bval,
1271+
NC1, NC2, NC3, nw, idxr, fac
1272+
)
1273+
end
1274+
1275+
if dBval !== nothing
1276+
JACC.parallel_for(
1277+
Nsites,
1278+
kernel_Dmatrix_mulCdagAadd_scaled!,
1279+
dBval, dCval, Aval,
1280+
NC2, NC1, NC3, nw, idxr, fac
1281+
)
1282+
end
1283+
1284+
if tapeA !== nothing
1285+
unused!(A.val.temps, tapeA[2])
1286+
end
1287+
if tapeB !== nothing
1288+
unused!(B.val.temps, tapeB[2])
1289+
end
1290+
1291+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
1292+
return (nothing, nothing, nothing, dα, dβ)
1293+
end
1294+
1295+
function ER.reverse(cfg::ER.RevConfig,
1296+
::ER.Const{typeof(mul_ABdag!)},
1297+
dCout, tape,
1298+
C::ER.Annotation{<:LatticeMatrix},
1299+
A::ER.Annotation{<:LatticeMatrix},
1300+
B,
1301+
α::S1,
1302+
β::S2,
1303+
) where {S1,S2}
1304+
= _zero_cotangent(α)
1305+
= _zero_cotangent(β)
1306+
1307+
dC_struct = _getshadow_out(dCout, C)
1308+
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
1309+
dC_struct === nothing && return (nothing, nothing, nothing, dα, dβ)
1310+
dCval = dC_struct.A
1311+
1312+
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
1313+
dAval = (dA_struct isa LatticeMatrix) ? dA_struct.A : nothing
1314+
1315+
dB_struct = (hasproperty(B, :dval) ? _getshadow(B.dval) : nothing)
1316+
dBval = (dB_struct isa LatticeMatrix) ? dB_struct.A : nothing
1317+
1318+
tapeA, tapeB, tape_α = tape
1319+
Aval = (tapeA === nothing) ? A.val.A : tapeA[1]
1320+
Bval = if tapeB === nothing
1321+
hasproperty(B, :val) ? B.val.A : nothing
1322+
else
1323+
tapeB[1]
1324+
end
1325+
1326+
NC1 = Val(C.val.NC1)
1327+
NC2 = Val(C.val.NC2)
1328+
NC3 = Val(A.val.NC2)
1329+
nw = Val(C.val.nw)
1330+
idxr = C.val.indexer
1331+
Nsites = prod(C.val.PN)
1332+
fac = conj(tape_α)
1333+
1334+
if dAval !== nothing && Bval !== nothing
1335+
JACC.parallel_for(
1336+
Nsites,
1337+
kernel_Dmatrix_mulACadd_scaled!,
1338+
dAval, dCval, Bval,
1339+
NC1, NC2, NC3, nw, idxr, fac
1340+
)
1341+
end
1342+
1343+
if dBval !== nothing
1344+
JACC.parallel_for(
1345+
Nsites,
1346+
kernel_Dmatrix_mulCdagAadd_scaled!,
1347+
dBval, dCval, Aval,
1348+
NC2, NC1, NC3, nw, idxr, fac
1349+
)
1350+
end
1351+
1352+
if tapeA !== nothing
1353+
unused!(A.val.temps, tapeA[2])
1354+
end
1355+
if tapeB !== nothing && hasproperty(B, :val)
1356+
unused!(B.val.temps, tapeB[2])
1357+
end
1358+
1359+
_should_zero_dC(dCout) && _zero_shadow!(dC_struct)
1360+
return (nothing, nothing, nothing, dα, dβ)
1361+
end
1362+
10301363
function _rev_mul_ABdag!(
10311364
cfg::ER.RevConfig,
10321365
dCout, tape,
@@ -3160,6 +3493,35 @@ end
31603493
end
31613494
end
31623495

3496+
@inline function kernel_Dmatrix_mulACadd_scaled!(i, dA, dC, B,
3497+
::Val{NC1}, ::Val{NC2}, ::Val{NC3}, ::Val{nw}, dindexer, fac
3498+
) where {NC1,NC2,NC3,nw}
3499+
indices = delinearize(dindexer, i, nw)
3500+
@inbounds for kc = 1:NC3
3501+
for jc = 1:NC2
3502+
b = B[jc, kc, indices...]
3503+
for ic = 1:NC1
3504+
dA[ic, kc, indices...] += fac * dC[ic, jc, indices...] * b
3505+
end
3506+
end
3507+
end
3508+
end
3509+
3510+
@inline function kernel_Dmatrix_mulCdagAadd_scaled!(i, dB, dC, A,
3511+
::Val{NC2}, ::Val{NC1}, ::Val{NC3}, ::Val{nw}, dindexer, fac
3512+
) where {NC2,NC1,NC3,nw}
3513+
indices = delinearize(dindexer, i, nw)
3514+
@inbounds for jc = 1:NC2
3515+
for kc = 1:NC3
3516+
acc = zero(eltype(dB))
3517+
for ic = 1:NC1
3518+
acc += fac * conj(dC[ic, jc, indices...]) * A[ic, kc, indices...]
3519+
end
3520+
dB[jc, kc, indices...] += acc
3521+
end
3522+
end
3523+
end
3524+
31633525

31643526

31653527
@inline function _replace_index(indices, dim, newval)

0 commit comments

Comments
 (0)