Skip to content

Commit a3d2da0

Browse files
committed
AD is fixed
1 parent 0fedfca commit a3d2da0

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LatticeMatrices"
22
uuid = "dd6a91e4-736f-4540-ac85-13822ca7b545"
33
authors = ["Yuki Nagai <cometscome@gmail.com>"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

ext/AD/AD.jl

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ end
2929
# Extract the shadow lattice matrix from an Annotation (Duplicated/MixedDuplicated).
3030
@inline _shadow_of(ann::ER.Annotation) = _getshadow(ann.dval)
3131

32+
# Return a zero cotangent for Active scalar arguments; otherwise return nothing.
33+
@inline _zero_cotangent(::Any) = nothing
34+
@inline _zero_cotangent(x::ER.Active{T}) where {T} = zero(T)
35+
3236

3337

3438

@@ -962,9 +966,14 @@ function ER.augmented_primal(cfg::ER.RevConfig,
962966
A::ER.Annotation{<:LatticeMatrix},
963967
α::S,
964968
) where {RT,S}
969+
RealRt = eltype(RT)
965970
αval = hasproperty(α, :val) ? α.val : α
966-
add_matrix_Adag!(C.val, A.val, αval)
967-
return ER.AugmentedReturn(nothing, nothing, nothing)
971+
primal_ret = add_matrix_Adag!(C.val, A.val, αval)
972+
primal = ER.needs_primal(cfg) ? convert(RealRt, primal_ret) : nothing
973+
shadow = ER.needs_shadow(cfg) ? convert(RealRt, nothing) : nothing
974+
cache = nothing::Any
975+
RetT = ER.augmented_rule_return_type(cfg, RT, cache)
976+
return RetT(primal, shadow, cache)
968977
end
969978

970979
function ER.reverse(cfg::ER.RevConfig,
@@ -974,9 +983,10 @@ function ER.reverse(cfg::ER.RevConfig,
974983
A::ER.Annotation{<:LatticeMatrix},
975984
α::S,
976985
) where {S}
986+
= _zero_cotangent(α)
977987
dC_struct = _getshadow_out(dCout, C)
978988
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
979-
dC_struct === nothing && return (nothing, nothing, nothing)
989+
dC_struct === nothing && return (nothing, nothing, )
980990
dCval = dC_struct.A
981991

982992
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
@@ -992,7 +1002,7 @@ function ER.reverse(cfg::ER.RevConfig,
9921002
)
9931003
end
9941004

995-
return (nothing, nothing, nothing)
1005+
return (nothing, nothing, )
9961006
end
9971007

9981008
# add_matrix! (C += α * A)
@@ -1003,9 +1013,14 @@ function ER.augmented_primal(cfg::ER.RevConfig,
10031013
A::ER.Annotation{<:LatticeMatrix},
10041014
α::S,
10051015
) where {RT,S}
1016+
RealRt = eltype(RT)
10061017
αval = hasproperty(α, :val) ? α.val : α
1007-
add_matrix!(C.val, A.val, αval)
1008-
return ER.AugmentedReturn(nothing, nothing, nothing)
1018+
primal_ret = add_matrix!(C.val, A.val, αval)
1019+
primal = ER.needs_primal(cfg) ? convert(RealRt, primal_ret) : nothing
1020+
shadow = ER.needs_shadow(cfg) ? convert(RealRt, nothing) : nothing
1021+
cache = nothing::Any
1022+
RetT = ER.augmented_rule_return_type(cfg, RT, cache)
1023+
return RetT(primal, shadow, cache)
10091024
end
10101025

10111026
function ER.reverse(cfg::ER.RevConfig,
@@ -1015,9 +1030,10 @@ function ER.reverse(cfg::ER.RevConfig,
10151030
A::ER.Annotation{<:LatticeMatrix},
10161031
α::S,
10171032
) where {S}
1033+
= _zero_cotangent(α)
10181034
dC_struct = _getshadow_out(dCout, C)
10191035
dC_struct isa LatticeMatrix || (dC_struct = _getshadow(C.dval))
1020-
dC_struct === nothing && return (nothing, nothing, nothing)
1036+
dC_struct === nothing && return (nothing, nothing, )
10211037
dCval = dC_struct.A
10221038

10231039
dA_struct = hasproperty(A, :dval) ? _getshadow(A.dval) : nothing
@@ -1033,7 +1049,7 @@ function ER.reverse(cfg::ER.RevConfig,
10331049
)
10341050
end
10351051

1036-
return (nothing, nothing, nothing)
1052+
return (nothing, nothing, )
10371053
end
10381054

10391055

0 commit comments

Comments
 (0)