@@ -417,17 +417,73 @@ end
417417end
418418
419419# ####
420- # #### `foldl`
420+ # ####
421+ # #### `foldl(f, ::Tuple)`
421422# ####
422423
423424# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
424- # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
425+ # this `f` is stateful, in which case the gradient must be calculated in the reverse order.
426+
427+ # The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
428+ # which is handled below. For tuples, `reduce` also comes here.
429+
430+ function rrule(
431+ config:: RuleConfig{>:HasReverseMode} ,
432+ :: typeof (Base. mapfoldl_impl),
433+ :: typeof (identity),
434+ op:: G ,
435+ init:: Base._InitialValue ,
436+ x:: Tuple ;
437+ ) where {G}
438+ hobbits = accumulate(Base. tail(x); init= (first(x), nothing )) do (a, _), b
439+ # Here `a` is what we would normally cary forward, and `_` ignores
440+ # the previous iteration's pullback function (needed later),
441+ # while `b` is the fresh input from `list` as usual.
442+ c, back = rrule_via_ad(config, op, a, b)
443+ # We don't really need to store every `c`, last one is `foldl` output.
444+ # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
445+ end
446+ y = first(last(hobbits))
447+ project = ProjectTo(x)
448+ function foldl_pullback_tuple(dy)
449+ trio = accumulate(_reverse1(hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
450+ ds, da, db = back(dc)
451+ # Don't need to store every `da`, need one for the next iteration + the last.
452+ end
453+ dop = sum(first, trio)
454+ dx = (trio[end ][2 ], reverse(map(last, trio)). .. )
455+ return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx))
456+ end
457+ return y, foldl_pullback_tuple
458+ end
459+
460+ function rrule(
461+ config:: RuleConfig{>:HasReverseMode} ,
462+ :: typeof (Base. mapfoldl_impl),
463+ :: typeof (identity),
464+ op:: G ,
465+ init,
466+ x:: Tuple ;
467+ ) where {G}
468+ # Treat `init` by simply appending it to the `x`:
469+ y, back = rrule(config, Base. mapfoldl_impl, identity, op, Base. _InitialValue(), (init, x... ))
470+ project_x = ProjectTo(x)
471+ project_in = ProjectTo(init)
472+ function foldl_pullback_tuple_init(dy)
473+ _, _, dop, _, dxplus = back(dy)
474+ return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base. tail(dxplus)))
475+ end
476+ return y, foldl_pullback_tuple_init
477+ end
425478
426- # The implementation aims to be efficient for both tuples and arrays, although using accumulate
427- # to carry intermediate results along creates arrays of tuples which could be avoided; using a
428- # loop can be a few times faster. Note also that it does not return a gradient for `init`.
479+ # ####
480+ # #### `foldl(f, ::Array)`
481+ # ####
429482
430- # Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
483+ # The implementation was originally for both tuples and arrays, although using accumulate
484+ # to carry intermediate results along creates arrays of tuples which could be avoided.
485+ # Using a loop can be a few times faster, this should be replaced.
486+ # Note also that it does not return a gradient for `init`.
431487
432488function rrule(
433489 config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
@@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x)
486542_drop1(x:: Tuple ) = Base. tail(x)
487543_zip2(x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple(i -> (x[i],y[i]), N)
488544
489- # struct _InitialValue end # Old versions don't have `Base._InitialValue`
490- const _INIT = VERSION >= v" 1.5" ? Base. _InitialValue() : NamedTuple()
545+ const _INIT = Base. _InitialValue()
491546
492547_vcat1(x, ys:: AbstractVector ) = vcat(x, ys)
493548_vcat1(x:: AbstractArray , ys:: AbstractVector ) = vcat([x], ys)
0 commit comments