Skip to content

Commit 570db21

Browse files
committed
use purity check in map, too?
1 parent 86c3bb4 commit 570db21

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/lib/array.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,19 +215,27 @@ _tryreverse(m, x) = x
215215
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
216216

217217
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
218-
@eval function $∇mapfunc(cx, f, args...)
218+
@eval function $∇mapfunc(cx, f::F, args...) where {F}
219219
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
220220
if isempty(ys_and_backs)
221221
ys_and_backs, _ -> nothing
222222
else
223-
ys, backs = unzip(ys_and_backs)
223+
ys = map(first, ys_and_backs)
224224
ys, function (Δ)
225225
isnothing(Δ) && return nothing
226-
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
227-
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)...)
228-
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
229-
Δf = reduce(accum, Δf_and_args[1])
230-
(Δf, Δf_and_args[2:end]...)
226+
if _purefun(F) && length(args) == 1
227+
Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
228+
(nothing, Δarg)
229+
elseif _purefun(F)
230+
Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
231+
(nothing, Δargs...)
232+
else
233+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
234+
Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
235+
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
236+
Δf = reduce(accum, Δf_and_args[1])
237+
(Δf, Δf_and_args[2:end]...)
238+
end
231239
end
232240
end
233241
end

0 commit comments

Comments
 (0)