Skip to content

Commit 6991636

Browse files
authored
feat: return sequence properly + checkpointing + mincut (#1561)
* feat: return sequence properly + checkpointing + mincut * feat: iterate over the full sequence * test: mincut
1 parent 255ad78 commit 6991636

File tree

4 files changed

+81
-33
lines changed

4 files changed

+81
-33
lines changed

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using Static: True, False
2121

2222
using Lux: Lux, LuxOps, Training, Utils, StatefulLuxLayer
2323
using Lux.Training: TrainingBackendCache, ReactantBackend
24+
using Lux: get_time_dimension, time_dimension_size, init_recurrent_state
2425
using LuxCore: LuxCore, AbstractLuxLayer
2526
using MLDataDevices: MLDataDevices, ReactantDevice, get_device
2627

ext/LuxReactantExt/layers.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# Recurrent Layers
22
function (r::Lux.Recurrence)(x::AnyTracedRArray, ps, st::NamedTuple)
33
idxs = ntuple(Returns(Colon()), ndims(x) - 1)
4-
N = Lux.time_dimension_size(x, r.ordering)
4+
N = time_dimension_size(x, r.ordering)
55

6-
(out, carry), st = r.cell(Lux.get_time_dimension(x, 1, r.ordering), ps, st)
7-
sequence = similar(x, size(out)..., N)
6+
# execute the first step to get the types
7+
tmp = get_time_dimension(x, 1, r.ordering)
8+
carry, _ = init_recurrent_state(r.cell, tmp, ps, st)
9+
(tmp_result, _), _ = r.cell(tmp, ps, st)
810

9-
sequence[idxs..., 1] = out
10-
@trace for i in 2:N
11-
(out, carry), st = r.cell((Lux.get_time_dimension(x, i, r.ordering), carry), ps, st)
11+
final_result = similar(tmp_result)
12+
sequence = similar(tmp_result, size(tmp_result)..., N)
13+
@trace checkpointing = r.checkpointing mincut = r.mincut for i in 1:N
14+
(out, carry), st = r.cell((get_time_dimension(x, i, r.ordering), carry), ps, st)
15+
final_result[idxs...] = out
1216
sequence[idxs..., i] = out
1317
end
1418

15-
r.return_sequence isa False && return (out, st)
16-
return LuxOps.eachslice(sequence, Val(ndims(sequence))), st
19+
r.return_sequence isa False && return (final_result, st)
20+
return eachslice(sequence; dims=ndims(sequence)), st
1721
end

src/layers/recurrent.jl

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@ function init_rnn_bias(rng::AbstractRNG, init_bias, hidden_dims, bias_len)
5555
end
5656

5757
"""
58-
Recurrence(cell;
58+
Recurrence(
59+
cell;
5960
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(),
60-
return_sequence::Bool=false)
61+
return_sequence::Bool=false,
62+
mincut::Bool=false,
63+
)
6164
6265
Wraps a recurrent cell (like [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref)) to
6366
automatically operate over a sequence of inputs.
@@ -79,6 +82,8 @@ automatically operate over a sequence of inputs.
7982
the last output. Defaults to `false`.
8083
- `ordering`: The ordering of the batch and time dimensions in the input. Defaults to
8184
`BatchLastIndex()`. Alternatively can be set to `TimeLastIndex()`.
85+
- `mincut`: If `true`, we will using mincut for the reverse mode differentiation.
86+
*(Only for Reactant)*
8287
8388
# Extended Help
8489
@@ -119,24 +124,36 @@ struct Recurrence{R<:StaticBool,C,O<:AbstractTimeSeriesDataBatchOrdering} <:
119124
cell::C
120125
ordering::O
121126
return_sequence::R
127+
# FIXME: checkpointing is intentionally not documented.
128+
# See https://github.com/LuxDL/Lux.jl/pull/1561#issuecomment-3564283063
129+
checkpointing::Bool
130+
mincut::Bool
122131

123132
function Recurrence(
124-
cell::C, ordering::AbstractTimeSeriesDataBatchOrdering, return_sequence::R
133+
cell::C,
134+
ordering::AbstractTimeSeriesDataBatchOrdering,
135+
return_sequence::R,
136+
checkpointing::Bool,
137+
mincut::Bool,
125138
) where {C,R}
126139
@assert cell isa Union{
127140
<:AbstractRecurrentCell,
128141
<:Experimental.DebugLayer{<:Any,<:Any,<:AbstractRecurrentCell},
129142
}
130-
return new{R,C,typeof(ordering)}(cell, ordering, return_sequence)
143+
return new{R,C,typeof(ordering)}(
144+
cell, ordering, return_sequence, checkpointing, mincut
145+
)
131146
end
132147
end
133148

134149
function Recurrence(
135150
cell;
136151
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(),
137152
return_sequence::Bool=false,
153+
checkpointing::Bool=false,
154+
mincut::Bool=false,
138155
)
139-
return Recurrence(cell, ordering, static(return_sequence))
156+
return Recurrence(cell, ordering, static(return_sequence), checkpointing, mincut)
140157
end
141158

142159
function (r::Recurrence)(x::AbstractArray, ps, st::NamedTuple)
@@ -233,6 +250,9 @@ function applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, carry)
233250
end
234251
applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, ::Nothing) = apply(l, x, ps, st)
235252

253+
# Used to construct the initial state of the recurrent cell
254+
function init_recurrent_state end
255+
236256
@doc doc"""
237257
RNNCell(in_dims => out_dims, activation=tanh; use_bias=True(), train_state=False(),
238258
init_bias=nothing, init_weight=nothing, init_recurrent_weight=init_weight,
@@ -343,15 +363,20 @@ end
343363

344364
initialstates(rng::AbstractRNG, ::RNNCell) = (rng=Utils.sample_replicate(rng),)
345365

346-
function (rnn::RNNCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
366+
function init_recurrent_state(rnn::RNNCell{False}, x::AbstractMatrix, ps, st::NamedTuple)
347367
rng = replicate(st.rng)
348368
hidden_state = init_rnn_hidden_state(rng, rnn, x)
349-
return rnn((x, (hidden_state,)), ps, merge(st, (; rng)))
369+
return (hidden_state,), merge(st, (; rng))
350370
end
351371

352-
function (rnn::RNNCell{True})(x::AbstractMatrix, ps, st::NamedTuple)
372+
function init_recurrent_state(::RNNCell{True}, x::AbstractMatrix, ps, st::NamedTuple)
353373
hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x)
354-
return rnn((x, (hidden_state,)), ps, st)
374+
return (hidden_state,), st
375+
end
376+
377+
function (rnn::RNNCell)(x::AbstractMatrix, ps, st::NamedTuple)
378+
hidden_state, st = init_recurrent_state(rnn, x, ps, st)
379+
return rnn((x, hidden_state), ps, st)
355380
end
356381

357382
@trace function (rnn::RNNCell)(
@@ -547,31 +572,42 @@ end
547572

548573
initialstates(rng::AbstractRNG, ::LSTMCell) = (rng=Utils.sample_replicate(rng),)
549574

550-
function (lstm::LSTMCell{False,False})(x::AbstractMatrix, ps, st::NamedTuple)
575+
function init_recurrent_state(
576+
lstm::LSTMCell{False,False}, x::AbstractMatrix, ps, st::NamedTuple
577+
)
551578
rng = replicate(st.rng)
552579
hidden_state = init_rnn_hidden_state(rng, lstm, x)
553580
memory = init_rnn_hidden_state(rng, lstm, x)
554-
return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng)))
581+
return (hidden_state, memory), merge(st, (; rng))
555582
end
556583

557-
function (lstm::LSTMCell{True,False})(x::AbstractMatrix, ps, st::NamedTuple)
584+
function init_recurrent_state(
585+
lstm::LSTMCell{True,False}, x::AbstractMatrix, ps, st::NamedTuple
586+
)
558587
rng = replicate(st.rng)
559588
hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x)
560589
memory = init_rnn_hidden_state(rng, lstm, x)
561-
return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng)))
590+
return (hidden_state, memory), merge(st, (; rng))
562591
end
563592

564-
function (lstm::LSTMCell{False,True})(x::AbstractMatrix, ps, st::NamedTuple)
593+
function init_recurrent_state(
594+
lstm::LSTMCell{False,True}, x::AbstractMatrix, ps, st::NamedTuple
595+
)
565596
rng = replicate(st.rng)
566597
hidden_state = init_rnn_hidden_state(rng, lstm, x)
567598
memory = init_trainable_rnn_hidden_state(ps.memory, x)
568-
return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng)))
599+
return (hidden_state, memory), merge(st, (; rng))
569600
end
570601

571-
function (lstm::LSTMCell{True,True})(x::AbstractMatrix, ps, st::NamedTuple)
602+
function init_recurrent_state(::LSTMCell{True,True}, x::AbstractMatrix, ps, st::NamedTuple)
572603
hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x)
573604
memory = init_trainable_rnn_hidden_state(ps.memory, x)
574-
return lstm((x, (hidden_state, memory)), ps, st)
605+
return (hidden_state, memory), st
606+
end
607+
608+
function (lstm::LSTMCell)(x::AbstractMatrix, ps, st::NamedTuple)
609+
hidden_state, st = init_recurrent_state(lstm, x, ps, st)
610+
return lstm((x, hidden_state), ps, st)
575611
end
576612

577613
const _LSTMCellInputType = Tuple{<:AbstractMatrix,Tuple{<:AbstractMatrix,<:AbstractMatrix}}
@@ -744,16 +780,20 @@ end
744780

745781
initialstates(rng::AbstractRNG, ::GRUCell) = (rng=Utils.sample_replicate(rng),)
746782

747-
function (gru::GRUCell{True})(x::AbstractMatrix, ps, st::NamedTuple)
783+
function init_recurrent_state(::GRUCell{True}, x::AbstractMatrix, ps, st::NamedTuple)
748784
hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x)
749-
return gru((x, (hidden_state,)), ps, st)
785+
return (hidden_state,), st
750786
end
751787

752-
function (gru::GRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
788+
function init_recurrent_state(gru::GRUCell{False}, x::AbstractMatrix, ps, st::NamedTuple)
753789
rng = replicate(st.rng)
754-
st = merge(st, (; rng))
755790
hidden_state = init_rnn_hidden_state(rng, gru, x)
756-
return gru((x, (hidden_state,)), ps, st)
791+
return (hidden_state,), merge(st, (; rng))
792+
end
793+
794+
function (gru::GRUCell)(x::AbstractMatrix, ps, st::NamedTuple)
795+
hidden_state, st = init_recurrent_state(gru, x, ps, st)
796+
return gru((x, hidden_state), ps, st)
757797
end
758798

759799
const _GRUCellInputType = Tuple{<:AbstractMatrix,Tuple{<:AbstractMatrix}}

test/reactant/layer_tests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@
4747

4848
@testset "gradient" begin
4949
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
50-
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
51-
@test ∂x_ra ∂x atol = 1.0e-2 rtol = 1.0e-2
52-
@test check_approx(∂ps_ra, ∂ps; atol=1.0e-2, rtol=1.0e-2)
50+
@testset for mincut in (true, false), checkpointing in (false,)
51+
model_ = Recurrence(cell(4 => 4); ordering, mincut, checkpointing)
52+
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model_, x_ra, ps_ra, st_ra)
53+
@test ∂x_ra ∂x atol = 1.0e-2 rtol = 1.0e-2
54+
@test check_approx(∂ps_ra, ∂ps; atol=1.0e-2, rtol=1.0e-2)
55+
end
5356
end
5457

5558
model2 = Recurrence(cell(4 => 4); ordering, return_sequence=true)

0 commit comments

Comments
 (0)