@@ -215,19 +215,27 @@ _tryreverse(m, x) = x
215215_tryreverse (m:: typeof (map), x:: Union{AbstractVector, Tuple} ) = reverse (x)
216216
217217for (mapfunc,∇mapfunc) in [(:map ,:∇map ),(:pmap ,:∇pmap )]
218- @eval function $∇mapfunc (cx, f, args... )
218+ @eval function $∇mapfunc (cx, f:: F , args... ) where {F}
219219 ys_and_backs = $ mapfunc ((args... ) -> _pullback (cx, f, args... ), args... )
220220 if isempty (ys_and_backs)
221221 ys_and_backs, _ -> nothing
222222 else
223- ys, backs = unzip ( ys_and_backs)
223+ ys = map (first, ys_and_backs)
224224 ys, function (Δ)
225225 isnothing (Δ) && return nothing
226- # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
227- Δf_and_args_zipped = $ mapfunc ((f, δ) -> f (δ), _tryreverse ($ mapfunc, backs, Δ)... )
228- Δf_and_args = unzip (_tryreverse ($ mapfunc, Δf_and_args_zipped))
229- Δf = reduce (accum, Δf_and_args[1 ])
230- (Δf, Δf_and_args[2 : end ]. .. )
226+ if _purefun (F) && length (args) == 1
227+ Δarg = $ mapfunc (((_,pb), δ) -> last (pb (δ)), ys_and_backs, Δ) # No unzip needed
228+ (nothing , Δarg)
229+ elseif _purefun (F)
230+ Δargs = unzip ($ mapfunc (((_,pb), δ) -> Base. tail (pb (δ)), ys_and_backs, Δ))
231+ (nothing , Δargs... )
232+ else
233+ # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
234+ Δf_and_args_zipped = $ mapfunc (((_,pb), δ) -> pb (δ), _tryreverse ($ mapfunc, ys_and_backs, Δ)... )
235+ Δf_and_args = unzip (_tryreverse ($ mapfunc, Δf_and_args_zipped))
236+ Δf = reduce (accum, Δf_and_args[1 ])
237+ (Δf, Δf_and_args[2 : end ]. .. )
238+ end
231239 end
232240 end
233241 end
0 commit comments