@@ -465,7 +465,7 @@ function rrule(
465465 y = first(last(hobbits))
466466 project = ProjectTo(x)
467467 function foldl_pullback_tuple(dy)
468- trio = accumulate(_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
468+ trio = accumulate(reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
469469 ds, da, db = back(dc)
470470 # Don't need to store every `da`, need one for the next iteration + the last.
471471 end
@@ -501,78 +501,43 @@ end
501501
502502# The implementation was originally for both tuples and arrays, although using accumulate
503503# to carry intermediate results along creates arrays of tuples which could be avoided.
504- # Using a loop can be a few times faster, this should be replaced.
505- # Note also that it does not return a gradient for `init`.
504+ # Using a loop can be a few times faster, this should be replaced:
505+ # https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305
506+
507+ # Note also that it does not return a gradient for `init`, now marked `@not_implemented`.
506508
507509function rrule(
508- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
510+ config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
509511 ) where {G}
510- list, start = if init === _INIT
511- _drop1(x), first (x)
512+ start, list = if init === Base . _InitialValue()
513+ Iterators . peel (x)
512514 else
513515 # Case with init keyword is simpler to understand first!
514- _reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
516+ init, x
515517 end
516- hobbits = accumulate(list; init= (start, nothing )) do (a,_), b
517- # Here `a` is what we would normally cary forward, and `_` ignores
518- # the previous iteration's pullback function (needed later),
519- # while `b` is the fresh input from `list` as usual.
520- c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
521- # We don't really need to store every `c`, last one is `foldl` output.
522- # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
518+ hobbits = accumulate(list; init= (start, nothing )) do (a, _), b
519+ c, back = rrule_via_ad(config, op, a, b)
523520 end
524521 y = first(last(hobbits))
525522 axe = axes(x)
526523 project = ProjectTo(x)
527524 function unfoldl(dy)
528- trio = accumulate(_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
525+ trio = accumulate(Iterators . reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
529526 ds, da, db = back(dc)
530- # Don't need to store every `da`, need one for the next iteration + maybe last
531527 end
532528 dop = sum(first, trio)
533- dx = map(last, _reverse1(trio))
534- if init === _INIT
535- # `hobbits` is one short
529+ dx = map(last, Iterators. reverse(trio))
530+ if init === Base. _InitialValue() # `hobbits` is one short
536531 dx = _vcat1(trio[end ][2 ], dx)
537532 end
538533 d_init = @not_implemented " gradient for foldl does not at present include init, sorry"
539- return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1 (dx, axe)))
534+ return (NoTangent(), NoTangent(), dop, d_init, project(reshape (dx, axe)))
540535 end
541536 return y, unfoldl
542537end
543538
544-
545- # ####
546- # #### Iterator-or-Tuple functions
547- # ####
548-
549- # This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
550- # and also provides some alternatives for versions of Julia where iterators weren't supported.
551- # Inspired by `Base._reverse`, used in defn of `foldr`.
552-
553- # To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
554- # be replaced by _peel1 like Iterators.peel
555-
556- _reverse1(x) = Iterators. reverse(x)
557- _drop1(x) = Iterators. drop(x, 1 )
558- _zip2(x, y) = zip(x, y) # for `accumulate`, below
559-
560- _reverse1(x:: Tuple ) = reverse(x)
561- _drop1(x:: Tuple ) = Base. tail(x)
562- _zip2(x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple(i -> (x[i],y[i]), N)
563-
564- const _INIT = Base. _InitialValue()
565-
566539_vcat1(x, ys:: AbstractVector ) = vcat(x, ys)
567540_vcat1(x:: AbstractArray , ys:: AbstractVector ) = vcat([x], ys)
568- _vcat1(x, ys:: Tuple ) = (x, ys... )
569-
570- _reshape1(x:: AbstractArray , axe) = reshape(x, axe)
571- _reshape1(x:: Tuple , axe) = x
572-
573- _no_tuple_tangent(dx:: Tangent ) = ChainRulesCore. backing(dx)
574- _no_tuple_tangent(dx) = dx
575-
576541
577542# ####
578543# #### `accumulate`
@@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx
584549# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
585550
586551function rrule(
587- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. _accumulate!), op:: G , y, x:: AbstractVector , dims:: Nothing , init,
552+ config:: RuleConfig{>:HasReverseMode} ,
553+ :: typeof (Base. _accumulate!),
554+ op:: G , y:: AbstractVector ,
555+ x:: AbstractVector ,
556+ dims:: Nothing ,
557+ init,
588558 ) where {G}
589559
590- list, start = if init === nothing
591- _drop1(x), first (x)
560+ start, list = if init === nothing
561+ Iterators . peel (x)
592562 else
593- x, something(init)
563+ something(init), x
594564 end
595565 hobbits = accumulate(list; init = (start, nothing )) do (a, _), b
596566 c, back = rrule_via_ad(config, op, a, b)
@@ -607,28 +577,24 @@ function rrule(
607577 axe = axes(x)
608578 project = ProjectTo(x)
609579 function decumulate(dy)
610- dy_plain = _no_tuple_tangent(unthunk(dy))
611- rev_list = if init === nothing
612- # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
613- # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
614- _zip2(_reverse1(hobbits), _reverse1(dy_plain))
615- else
616- _zip2(_reverse1(hobbits), _reverse1(dy_plain))
617- end
580+ dy_plain = unthunk(dy)
581+ rev_list = zip(Iterators. reverse(hobbits), Iterators. reverse(dy_plain))
582+ # Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
583+ # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
618584 trio = accumulate(rev_list; init= (0 , ZeroTangent(), 0 )) do (_, dc, _), ((_, back), dz)
619585 ds, da, db = back(dc + dz)
620586 # Don't need to store every 'da', but need for next iteration, and the last one.
621587 end
622588 dop = sum(first, trio)
623- dx = map(last, _reverse1 (trio))
589+ dx = map(last, Iterators . reverse (trio))
624590 if init == nothing
625591 # `hobbits` is one short, and the first one is weird
626592 dx = _vcat1(trio[end ][2 ] + dy_plain[1 ], dx)
627593 end
628594 dy = @not_implemented " no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
629595 d_init_not = @not_implemented " gradient for accumulate does not at present include init, sorry"
630596 d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
631- return (NoTangent(), dop, dy, project(_reshape1 (dx, axe)), NoTangent(), d_init)
597+ return (NoTangent(), dop, dy, project(reshape (dx, axe)), NoTangent(), d_init)
632598 end
633- return _reshape1 (y, axe), decumulate
599+ return reshape (y, axe), decumulate
634600end
0 commit comments