2929# Extract the shadow lattice matrix from an Annotation (Duplicated/MixedDuplicated).
3030@inline _shadow_of(ann:: ER.Annotation ) = _getshadow(ann. dval)
3131
32+ # Return a zero cotangent for Active scalar arguments; otherwise return nothing.
33+ @inline _zero_cotangent(:: Any ) = nothing
34+ @inline _zero_cotangent(x:: ER.Active{T} ) where {T} = zero(T)
35+
3236
3337
3438
@@ -962,9 +966,14 @@ function ER.augmented_primal(cfg::ER.RevConfig,
962966 A:: ER.Annotation{<:LatticeMatrix} ,
963967 α:: S ,
964968) where {RT,S}
969+ RealRt = eltype(RT)
965970 αval = hasproperty(α, :val) ? α. val : α
966- add_matrix_Adag!(C. val, A. val, αval)
967- return ER. AugmentedReturn(nothing , nothing , nothing )
971+ primal_ret = add_matrix_Adag!(C. val, A. val, αval)
972+ primal = ER. needs_primal(cfg) ? convert(RealRt, primal_ret) : nothing
973+ shadow = ER. needs_shadow(cfg) ? convert(RealRt, nothing ) : nothing
974+ cache = nothing :: Any
975+ RetT = ER. augmented_rule_return_type(cfg, RT, cache)
976+ return RetT(primal, shadow, cache)
968977end
969978
970979function ER. reverse(cfg:: ER.RevConfig ,
@@ -974,9 +983,10 @@ function ER.reverse(cfg::ER.RevConfig,
974983 A:: ER.Annotation{<:LatticeMatrix} ,
975984 α:: S ,
976985) where {S}
986+ dα = _zero_cotangent(α)
977987 dC_struct = _getshadow_out(dCout, C)
978988 dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
979- dC_struct === nothing && return (nothing , nothing , nothing )
989+ dC_struct === nothing && return (nothing , nothing , dα )
980990 dCval = dC_struct. A
981991
982992 dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
@@ -992,7 +1002,7 @@ function ER.reverse(cfg::ER.RevConfig,
9921002 )
9931003 end
9941004
995- return (nothing , nothing , nothing )
1005+ return (nothing , nothing , dα )
9961006end
9971007
9981008# add_matrix! (C += α * A)
@@ -1003,9 +1013,14 @@ function ER.augmented_primal(cfg::ER.RevConfig,
10031013 A:: ER.Annotation{<:LatticeMatrix} ,
10041014 α:: S ,
10051015) where {RT,S}
1016+ RealRt = eltype(RT)
10061017 αval = hasproperty(α, :val) ? α. val : α
1007- add_matrix!(C. val, A. val, αval)
1008- return ER. AugmentedReturn(nothing , nothing , nothing )
1018+ primal_ret = add_matrix!(C. val, A. val, αval)
1019+ primal = ER. needs_primal(cfg) ? convert(RealRt, primal_ret) : nothing
1020+ shadow = ER. needs_shadow(cfg) ? convert(RealRt, nothing ) : nothing
1021+ cache = nothing :: Any
1022+ RetT = ER. augmented_rule_return_type(cfg, RT, cache)
1023+ return RetT(primal, shadow, cache)
10091024end
10101025
10111026function ER. reverse(cfg:: ER.RevConfig ,
@@ -1015,9 +1030,10 @@ function ER.reverse(cfg::ER.RevConfig,
10151030 A:: ER.Annotation{<:LatticeMatrix} ,
10161031 α:: S ,
10171032) where {S}
1033+ dα = _zero_cotangent(α)
10181034 dC_struct = _getshadow_out(dCout, C)
10191035 dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C. dval))
1020- dC_struct === nothing && return (nothing , nothing , nothing )
1036+ dC_struct === nothing && return (nothing , nothing , dα )
10211037 dCval = dC_struct. A
10221038
10231039 dA_struct = hasproperty(A, :dval) ? _getshadow(A. dval) : nothing
@@ -1033,7 +1049,7 @@ function ER.reverse(cfg::ER.RevConfig,
10331049 )
10341050 end
10351051
1036- return (nothing , nothing , nothing )
1052+ return (nothing , nothing , dα )
10371053end
10381054
10391055
0 commit comments