@@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue)
3737 return x
3838end
3939
40+ accum(x:: NamedTuple , y:: ChainRulesCore.Tangent ) = accum(x, wrap_chainrules_output(y))
41+ accum(x:: ChainRulesCore.Tangent , y:: NamedTuple ) = accum(wrap_chainrules_output(x), y)
42+
43+ accum(x, y:: AbstractThunk ) = @thunk(accum(x, unthunk(y)))
44+ accum(x:: AbstractThunk , y) = @thunk(accum(unthunk(x), y))
45+ accum(x:: AbstractThunk , y:: AbstractThunk ) = @thunk(accum(unthunk(x), unthunk(y)))
46+
4047# Core functions
41- @adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
48+ @_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
4249
43- @adjoint (:: Type{V} )(x... ) where V<: Val = V(x... ), _ -> nothing
50+ @_adjoint_keepthunks (:: Type{V} )(x... ) where V<: Val = V(x... ), _ -> nothing
4451
45- @adjoint ifelse(cond:: Bool , t, f) =
52+ @_adjoint_keepthunks ifelse(cond:: Bool , t, f) =
4653 ifelse(cond, t, f),
4754 Δ -> cond ? (nothing , Δ, zero(Δ)) : (nothing , zero(Δ), Δ)
4855
49- @adjoint Base. typeassert(x, T) = Base. typeassert(x, T), Δ -> (Δ, nothing )
56+ @_adjoint_keepthunks Base. typeassert(x, T) = Base. typeassert(x, T), Δ -> (Δ, nothing )
5057
5158accum_param(:: Context{false} , _, Δ) = Δ
5259@generated function accum_param(cx:: Context , x, Δ)
7077
7178unwrap(x) = x
7279
73- @adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
80+ @_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
7481
7582unwrap(ref, x) = x
7683
77- @adjoint unwrap(ref, x) = unwrap(x), function (x̄)
84+ @_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄)
7885 accum_global(__context__, ref, x̄)
7986 (accum_param(__context__, x, x̄),)
8087end
@@ -88,7 +95,7 @@ function global_set(ref, val)
8895 end
8996end
9097
91- @adjoint ! function global_set(ref, x)
98+ @_adjoint_keepthunks ! function global_set(ref, x)
9299 global_set(ref, x), function (x̄)
93100 gs = cache(__context__)
94101 x̄ = accum(get(gs, ref, nothing ), x̄)
101108
102109using Base: tail
103110
104- @adjoint tuple(xs... ) = xs, identity
111+ @_adjoint_keepthunks tuple(xs... ) = xs, identity
105112
106- @adjoint function literal_getindex(xs:: NTuple{N,Any} , :: Val{i} ) where {N,i}
113+ @_adjoint_keepthunks function literal_getindex(xs:: NTuple{N,Any} , :: Val{i} ) where {N,i}
107114 val = xs[i]
108115 function back(Δ)
109116 accum_param(__context__, val, Δ) === nothing && return
@@ -112,7 +119,7 @@ using Base: tail
112119 val, back
113120end
114121
115- @adjoint function getindex(xs:: NTuple{N,Any} , i:: Integer ) where N
122+ @_adjoint_keepthunks function getindex(xs:: NTuple{N,Any} , i:: Integer ) where N
116123 val = xs[i]
117124 function back(Δ)
118125 accum_param(__context__, val, Δ) === nothing && return
@@ -121,10 +128,10 @@ end
121128 return val, back
122129end
123130
124- @adjoint getindex(xs:: NTuple{N,Any} , r:: AbstractUnitRange ) where N =
131+ @_adjoint_keepthunks getindex(xs:: NTuple{N,Any} , r:: AbstractUnitRange ) where N =
125132 (xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing , Val(N)), nothing ))
126133
127- @adjoint function getindex(xs:: NTuple{N,Any} , r:: AbstractVector ) where N
134+ @_adjoint_keepthunks function getindex(xs:: NTuple{N,Any} , r:: AbstractVector ) where N
128135 val = xs[r]
129136 function back(Δ)
130137 dxs = ntuple(Val(length(xs))) do x
@@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
155162end
156163
157164# Needed for iteration lowering
158- @adjoint Core. getfield(xs:: NTuple{N,Any} , i:: Int ) where N =
165+ @_adjoint_keepthunks Core. getfield(xs:: NTuple{N,Any} , i:: Int ) where N =
159166 (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing , Val(N)), nothing ))
160167
161- @adjoint Core. getfield(xs:: NamedTuple{K,<:NTuple{N,Any}} , i:: Int ) where {K,N} =
168+ @_adjoint_keepthunks Core. getfield(xs:: NamedTuple{K,<:NTuple{N,Any}} , i:: Int ) where {K,N} =
162169 (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing , Val(N))), nothing ))
163170
164- @adjoint function Base. first(xs:: Tuple )
171+ @_adjoint_keepthunks function Base. first(xs:: Tuple )
165172 drest = map(_-> nothing , tail(xs))
166173 first(xs), Δ -> ((Δ, drest... ),)
167174end
168175
169- @adjoint Base. tail(xs:: Tuple ) = tail(xs), x̄s -> ((nothing , x̄s... ),)
176+ @_adjoint_keepthunks Base. tail(xs:: Tuple ) = tail(xs), x̄s -> ((nothing , x̄s... ),)
170177
171178_empty(x) = length(x)
172179_empty(x:: Union{Tuple,NamedTuple} ) = map(_-> nothing , x)
188195
189196unapply(t, xs) = _unapply(t, xs)[1 ]
190197
191- @adjoint ! function Core. _apply(f, args... )
198+ @_adjoint_keepthunks ! function Core. _apply(f, args... )
192199 y, back = Core. _apply(_pullback, (__context__, f), args... )
193200 st = map(_empty, args)
194201 y, function (Δ)
@@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
198205 end
199206end
200207
201- @adjoint ! function Core. _apply_iterate(:: typeof (iterate), f, args... )
208+ @_adjoint_keepthunks ! function Core. _apply_iterate(:: typeof (iterate), f, args... )
202209 y, back = Core. _apply(_pullback, (__context__, f), args... )
203210 st = map(_empty, args)
204211 y, function (Δ)
223230@generated pair(:: Val{k} , v, _= nothing ) where k = :($ k = v,)
224231@generated pair(:: Val{k} , v, :: NamedTuple{keys} ) where {k,keys} = k isa Int ? :($ (getfield(keys, k)) = v,) : :($ k = v,)
225232
226- @adjoint function literal_getfield(x, :: Val{f} ) where f
233+ @_adjoint_keepthunks function literal_getfield(x, :: Val{f} ) where f
227234 val = getfield(x, f)
228235 function back(Δ)
229236 accum_param(__context__, val, Δ) === nothing && return
@@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x)
273280 end
274281end
275282
276-
277- @adjoint! function setfield!(x, f, val)
283+ @_adjoint_keepthunks! function setfield!(x, f, val)
278284 y = setfield!(x, f, val)
279285 g = grad_mut(__context__, x)
280286 y, function (_)
@@ -290,13 +296,13 @@ end
290296
291297Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)
292298
293- @adjoint ! function __new__(T, args... )
299+ @_adjoint_keepthunks ! function __new__(T, args... )
294300 x = __new__(T, args... )
295301 g = ! ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
296302 x, Jnew{T,typeof(g),false }(g)
297303end
298304
299- @adjoint ! function __splatnew__(T, args)
305+ @_adjoint_keepthunks ! function __splatnew__(T, args)
300306 x = __splatnew__(T, args)
301307 g = ! ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
302308 x, Jnew{T,typeof(g),true }(g)
0 commit comments