Skip to content

Commit 9dcae27

Browse files
authored
NewRecur experimental interface (#11)
1 parent 62e7d06 commit 9dcae27

File tree

5 files changed

+333
-0
lines changed

5 files changed

+333
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
33
version = "0.1.3"
44

55
[deps]
6+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
67
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
78
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
89
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

src/Fluxperimental.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ include("chain.jl")
1313

1414
include("compact.jl")
1515

16+
include("new_recur.jl")
17+
1618
end # module Fluxperimental

src/new_recur.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import Flux: ChainRulesCore
2+
import Compat: stack
3+
4+
##### Helper scan funtion which can likely be put into NNLib. #####
5+
"""
6+
scan_full
7+
8+
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence.
9+
"""
10+
function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray})
11+
# Recurrence operation used in the fold. Takes the state of the
12+
# fold and the next input, returns the new state.
13+
function recurrence_op((carry, outputs), input)
14+
carry, out = func(carry, input)
15+
return carry, vcat(outputs, [out])
16+
end
17+
# Fold left to right.
18+
return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs)
19+
end
20+
21+
function scan_full(func, init_carry, x_block)
22+
# x_block is an abstractarray and we want to scan over the last dimension.
23+
xs_ = Flux.eachlastdim(x_block)
24+
25+
# this is needed due to a bug in eachlastdim which produces a vector in a
26+
# gradient context, but a generator otherwise.
27+
xs = if xs_ isa Base.Generator
28+
collect(xs_) # eachlastdim produces a generator in non-gradient environment
29+
else
30+
xs_
31+
end
32+
scan_full(func, init_carry, xs)
33+
end
34+
35+
# Chain Rule for Base.mapfoldl_impl
36+
function ChainRulesCore.rrule(
37+
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
38+
::typeof(Base.mapfoldl_impl),
39+
::typeof(identity),
40+
op::G,
41+
init,
42+
x::Union{AbstractArray, Tuple};
43+
) where {G}
44+
hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this
45+
accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b
46+
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
47+
end
48+
y = first(last(hobbits))
49+
axe = axes(x)
50+
project = ChainRulesCore.ProjectTo(x)
51+
function unfoldl(dy)
52+
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
53+
ds, da, db = back(dc)
54+
end
55+
dop = sum(first, trio)
56+
dx = map(last, Iterators.reverse(trio))
57+
d_init = trio[end][2]
58+
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
59+
end
60+
return y, unfoldl
61+
end
62+
63+
64+
"""
65+
scan_partial
66+
67+
Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence.
68+
"""
69+
function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray})
70+
x_init, x_rest = Iterators.peel(xs)
71+
(carry, y) = func(init_carry, x_init)
72+
for x in x_rest
73+
(carry, y) = func(carry, x)
74+
end
75+
carry, y
76+
end
77+
78+
function scan_partial(func, init_carry, x_block)
79+
# x_block is an abstractarray and we want to scan over the last dimension.
80+
xs_ = Flux.eachlastdim(x_block)
81+
82+
# this is needed due to a bug in eachlastdim which produces a vector in a
83+
# gradient context, but a generator otherwise.
84+
xs = if xs_ isa Base.Generator
85+
collect(xs_) # eachlastdim produces a generator in non-gradient environment
86+
else
87+
xs_
88+
end
89+
scan_partial(func, init_carry, xs)
90+
end
91+
92+
93+
"""
94+
NewRecur
95+
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns:
96+
97+
```julia
98+
l = NewRNN(1,2)
99+
xs # Some input array Input x BatchSize x Time
100+
init_carry # the initial carry of the cell.
101+
l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry.
102+
l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE.
103+
```
104+
"""
105+
struct NewRecur{RET_SEQUENCE, T}
106+
cell::T
107+
# state::S
108+
function NewRecur(cell; return_sequence::Bool=false)
109+
new{return_sequence, typeof(cell)}(cell)
110+
end
111+
function NewRecur{true}(cell)
112+
new{true, typeof(cell)}(cell)
113+
end
114+
function NewRecur{false}(cell)
115+
new{false, typeof(cell)}(cell)
116+
end
117+
end
118+
119+
Flux.@functor NewRecur
120+
Flux.trainable(a::NewRecur) = (; cell = a.cell)
121+
Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")")
122+
NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)
123+
124+
(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur")
125+
(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur")
126+
127+
function (l::NewRecur)(xs::AbstractArray)
128+
results = l(l.cell.state0, xs)
129+
results[2] # Only return the output here.
130+
end
131+
132+
function (l::NewRecur{false})(init_carry, xs)
133+
results = scan_partial(l.cell, init_carry, xs)
134+
results[1], results[2]
135+
end
136+
137+
function (l::NewRecur{true})(init_carry, xs)
138+
results = scan_full(l.cell, init_carry, xs)
139+
results[1], stack(results[2], dims=3)
140+
end

test/new_recur.jl

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
@testset "NewRecur RNN" begin
2+
@testset "Forward Pass" begin
3+
# tanh is needed for forward check to determine ordering of inputs.
4+
cell = Flux.RNNCell(1, 1, tanh)
5+
layer = Fluxperimental.NewRecur(cell; return_sequence=true)
6+
layer.cell.Wi .= 5.0
7+
layer.cell.Wh .= 4.0
8+
layer.cell.b .= 0.0f0
9+
layer.cell.state0 .= 7.0
10+
x = reshape([2.0f0, 3.0f0], 1, 1, 2)
11+
12+
# Lets make sure th output is correct
13+
h = cell.state0
14+
h, out = cell(h, [2.0f0])
15+
h, out = cell(h, [3.0f0])
16+
17+
@test eltype(layer(x)) <: Float32
18+
@test size(layer(x)) == (1, 1, 2)
19+
@test layer(x)[1, 1, 2] out[1,1]
20+
21+
@test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed.
22+
@test layer(cell.state0, x)[2][1,1,2] out[1,1]
23+
24+
@test_throws MethodError layer([2.0f0])
25+
@test_throws MethodError layer([2.0f0;; 3.0f0])
26+
end
27+
28+
@testset "gradients-implicit" begin
29+
cell = Flux.RNNCell(1, 1, identity)
30+
layer = Flux.Recur(cell)
31+
layer.cell.Wi .= 5.0
32+
layer.cell.Wh .= 4.0
33+
layer.cell.b .= 0.0f0
34+
layer.cell.state0 .= 7.0
35+
x = [[2.0f0], [3.0f0]]
36+
37+
# theoretical primal gradients
38+
primal =
39+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
40+
x[2] .* layer.cell.Wi
41+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
42+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
43+
∇b = layer.cell.Wh .+ 1
44+
∇state0 = layer.cell.Wh .^ 2
45+
46+
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
47+
ps = Flux.params(nm_layer)
48+
x_block = reshape(vcat(x...), 1, 1, length(x))
49+
e, g = Flux.withgradient(ps) do
50+
out = nm_layer(x_block)
51+
sum(out[1, 1, 2])
52+
end
53+
54+
@test primal[1] e
55+
@test ∇Wi g[ps[1]]
56+
@test ∇Wh g[ps[2]]
57+
@test ∇b g[ps[3]]
58+
@test ∇state0 g[ps[4]]
59+
end
60+
61+
@testset "gradients-explicit" begin
62+
63+
cell = Flux.RNNCell(1, 1, identity)
64+
layer = Flux.Recur(cell)
65+
layer.cell.Wi .= 5.0
66+
layer.cell.Wh .= 4.0
67+
layer.cell.b .= 0.0f0
68+
layer.cell.state0 .= 7.0
69+
x = [[2.0f0], [3.0f0]]
70+
71+
# theoretical primal gradients
72+
primal =
73+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
74+
x[2] .* layer.cell.Wi
75+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
76+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
77+
∇b = layer.cell.Wh .+ 1
78+
∇state0 = layer.cell.Wh .^ 2
79+
80+
81+
x_block = reshape(vcat(x...), 1, 1, length(x))
82+
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
83+
e, g = Flux.withgradient(nm_layer) do layer
84+
out = layer(x_block)
85+
sum(out[1, 1, 2])
86+
end
87+
grads = g[1][:cell]
88+
89+
@test primal[1] e
90+
@test ∇Wi grads[:Wi]
91+
@test ∇Wh grads[:Wh]
92+
@test ∇b grads[:b]
93+
@test ∇state0 grads[:state0]
94+
end
95+
end
96+
97+
@testset "New Recur RNN Partial Sequence" begin
98+
@testset "Forward Pass" begin
99+
cell = Flux.RNNCell(1, 1, identity)
100+
layer = Fluxperimental.NewRecur(cell)
101+
layer.cell.Wi .= 5.0
102+
layer.cell.Wh .= 4.0
103+
layer.cell.b .= 0.0f0
104+
layer.cell.state0 .= 7.0
105+
x = reshape([2.0f0, 3.0f0], 1, 1, 2)
106+
107+
h = cell.state0
108+
h, out = cell(h, [2.0f0])
109+
h, out = cell(h, [3.0f0])
110+
111+
@test eltype(layer(x)) <: Float32
112+
@test size(layer(x)) == (1, 1)
113+
@test layer(x)[1, 1] out[1,1]
114+
115+
@test length(layer(cell.state0, x)) == 2
116+
@test layer(cell.state0, x)[2][1,1] out[1,1]
117+
118+
@test_throws MethodError layer([2.0f0])
119+
@test_throws MethodError layer([2.0f0;; 3.0f0])
120+
end
121+
122+
@testset "gradients-implicit" begin
123+
cell = Flux.RNNCell(1, 1, identity)
124+
layer = Flux.Recur(cell)
125+
layer.cell.Wi .= 5.0
126+
layer.cell.Wh .= 4.0
127+
layer.cell.b .= 0.0f0
128+
layer.cell.state0 .= 7.0
129+
x = [[2.0f0], [3.0f0]]
130+
131+
# theoretical primal gradients
132+
primal =
133+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
134+
x[2] .* layer.cell.Wi
135+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
136+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
137+
∇b = layer.cell.Wh .+ 1
138+
∇state0 = layer.cell.Wh .^ 2
139+
140+
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
141+
ps = Flux.params(nm_layer)
142+
x_block = reshape(vcat(x...), 1, 1, length(x))
143+
e, g = Flux.withgradient(ps) do
144+
out = (nm_layer)(x_block)
145+
sum(out)
146+
end
147+
148+
@test primal[1] e
149+
@test ∇Wi g[ps[1]]
150+
@test ∇Wh g[ps[2]]
151+
@test ∇b g[ps[3]]
152+
@test ∇state0 g[ps[4]]
153+
end
154+
155+
@testset "gradients-explicit" begin
156+
cell = Flux.RNNCell(1, 1, identity)
157+
layer = Flux.Recur(cell)
158+
layer.cell.Wi .= 5.0
159+
layer.cell.Wh .= 4.0
160+
layer.cell.b .= 0.0f0
161+
layer.cell.state0 .= 7.0
162+
x = [[2.0f0], [3.0f0]]
163+
164+
# theoretical primal gradients
165+
primal =
166+
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
167+
x[2] .* layer.cell.Wi
168+
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
169+
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
170+
∇b = layer.cell.Wh .+ 1
171+
∇state0 = layer.cell.Wh .^ 2
172+
173+
x_block = reshape(vcat(x...), 1, 1, length(x))
174+
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
175+
e, g = Flux.withgradient(nm_layer) do layer
176+
out = layer(x_block)
177+
sum(out)
178+
end
179+
grads = g[1][:cell]
180+
181+
@test primal[1] e
182+
@test ∇Wi grads[:Wi]
183+
@test ∇Wh grads[:Wh]
184+
@test ∇b grads[:b]
185+
@test ∇state0 grads[:state0]
186+
187+
end
188+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ using Flux, Fluxperimental
88

99
include("compact.jl")
1010

11+
include("new_recur.jl")
12+
1113
end

0 commit comments

Comments
 (0)