From c007097db97a09d51fa28f1cba56c66c5435d5fe Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Sun, 25 Jun 2023 17:58:22 -0600 Subject: [PATCH 01/12] First pass on NewRecur. --- src/Fluxperimental.jl | 2 + src/new_recur.jl | 106 ++++++++++++++++++++++++++++++++++++++++ test/new_recur.jl | 111 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + 4 files changed, 221 insertions(+) create mode 100644 src/new_recur.jl create mode 100644 test/new_recur.jl diff --git a/src/Fluxperimental.jl b/src/Fluxperimental.jl index 91438b0..893dc46 100644 --- a/src/Fluxperimental.jl +++ b/src/Fluxperimental.jl @@ -13,4 +13,6 @@ include("chain.jl") include("compact.jl") +include("new_recur.jl") + end # module Fluxperimental diff --git a/src/new_recur.jl b/src/new_recur.jl new file mode 100644 index 0000000..1c315e5 --- /dev/null +++ b/src/new_recur.jl @@ -0,0 +1,106 @@ + + +""" + NewRecur +New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. +""" +struct NewRecur{RET_SEQUENCE, T} + cell::T + # state::S + function NewRecur(cell; return_sequence::Bool=false) + new{return_sequence, typeof(cell)}(cell) + end + function NewRecur{true}(cell) + new{true, typeof(cell)}(cell) + end + function NewRecur{false}(cell) + new{false, typeof(cell)}(cell) + end +end + +# This is the same way we do 3-tensers from Flux.Recur +function (m::NewRecur{false})(x::AbstractArray{T, N}, carry) where {T, N} + @assert N >= 3 + # h = [m(x_t) for x_t in eachlastdim(x)] + + cell = l.cell + x_init, x_rest = Iterators.peel(xs) + (carry, y) = cell(carry, x_init) + for x in x_rest + (carry, y) = cell(carry, x) + end + # carry, y + y + +end + +function (l::NewRecur{false})(x::AbstractArray{T, 3}, carry=l.cell.state0) where T + m(Flux.eachlastdim(x), carry) +end + +function (l::NewRecur{false})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}, + carry=l.cell.state0) + rnn = l.cell + # carry = layer.stamte + x_init, x_rest = Iterators.peel(xs) + (carry, y) = rnn(carry, x_init) + for x in x_rest + (carry, y) = rnn(carry, x) + end + y +end + +# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ +function (l::NewRecur{true})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}, + carry=l.cell.state0) + rnn = l.cell + _xs = if xs isa Base.Generator + collect(xs) # TODO: Fix. I can't figure out how to get around this for generators. + else + xs + end + x_init, _ = Iterators.peel(_xs) + + (carry, out_) = rnn(carry, x_init) + + init = (typeof(out_)[out_], carry) + + function recurrence_op(input, (outputs, carry)) + carry, out = rnn(carry, input) + return vcat(outputs, typeof(out)[out]), carry + end + results = foldr(recurrence_op, _xs[(begin+1):end]; init) + # return NewRecur{true}(rnn, results[1][end]), first(results) + first(results) +end + +Flux.@functor NewRecur +Flux.trainable(a::NewRecur) = (; cell = a.cell) + +Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") + +NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) +# NewRecur(cell::Flux.RNNCell; return_sequence::Bool=false) = NewRecur(cell; return_sequence=return_sequence) + +# Quick Reset functionality + +# struct RecurWalk <: Flux.Functors.AbstractWalk end +# (::RecurWalk)(recurse, x) = x isa Fluxperimental.NewRecur ? reset(x) : Flux.Functors.DefaultWalk()(recurse, x) + +# function reset(m::NewRecur{SEQ}) where SEQ +# NewRecur{SEQ}(m.cell, m.cell.state0) +# end +# reset(m) = m +# function reset(m::Flux.Chain) +# ret = Flux.Functors.fmap((l)->l, m; walk=RecurWalk()) +# end + + +## +# Fallback apply timeseries data to other layers. Likely needs to be thoought through a bit more. +## + +# function apply(l, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}) +# l, [l(x) for x in xs] +# end + diff --git a/test/new_recur.jl b/test/new_recur.jl new file mode 100644 index 0000000..fc978ab --- /dev/null +++ b/test/new_recur.jl @@ -0,0 +1,111 @@ + + +@testset "RNN gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) + ps = Flux.params(nm_layer) + e, g = Flux.withgradient(ps) do + out = nm_layer(x) + sum(out[2]) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] +end + +@testset "RNN gradients-implicit-partial sequence" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + ps = Flux.params(nm_layer) + e, g = Flux.withgradient(ps) do + out = (nm_layer)(x) + sum(out) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] +end + +@testset "RNN gradients-explicit partial sequence" begin + + + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + e, g = Flux.withgradient(nm_layer) do layer + # r_l = Fluxperimental.reset(layer) + out = layer(x) + sum(out) + end + grads = g[1][:cell] + + @test primal[1] ≈ e + + if VERSION < v"1.7" + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + else + @test ∇Wi ≈ grads[:Wi] + @test ∇Wh ≈ grads[:Wh] + @test ∇b ≈ grads[:b] + @test ∇state0 ≈ grads[:state0] + end +end + diff --git a/test/runtests.jl b/test/runtests.jl index 55315cc..5291a8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,6 @@ using Flux, Fluxperimental include("compact.jl") + include("new_recur.jl") + end From c2a0ec9138226508c65b0958616f057d61b22815 Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Sun, 25 Jun 2023 18:05:37 -0600 Subject: [PATCH 02/12] Remove comments. --- src/new_recur.jl | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 1c315e5..6469391 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -80,27 +80,4 @@ Flux.trainable(a::NewRecur) = (; cell = a.cell) Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) -# NewRecur(cell::Flux.RNNCell; return_sequence::Bool=false) = NewRecur(cell; return_sequence=return_sequence) - -# Quick Reset functionality - -# struct RecurWalk <: Flux.Functors.AbstractWalk end -# (::RecurWalk)(recurse, x) = x isa Fluxperimental.NewRecur ? reset(x) : Flux.Functors.DefaultWalk()(recurse, x) - -# function reset(m::NewRecur{SEQ}) where SEQ -# NewRecur{SEQ}(m.cell, m.cell.state0) -# end -# reset(m) = m -# function reset(m::Flux.Chain) -# ret = Flux.Functors.fmap((l)->l, m; walk=RecurWalk()) -# end - - -## -# Fallback apply timeseries data to other layers. Likely needs to be thoought through a bit more. -## - -# function apply(l, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}) -# l, [l(x) for x in xs] -# end From 49c601c038762db60549595075762b4572337afa Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Mon, 26 Jun 2023 10:37:44 -0600 Subject: [PATCH 03/12] Updating interface, adding more tests and re-orging tests. --- src/new_recur.jl | 63 ++++++------ test/new_recur.jl | 246 +++++++++++++++++++++++++++++----------------- 2 files changed, 189 insertions(+), 120 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 6469391..35fb3ed 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -18,60 +18,59 @@ struct NewRecur{RET_SEQUENCE, T} end end -# This is the same way we do 3-tensers from Flux.Recur -function (m::NewRecur{false})(x::AbstractArray{T, N}, carry) where {T, N} - @assert N >= 3 - # h = [m(x_t) for x_t in eachlastdim(x)] +# assumes single timestep with batch=1. +function (l::NewRecur)(x_vec::AbstractVector{T}, + init_carry=l.cell.state0) where T<:Number + x_block = reshape(x_vec, :, 1, 1) + l(x_block, init_carry)[:, 1, 1] +end + +(l::NewRecur)(x_mat::AbstractMatrix, args...) = error("Matrix is ambiguous with NewRecur") +function (l::NewRecur{false})(x_block::AbstractArray{T, 3}, + init_carry=l.cell.state0) where {T} + xs = Flux.eachlastdim(x_block) cell = l.cell x_init, x_rest = Iterators.peel(xs) - (carry, y) = cell(carry, x_init) + (carry, y) = cell(init_carry, x_init) for x in x_rest (carry, y) = cell(carry, x) end # carry, y y - -end - -function (l::NewRecur{false})(x::AbstractArray{T, 3}, carry=l.cell.state0) where T - m(Flux.eachlastdim(x), carry) -end - -function (l::NewRecur{false})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}, - carry=l.cell.state0) - rnn = l.cell - # carry = layer.stamte - x_init, x_rest = Iterators.peel(xs) - (carry, y) = rnn(carry, x_init) - for x in x_rest - (carry, y) = rnn(carry, x) - end - y end # From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ -function (l::NewRecur{true})(xs::Union{AbstractVector{<:AbstractArray}, Base.Generator}, - carry=l.cell.state0) - rnn = l.cell - _xs = if xs isa Base.Generator - collect(xs) # TODO: Fix. I can't figure out how to get around this for generators. +function (l::NewRecur{true})(x_block::AbstractArray{T, 3}, + init_carry=l.cell.state0) where {T} + + # Time index is always the last index. + xs = Flux.eachlastdim(x_block) + xs_ = if xs isa Base.Generator + # This is because eachlastdim has different behavior in + # a gradient environment vs outside a gradient environment. + # Needs to be fixed.... + collect(xs) else xs end - x_init, _ = Iterators.peel(_xs) - (carry, out_) = rnn(carry, x_init) + cell = l.cell + x_init, x_rest = Iterators.peel(xs_) + + (carry, out_) = cell(init_carry, x_init) init = (typeof(out_)[out_], carry) function recurrence_op(input, (outputs, carry)) - carry, out = rnn(carry, input) + carry, out = cell(carry, input) return vcat(outputs, typeof(out)[out]), carry end - results = foldr(recurrence_op, _xs[(begin+1):end]; init) + results = foldr(recurrence_op, xs_[(begin+1):end]; init) # return NewRecur{true}(rnn, results[1][end]), first(results) - first(results) + h = first(results) + sze = size(h[1]) + reshape(reduce(hcat, h), sze[1], sze[2], length(h)) end Flux.@functor NewRecur diff --git a/test/new_recur.jl b/test/new_recur.jl index fc978ab..5e43c5b 100644 --- a/test/new_recur.jl +++ b/test/new_recur.jl @@ -1,111 +1,181 @@ -@testset "RNN gradients-implicit" begin - cell = Flux.RNNCell(1, 1, identity) - layer = Flux.Recur(cell) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - - nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) - ps = Flux.params(nm_layer) - e, g = Flux.withgradient(ps) do - out = nm_layer(x) - sum(out[2]) +@testset "NewRecur RNN" begin + @testset "Forward Pass" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Fluxperimental.NewRecur(cell; return_sequence=true) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + # @show layer(x) + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1, 2) + @test size(layer([2.0f0])) == (1, ) + + @test_throws ErrorException layer([2.0f0;; 3.0f0]) end - - @test primal[1] ≈ e - @test ∇Wi ≈ g[ps[1]] - @test ∇Wh ≈ g[ps[2]] - @test ∇b ≈ g[ps[3]] - @test ∇state0 ≈ g[ps[4]] -end - -@testset "RNN gradients-implicit-partial sequence" begin - cell = Flux.RNNCell(1, 1, identity) - layer = Flux.Recur(cell) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - - nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) - ps = Flux.params(nm_layer) - e, g = Flux.withgradient(ps) do - out = (nm_layer)(x) - sum(out) - end - - @test primal[1] ≈ e - @test ∇Wi ≈ g[ps[1]] - @test ∇Wh ≈ g[ps[2]] - @test ∇b ≈ g[ps[3]] - @test ∇state0 ≈ g[ps[4]] -end -@testset "RNN gradients-explicit partial sequence" begin + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = nm_layer(x_block) + sum(out[1, 1, 2]) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end - cell = Flux.RNNCell(1, 1, identity) - layer = Flux.Recur(cell) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 + @testset "gradients-explicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 - nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) - e, g = Flux.withgradient(nm_layer) do layer - # r_l = Fluxperimental.reset(layer) - out = layer(x) - sum(out) - end - grads = g[1][:cell] - @test primal[1] ≈ e + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out[1, 1, 2]) + end + grads = g[1][:cell] - if VERSION < v"1.7" + @test primal[1] ≈ e @test ∇Wi ≈ grads[:Wi] @test ∇Wh ≈ grads[:Wh] @test ∇b ≈ grads[:b] @test ∇state0 ≈ grads[:state0] - else + + end +end + +@testset "New Recur RNN Partial Sequence" begin + + @testset "Forward Pass" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Fluxperimental.NewRecur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = reshape([2.0f0, 3.0f0], 1, 1, 2) + + @test eltype(layer(x)) <: Float32 + @test size(layer(x)) == (1, 1) + @test size(layer([2.0f0])) == (1, ) + + @test_throws ErrorException layer([2.0f0;; 3.0f0]) + end + + @testset "gradients-implicit" begin + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + ps = Flux.params(nm_layer) + x_block = reshape(vcat(x...), 1, 1, length(x)) + e, g = Flux.withgradient(ps) do + out = (nm_layer)(x_block) + sum(out) + end + + @test primal[1] ≈ e + @test ∇Wi ≈ g[ps[1]] + @test ∇Wh ≈ g[ps[2]] + @test ∇b ≈ g[ps[3]] + @test ∇state0 ≈ g[ps[4]] + end + + @testset "gradients-explicit" begin + + + cell = Flux.RNNCell(1, 1, identity) + layer = Flux.Recur(cell) + layer.cell.Wi .= 5.0 + layer.cell.Wh .= 4.0 + layer.cell.b .= 0.0f0 + layer.cell.state0 .= 7.0 + x = [[2.0f0], [3.0f0]] + + # theoretical primal gradients + primal = + layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ + x[2] .* layer.cell.Wi + ∇Wi = x[1] .* layer.cell.Wh .+ x[2] + ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi + ∇b = layer.cell.Wh .+ 1 + ∇state0 = layer.cell.Wh .^ 2 + + + x_block = reshape(vcat(x...), 1, 1, length(x)) + nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) + e, g = Flux.withgradient(nm_layer) do layer + out = layer(x_block) + sum(out) + end + grads = g[1][:cell] + + @test primal[1] ≈ e @test ∇Wi ≈ grads[:Wi] @test ∇Wh ≈ grads[:Wh] @test ∇b ≈ grads[:b] @test ∇state0 ≈ grads[:state0] + end end From 79e72618fcbc236afd9bc0d131e6e14ae441c20f Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Tue, 27 Jun 2023 13:09:58 -0600 Subject: [PATCH 04/12] Moving functionality to separate helper functions. --- src/new_recur.jl | 121 +++++++++++++++++++++++++++------------------- test/new_recur.jl | 11 +++-- 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 35fb3ed..389a949 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -1,5 +1,59 @@ + +##### Helper scan funtion which can likely be put into NNLib. ##### +""" + scan + +Recreating jax.lax.scan functionality in julia. +""" +function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) + # xs = Flux.eachlastdim(x_block) + x_init, x_rest = Iterators.peel(xs) + + (carry, out_) = func(init_carry, x_init) + + init = (typeof(out_)[out_], carry) + + function recurrence_op(input, (outputs, carry)) + carry, out = func(carry, input) + return vcat(outputs, typeof(out)[out]), carry + end + results = foldr(recurrence_op, xs[(begin+1):end]; init) + results[2], results[1] +end + +function scan_full(func, init_carry, x_block) + xs_ = Flux.eachlastdim(x_block) + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_full(func, init_carry, xs) +end + +function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) + x_init, x_rest = Iterators.peel(xs) + (carry, y) = func(init_carry, x_init) + for x in x_rest + (carry, y) = func(carry, x) + end + # carry, y + carry, y +end + +function scan_partial(func, init_carry, x_block) + xs_ = Flux.eachlastdim(x_block) + xs = if xs_ isa Base.Generator + collect(xs_) # eachlastdim produces a generator in non-gradient environment + else + xs_ + end + scan_partial(func, init_carry, xs) +end + + """ NewRecur New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. @@ -18,65 +72,34 @@ struct NewRecur{RET_SEQUENCE, T} end end -# assumes single timestep with batch=1. -function (l::NewRecur)(x_vec::AbstractVector{T}, - init_carry=l.cell.state0) where T<:Number - x_block = reshape(x_vec, :, 1, 1) - l(x_block, init_carry)[:, 1, 1] -end +Flux.@functor NewRecur +Flux.trainable(a::NewRecur) = (; cell = a.cell) +Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") +NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) -(l::NewRecur)(x_mat::AbstractMatrix, args...) = error("Matrix is ambiguous with NewRecur") -function (l::NewRecur{false})(x_block::AbstractArray{T, 3}, - init_carry=l.cell.state0) where {T} - xs = Flux.eachlastdim(x_block) - cell = l.cell - x_init, x_rest = Iterators.peel(xs) - (carry, y) = cell(init_carry, x_init) - for x in x_rest - (carry, y) = cell(carry, x) - end - # carry, y - y -end +(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur") +(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur") -# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ -function (l::NewRecur{true})(x_block::AbstractArray{T, 3}, - init_carry=l.cell.state0) where {T} - - # Time index is always the last index. - xs = Flux.eachlastdim(x_block) - xs_ = if xs isa Base.Generator - # This is because eachlastdim has different behavior in - # a gradient environment vs outside a gradient environment. - # Needs to be fixed.... - collect(xs) - else - xs - end +(l::NewRecur)(xs) = l(l.cell.state0, xs) - cell = l.cell - x_init, x_rest = Iterators.peel(xs_) - (carry, out_) = cell(init_carry, x_init) +function (l::NewRecur{false})(init_carry, + xs) + results = scan_partial(l.cell, init_carry, xs) + results[2] +end - init = (typeof(out_)[out_], carry) +# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ +function (l::NewRecur{true})(init_carry, + xs,) - function recurrence_op(input, (outputs, carry)) - carry, out = cell(carry, input) - return vcat(outputs, typeof(out)[out]), carry - end - results = foldr(recurrence_op, xs_[(begin+1):end]; init) - # return NewRecur{true}(rnn, results[1][end]), first(results) - h = first(results) + results = scan_full(l.cell, init_carry, xs) + + h = results[2] sze = size(h[1]) reshape(reduce(hcat, h), sze[1], sze[2], length(h)) end -Flux.@functor NewRecur -Flux.trainable(a::NewRecur) = (; cell = a.cell) - -Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") -NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) diff --git a/test/new_recur.jl b/test/new_recur.jl index 5e43c5b..16bcf00 100644 --- a/test/new_recur.jl +++ b/test/new_recur.jl @@ -13,9 +13,9 @@ # @show layer(x) @test eltype(layer(x)) <: Float32 @test size(layer(x)) == (1, 1, 2) - @test size(layer([2.0f0])) == (1, ) - @test_throws ErrorException layer([2.0f0;; 3.0f0]) + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) end @@ -103,9 +103,10 @@ end @test eltype(layer(x)) <: Float32 @test size(layer(x)) == (1, 1) - @test size(layer([2.0f0])) == (1, ) - - @test_throws ErrorException layer([2.0f0;; 3.0f0]) + + @test_throws MethodError layer([2.0f0]) + @test_throws MethodError layer([2.0f0;; 3.0f0]) + end @testset "gradients-implicit" begin From b28bb57bdfef6c97d26873e96ebe2634bc213b9e Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Mon, 10 Jul 2023 16:39:08 -0600 Subject: [PATCH 05/12] Modified interface slightly according to comments. --- src/new_recur.jl | 64 +++++++++++++++++++++++++++++------------------ test/new_recur.jl | 23 ++++++++++++++--- 2 files changed, 59 insertions(+), 28 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 389a949..c1b2e57 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -3,28 +3,39 @@ ##### Helper scan funtion which can likely be put into NNLib. ##### """ - scan + scan_full -Recreating jax.lax.scan functionality in julia. +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, +then returns the output sequence and the final carry. """ function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) - # xs = Flux.eachlastdim(x_block) + # get the first input to setup the initial state, + # get the rest of the input to run the fold over. x_init, x_rest = Iterators.peel(xs) + # the following does the same as peel, but doesn't produce correct gradients? + ### x_init = first(xs) + ### x_rest = xs[begin+1:end] - (carry, out_) = func(init_carry, x_init) + # set up the initial state of the fold. + (carry_, out_) = func(init_carry, x_init) + init = (carry_, [out_]) - init = (typeof(out_)[out_], carry) - - function recurrence_op(input, (outputs, carry)) + # recurrence operation used in the fold. Takes the state of the + # folde and the next input, returns the new state. + function __recurrence_op((carry, outputs), input) carry, out = func(carry, input) - return vcat(outputs, typeof(out)[out]), carry + return carry, vcat(outputs, [out]) end - results = foldr(recurrence_op, xs[(begin+1):end]; init) - results[2], results[1] + # Fold left to right. + foldl(__recurrence_op, x_rest; init) end function scan_full(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. xs = if xs_ isa Base.Generator collect(xs_) # eachlastdim produces a generator in non-gradient environment else @@ -33,18 +44,28 @@ function scan_full(func, init_carry, x_block) scan_full(func, init_carry, xs) end + +""" + scan_partial + +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, +then returns the final output of the sequence and the final carry. +""" function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) x_init, x_rest = Iterators.peel(xs) (carry, y) = func(init_carry, x_init) for x in x_rest (carry, y) = func(carry, x) end - # carry, y carry, y end function scan_partial(func, init_carry, x_block) + # x_block is an abstractarray and we want to scan over the last dimension. xs_ = Flux.eachlastdim(x_block) + + # this is needed due to a bug in eachlastdim which produces a vector in a + # gradient context, but a generator otherwise. xs = if xs_ isa Base.Generator collect(xs_) # eachlastdim produces a generator in non-gradient environment else @@ -77,28 +98,23 @@ Flux.trainable(a::NewRecur) = (; cell = a.cell) Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")") NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence) - (l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur") (l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur") -(l::NewRecur)(xs) = l(l.cell.state0, xs) - +function (l::NewRecur)(xs::AbstractArray) + results = l(l.cell.state0, xs) + results[2] # Only return the output here. +end -function (l::NewRecur{false})(init_carry, - xs) +function (l::NewRecur{false})(init_carry, xs) results = scan_partial(l.cell, init_carry, xs) - results[2] + results[1], results[2] end -# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/ -function (l::NewRecur{true})(init_carry, - xs,) +function (l::NewRecur{true})(init_carry, xs) results = scan_full(l.cell, init_carry, xs) - - h = results[2] - sze = size(h[1]) - reshape(reduce(hcat, h), sze[1], sze[2], length(h)) + results[1], stack(results[2], dims=3) end diff --git a/test/new_recur.jl b/test/new_recur.jl index 16bcf00..0acc915 100644 --- a/test/new_recur.jl +++ b/test/new_recur.jl @@ -2,7 +2,8 @@ @testset "NewRecur RNN" begin @testset "Forward Pass" begin - cell = Flux.RNNCell(1, 1, identity) + # tanh is needed for forward check to determine ordering of inputs. + cell = Flux.RNNCell(1, 1, tanh) layer = Fluxperimental.NewRecur(cell; return_sequence=true) layer.cell.Wi .= 5.0 layer.cell.Wh .= 4.0 @@ -10,15 +11,22 @@ layer.cell.state0 .= 7.0 x = reshape([2.0f0, 3.0f0], 1, 1, 2) - # @show layer(x) + # Lets make sure th output is correct + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + @test eltype(layer(x)) <: Float32 @test size(layer(x)) == (1, 1, 2) + @test layer(x)[1, 1, 2] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed. + @test layer(cell.state0, x)[2][1,1,2] ≈ out[1,1] @test_throws MethodError layer([2.0f0]) @test_throws MethodError layer([2.0f0;; 3.0f0]) end - @testset "gradients-implicit" begin cell = Flux.RNNCell(1, 1, identity) layer = Flux.Recur(cell) @@ -52,7 +60,6 @@ @test ∇state0 ≈ g[ps[4]] end - @testset "gradients-explicit" begin cell = Flux.RNNCell(1, 1, identity) @@ -101,8 +108,16 @@ end layer.cell.state0 .= 7.0 x = reshape([2.0f0, 3.0f0], 1, 1, 2) + h = cell.state0 + h, out = cell(h, [2.0f0]) + h, out = cell(h, [3.0f0]) + @test eltype(layer(x)) <: Float32 @test size(layer(x)) == (1, 1) + @test layer(x)[1, 1] ≈ out[1,1] + + @test length(layer(cell.state0, x)) == 2 + @test layer(cell.state0, x)[2][1,1] ≈ out[1,1] @test_throws MethodError layer([2.0f0]) @test_throws MethodError layer([2.0f0;; 3.0f0]) From 7b60350bc843ee0f4bd22582fb18e088d1c1a4f1 Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Thu, 3 Aug 2023 10:25:59 -0600 Subject: [PATCH 06/12] Fixed gradients using Lux's impl. --- src/new_recur.jl | 143 ++++++++++++++++++++++++++++++++++++++++------- test/runtests.jl | 6 +- 2 files changed, 126 insertions(+), 23 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index c1b2e57..35ae021 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -1,5 +1,6 @@ - +import Flux: ChainRulesCore +# import ChainRulesCore: rrule, HasReverseMode ##### Helper scan funtion which can likely be put into NNLib. ##### """ @@ -8,28 +9,130 @@ Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the output sequence and the final carry. """ + function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) - # get the first input to setup the initial state, - # get the rest of the input to run the fold over. - x_init, x_rest = Iterators.peel(xs) - # the following does the same as peel, but doesn't produce correct gradients? - ### x_init = first(xs) - ### x_rest = xs[begin+1:end] - - # set up the initial state of the fold. - (carry_, out_) = func(init_carry, x_init) - init = (carry_, [out_]) - - # recurrence operation used in the fold. Takes the state of the - # folde and the next input, returns the new state. - function __recurrence_op((carry, outputs), input) - carry, out = func(carry, input) - return carry, vcat(outputs, [out]) - end - # Fold left to right. - foldl(__recurrence_op, x_rest; init) + # Recurrence operation used in the fold. Takes the state of the + # fold and the next input, returns the new state. + function recurrence_op((carry, outputs), input) + carry, out = func(carry, input) + return carry, vcat(outputs, [out]) + end + # Fold left to right. + return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) end +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init, + x::Union{AbstractArray, Tuple}; +) where {G} + hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this + accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b + # hobbits = accumulate(x; init=(init, nothing)) do (a, _), b + c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) + end + y = first(last(hobbits)) + axe = axes(x) + project = ChainRulesCore.ProjectTo(x) + function unfoldl(dy) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + ds, da, db = back(dc) + end + dop = sum(first, trio) + dx = map(last, Iterators.reverse(trio)) + d_init = trio[end][2] + return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) + end + return y, unfoldl +end + +# function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, +# ::typeof(Base.mapfoldl_impl), +# op::G, +# x::AbstractArray, +# init) where {G} +# list, start = x, init +# hobbits = Vector{Any}(undef, length(list)) # Unfornately Zygote needs this +# accumulate!(hobbits, list; init=(start, nothing)) do (a, _), b +# return CRC.rrule_via_ad(cfg, op, a, b) +# end +# y = first(last(hobbits)) +# ax = axes(x) +# project = ChainRulesCore.ProjectTo(x) +# function ∇mapfoldl_impl(Δ) +# trio = accumulate(Iterators.reverse(hobbits); init=(0, Δ, 0)) do (_, dc, _), (_, back) +# return back(dc) +# end +# ∂op = sum(first, trio) +# ∂x = map(last, Iterators.reverse(trio)) +# return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) #NoTangent(), ∂op, project(reshape(∂x, ax)), trio[end][2] +# end +# return y, ∇mapfoldl_impl +# end + + + +# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) +# function __recurrence_op(::Tuple{Nothing, Nothing}, input) +# carry, out = func(init_carry, input) +# return carry, [out] +# end + +# # recurrence operation used in the fold. Takes the state of the +# function __recurrence_op((carry, outputs), input) +# carry, out = func(carry, input) +# return carry, vcat(outputs, [out]) +# end + +# # Fold left to right. +# foldl(__recurrence_op, xs; init=(nothing, nothing)) +# end + +# _cat_output(::Nothing, out) = [out] +# _cat_output(outputs, out) = vcat(outputs, [out]) + +# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) +# # Recurrence operation used in the fold. Takes the state of the +# # fold and the next input, returns the new state. +# function recurrence_op((carry, outputs), input) +# carry = ifelse(carry === nothing, init_carry, carry) +# carry, out = func(carry, input) +# return carry, _cat_output(outputs, out) +# end +# # Fold left to right. +# return foldl(recurrence_op, xs; init=(nothing, nothing)) +# end + +# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) +# # get the first input to setup the initial state, +# # get the rest of the input to run the fold over. +# # x_init, x_rest = Iterators.peel(xs) +# # the following does the same as peel, but doesn't produce correct gradients? +# ### x_init = first(xs) +# ### x_rest = xs[begin+1:end] + +# # set up the initial state of the fold. +# # (carry_, out_) = func(init_carry, x_init) +# # init = (carry_, [out_]) +# function __recurrence_op(::Nothing, input) +# carry, out = func(init_carry, input) +# return carry, [out] +# end + +# # recurrence operation used in the fold. Takes the state of the +# # folde and the next input, returns the new state. +# function __recurrence_op((carry, outputs), input) +# carry, out = func(carry, input) +# return carry, vcat(outputs, [out]) +# end +# # Fold left to right. +# # foldl(__recurrence_op, xs; init=nothing) +# foldl_init(__recurrence_op, xs) +# end + function scan_full(func, init_carry, x_block) # x_block is an abstractarray and we want to scan over the last dimension. xs_ = Flux.eachlastdim(x_block) diff --git a/test/runtests.jl b/test/runtests.jl index 5291a8a..7dfbc42 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,11 @@ using Test using Flux, Fluxperimental @testset "Fluxperimental.jl" begin - include("split_join.jl") + # include("split_join.jl") - include("chain.jl") + # include("chain.jl") - include("compact.jl") + # include("compact.jl") include("new_recur.jl") From 832f860a296b033541bc3cab2960d749c8714852 Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Thu, 3 Aug 2023 14:48:03 -0600 Subject: [PATCH 07/12] Cleanup. --- src/new_recur.jl | 141 ++++++++++------------------------------------- 1 file changed, 28 insertions(+), 113 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 35ae021..81ae36e 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -9,7 +9,6 @@ import Flux: ChainRulesCore Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the output sequence and the final carry. """ - function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) # Recurrence operation used in the fold. Takes the state of the # fold and the next input, returns the new state. @@ -21,118 +20,6 @@ function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) end -function ChainRulesCore.rrule( - config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, - ::typeof(Base.mapfoldl_impl), - ::typeof(identity), - op::G, - init, - x::Union{AbstractArray, Tuple}; -) where {G} - hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this - accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b - # hobbits = accumulate(x; init=(init, nothing)) do (a, _), b - c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) - end - y = first(last(hobbits)) - axe = axes(x) - project = ChainRulesCore.ProjectTo(x) - function unfoldl(dy) - trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) - ds, da, db = back(dc) - end - dop = sum(first, trio) - dx = map(last, Iterators.reverse(trio)) - d_init = trio[end][2] - return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) - end - return y, unfoldl -end - -# function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, -# ::typeof(Base.mapfoldl_impl), -# op::G, -# x::AbstractArray, -# init) where {G} -# list, start = x, init -# hobbits = Vector{Any}(undef, length(list)) # Unfornately Zygote needs this -# accumulate!(hobbits, list; init=(start, nothing)) do (a, _), b -# return CRC.rrule_via_ad(cfg, op, a, b) -# end -# y = first(last(hobbits)) -# ax = axes(x) -# project = ChainRulesCore.ProjectTo(x) -# function ∇mapfoldl_impl(Δ) -# trio = accumulate(Iterators.reverse(hobbits); init=(0, Δ, 0)) do (_, dc, _), (_, back) -# return back(dc) -# end -# ∂op = sum(first, trio) -# ∂x = map(last, Iterators.reverse(trio)) -# return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) #NoTangent(), ∂op, project(reshape(∂x, ax)), trio[end][2] -# end -# return y, ∇mapfoldl_impl -# end - - - -# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) -# function __recurrence_op(::Tuple{Nothing, Nothing}, input) -# carry, out = func(init_carry, input) -# return carry, [out] -# end - -# # recurrence operation used in the fold. Takes the state of the -# function __recurrence_op((carry, outputs), input) -# carry, out = func(carry, input) -# return carry, vcat(outputs, [out]) -# end - -# # Fold left to right. -# foldl(__recurrence_op, xs; init=(nothing, nothing)) -# end - -# _cat_output(::Nothing, out) = [out] -# _cat_output(outputs, out) = vcat(outputs, [out]) - -# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) -# # Recurrence operation used in the fold. Takes the state of the -# # fold and the next input, returns the new state. -# function recurrence_op((carry, outputs), input) -# carry = ifelse(carry === nothing, init_carry, carry) -# carry, out = func(carry, input) -# return carry, _cat_output(outputs, out) -# end -# # Fold left to right. -# return foldl(recurrence_op, xs; init=(nothing, nothing)) -# end - -# function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) -# # get the first input to setup the initial state, -# # get the rest of the input to run the fold over. -# # x_init, x_rest = Iterators.peel(xs) -# # the following does the same as peel, but doesn't produce correct gradients? -# ### x_init = first(xs) -# ### x_rest = xs[begin+1:end] - -# # set up the initial state of the fold. -# # (carry_, out_) = func(init_carry, x_init) -# # init = (carry_, [out_]) -# function __recurrence_op(::Nothing, input) -# carry, out = func(init_carry, input) -# return carry, [out] -# end - -# # recurrence operation used in the fold. Takes the state of the -# # folde and the next input, returns the new state. -# function __recurrence_op((carry, outputs), input) -# carry, out = func(carry, input) -# return carry, vcat(outputs, [out]) -# end -# # Fold left to right. -# # foldl(__recurrence_op, xs; init=nothing) -# foldl_init(__recurrence_op, xs) -# end - function scan_full(func, init_carry, x_block) # x_block is an abstractarray and we want to scan over the last dimension. xs_ = Flux.eachlastdim(x_block) @@ -147,6 +34,34 @@ function scan_full(func, init_carry, x_block) scan_full(func, init_carry, xs) end +# Chain Rule for Base.mapfoldl_impl +function ChainRulesCore.rrule( + config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.mapfoldl_impl), + ::typeof(identity), + op::G, + init, + x::Union{AbstractArray, Tuple}; +) where {G} + hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this + accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b + c, back = ChainRulesCore.rrule_via_ad(config, op, a, b) + end + y = first(last(hobbits)) + axe = axes(x) + project = ChainRulesCore.ProjectTo(x) + function unfoldl(dy) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + ds, da, db = back(dc) + end + dop = sum(first, trio) + dx = map(last, Iterators.reverse(trio)) + d_init = trio[end][2] + return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe))) + end + return y, unfoldl +end + """ scan_partial From 72a7fe149b4403b605e3897ee3a750b20389804b Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Thu, 3 Aug 2023 14:55:10 -0600 Subject: [PATCH 08/12] Fixed tests. --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7dfbc42..5291a8a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,11 @@ using Test using Flux, Fluxperimental @testset "Fluxperimental.jl" begin - # include("split_join.jl") + include("split_join.jl") - # include("chain.jl") + include("chain.jl") - # include("compact.jl") + include("compact.jl") include("new_recur.jl") From b238091a66c4bfd09aa5c23b086360a58f91ba90 Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Fri, 4 Aug 2023 10:46:53 -0600 Subject: [PATCH 09/12] Added Compat as a dependency. --- Project.toml | 1 + src/new_recur.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 5187a82..a2e4518 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" version = "0.1.2" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/new_recur.jl b/src/new_recur.jl index 81ae36e..d5c6929 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -1,5 +1,6 @@ import Flux: ChainRulesCore +import Compat: stack # import ChainRulesCore: rrule, HasReverseMode ##### Helper scan funtion which can likely be put into NNLib. ##### From 2ed6588cf0f7de375e17e4c1d524f2cd4fc299f6 Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Mon, 7 Aug 2023 10:10:37 -0600 Subject: [PATCH 10/12] Minor edits to whitespace, Temp Docs. --- src/new_recur.jl | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index d5c6929..a44590f 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -8,17 +8,18 @@ import Compat: stack scan_full Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, -then returns the output sequence and the final carry. +then returns the full output of the sequence and the final carry. See `scan_partial` to only +return the final output of the sequence. """ function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) - # Recurrence operation used in the fold. Takes the state of the - # fold and the next input, returns the new state. - function recurrence_op((carry, outputs), input) - carry, out = func(carry, input) - return carry, vcat(outputs, [out]) - end - # Fold left to right. - return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) + # Recurrence operation used in the fold. Takes the state of the + # fold and the next input, returns the new state. + function recurrence_op((carry, outputs), input) + carry, out = func(carry, input) + return carry, vcat(outputs, [out]) + end + # Fold left to right. + return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs) end function scan_full(func, init_carry, x_block) @@ -68,7 +69,8 @@ end scan_partial Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, -then returns the final output of the sequence and the final carry. +then returns the final output of the sequence and the final carry. See `scan_full` to return +the entire output sequence. """ function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) x_init, x_rest = Iterators.peel(xs) @@ -97,6 +99,17 @@ end """ NewRecur New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. +This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` +(`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This +structure has no internal state, and instead returns: + +```julia +l = NewRNN(1,2) +xs # Some input array Input x BatchSize x Time +init_carry # the initial carry of the cell. +l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry. +l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE. +``` """ struct NewRecur{RET_SEQUENCE, T} cell::T From 176161495d2df7a20efc8053755f4882088ee6ac Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Tue, 8 Aug 2023 19:17:23 -0600 Subject: [PATCH 11/12] Remove extra newlines. --- src/new_recur.jl | 6 ------ test/new_recur.jl | 9 --------- 2 files changed, 15 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index d5c6929..8394818 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -1,7 +1,5 @@ - import Flux: ChainRulesCore import Compat: stack -# import ChainRulesCore: rrule, HasReverseMode ##### Helper scan funtion which can likely be put into NNLib. ##### """ @@ -131,10 +129,6 @@ function (l::NewRecur{false})(init_carry, xs) end function (l::NewRecur{true})(init_carry, xs) - results = scan_full(l.cell, init_carry, xs) results[1], stack(results[2], dims=3) end - - - diff --git a/test/new_recur.jl b/test/new_recur.jl index 0acc915..cb5cf2a 100644 --- a/test/new_recur.jl +++ b/test/new_recur.jl @@ -1,5 +1,3 @@ - - @testset "NewRecur RNN" begin @testset "Forward Pass" begin # tanh is needed for forward check to determine ordering of inputs. @@ -93,12 +91,10 @@ @test ∇Wh ≈ grads[:Wh] @test ∇b ≈ grads[:b] @test ∇state0 ≈ grads[:state0] - end end @testset "New Recur RNN Partial Sequence" begin - @testset "Forward Pass" begin cell = Flux.RNNCell(1, 1, identity) layer = Fluxperimental.NewRecur(cell) @@ -121,7 +117,6 @@ end @test_throws MethodError layer([2.0f0]) @test_throws MethodError layer([2.0f0;; 3.0f0]) - end @testset "gradients-implicit" begin @@ -158,8 +153,6 @@ end end @testset "gradients-explicit" begin - - cell = Flux.RNNCell(1, 1, identity) layer = Flux.Recur(cell) layer.cell.Wi .= 5.0 @@ -177,7 +170,6 @@ end ∇b = layer.cell.Wh .+ 1 ∇state0 = layer.cell.Wh .^ 2 - x_block = reshape(vcat(x...), 1, 1, length(x)) nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false) e, g = Flux.withgradient(nm_layer) do layer @@ -194,4 +186,3 @@ end end end - From 52f3b7f120f4f339070429743ddb4b3f0f6b64be Mon Sep 17 00:00:00 2001 From: Matthew Schlegel Date: Tue, 8 Aug 2023 19:19:34 -0600 Subject: [PATCH 12/12] Some more small modifications. --- src/new_recur.jl | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/new_recur.jl b/src/new_recur.jl index 3891fcd..824644f 100644 --- a/src/new_recur.jl +++ b/src/new_recur.jl @@ -5,9 +5,7 @@ import Compat: stack """ scan_full -Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, -then returns the full output of the sequence and the final carry. See `scan_partial` to only -return the final output of the sequence. +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence. """ function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray}) # Recurrence operation used in the fold. Takes the state of the @@ -66,9 +64,7 @@ end """ scan_partial -Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, -then returns the final output of the sequence and the final carry. See `scan_full` to return -the entire output sequence. +Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence. """ function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray}) x_init, x_rest = Iterators.peel(xs) @@ -96,10 +92,7 @@ end """ NewRecur -New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. -This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` -(`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This -structure has no internal state, and instead returns: +New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns: ```julia l = NewRNN(1,2)