| 
 | 1 | + | 
 | 2 | +import Flux: ChainRulesCore  | 
 | 3 | +# Some experiments with chain to start removing the need for recur to be mutable.  | 
 | 4 | +# As per the conversation in the recurrent network rework issue.  | 
 | 5 | + | 
 | 6 | +# Main difference between this and the _applychain function is we return a new chain  | 
 | 7 | +# with the internal state modified as well as the output of applying x to the chain.  | 
 | 8 | +function apply(chain::Flux.Chain, x)  | 
 | 9 | +  layers, out = _apply(chain.layers, x)  | 
 | 10 | +  Flux.Chain(layers), out  | 
 | 11 | +end  | 
 | 12 | + | 
 | 13 | +function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS}  | 
 | 14 | +  layers, out = _apply(Tuple(layers), x)  | 
 | 15 | +  NamedTuple{NMS}(layers), out  | 
 | 16 | +end  | 
 | 17 | + | 
 | 18 | +function _scan(layers::AbstractVector, x)  | 
 | 19 | +  new_layers = typeof(layers)(undef, length(layers))  | 
 | 20 | +  for (idx, f) in enumerate(layers)  | 
 | 21 | +    new_layers[idx], x = _apply(f, x)  | 
 | 22 | +  end  | 
 | 23 | +  new_layers, x  | 
 | 24 | +end  | 
 | 25 | + | 
 | 26 | +# Reverse rule for _scan  | 
 | 27 | +# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl  | 
 | 28 | +function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x)  | 
 | 29 | +  duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl,  input), _), cur_layer  | 
 | 30 | +    out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input)  | 
 | 31 | +  end  | 
 | 32 | +  outs = map(first, duo)  | 
 | 33 | +  backs = map(last, duo)  | 
 | 34 | +    | 
 | 35 | +  function _scan_pullback(dy)  | 
 | 36 | +    multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back  | 
 | 37 | +      dapply, dlayer, din = back(delta)  | 
 | 38 | +      return dapply, (dlayer, din)  | 
 | 39 | +    end  | 
 | 40 | +    layergrads = reverse(map(first, multi))  | 
 | 41 | +    xgrad = last(multi[end])  | 
 | 42 | +    return (ChainRulesCore.NoTangent(), layergrads, xgrad)  | 
 | 43 | +  end  | 
 | 44 | +  return (map(first, outs), last(outs[end])), _scan_pullback  | 
 | 45 | +end  | 
 | 46 | + | 
 | 47 | +function _apply(layers::AbstractVector, x)  # type-unstable path, helps compile times  | 
 | 48 | +  _scan(layers, x)  | 
 | 49 | +end  | 
 | 50 | + | 
 | 51 | +# Generated function returns a tuple of args and the last output of the network.  | 
 | 52 | +@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}  | 
 | 53 | +  x_symbols = vcat(:x, [gensym() for _ in 1:N])  | 
 | 54 | +  l_symbols = [gensym() for _ in 1:N]  | 
 | 55 | +  calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N]  | 
 | 56 | +  push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))  | 
 | 57 | +  Expr(:block, calls...)  | 
 | 58 | +end  | 
 | 59 | + | 
 | 60 | +_apply(layer, x) = layer, layer(x)  | 
 | 61 | + | 
 | 62 | + | 
0 commit comments