@@ -55,9 +55,12 @@ function init_rnn_bias(rng::AbstractRNG, init_bias, hidden_dims, bias_len)
5555end
5656
5757"""
58- Recurrence(cell;
58+ Recurrence(
59+ cell;
5960 ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(),
60- return_sequence::Bool=false)
61+ return_sequence::Bool=false,
62+ mincut::Bool=false,
63+ )
6164
6265Wraps a recurrent cell (like [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref)) to
6366automatically operate over a sequence of inputs.
@@ -79,6 +82,8 @@ automatically operate over a sequence of inputs.
7982 the last output. Defaults to `false`.
8083 - `ordering`: The ordering of the batch and time dimensions in the input. Defaults to
8184 `BatchLastIndex()`. Alternatively can be set to `TimeLastIndex()`.
85+ - `mincut`: If `true`, we will using mincut for the reverse mode differentiation.
86+ *(Only for Reactant)*
8287
8388# Extended Help
8489
@@ -119,24 +124,36 @@ struct Recurrence{R<:StaticBool,C,O<:AbstractTimeSeriesDataBatchOrdering} <:
119124 cell:: C
120125 ordering:: O
121126 return_sequence:: R
127+ # FIXME : checkpointing is intentionally not documented.
128+ # See https://github.com/LuxDL/Lux.jl/pull/1561#issuecomment-3564283063
129+ checkpointing:: Bool
130+ mincut:: Bool
122131
123132 function Recurrence(
124- cell:: C , ordering:: AbstractTimeSeriesDataBatchOrdering , return_sequence:: R
133+ cell:: C ,
134+ ordering:: AbstractTimeSeriesDataBatchOrdering ,
135+ return_sequence:: R ,
136+ checkpointing:: Bool ,
137+ mincut:: Bool ,
125138 ) where {C,R}
126139 @assert cell isa Union{
127140 <: AbstractRecurrentCell ,
128141 <: Experimental.DebugLayer{<:Any,<:Any,<:AbstractRecurrentCell} ,
129142 }
130- return new{R,C,typeof(ordering)}(cell, ordering, return_sequence)
143+ return new{R,C,typeof(ordering)}(
144+ cell, ordering, return_sequence, checkpointing, mincut
145+ )
131146 end
132147end
133148
134149function Recurrence(
135150 cell;
136151 ordering:: AbstractTimeSeriesDataBatchOrdering = BatchLastIndex(),
137152 return_sequence:: Bool = false ,
153+ checkpointing:: Bool = false ,
154+ mincut:: Bool = false ,
138155)
139- return Recurrence(cell, ordering, static(return_sequence))
156+ return Recurrence(cell, ordering, static(return_sequence), checkpointing, mincut )
140157end
141158
142159function (r:: Recurrence )(x:: AbstractArray , ps, st:: NamedTuple )
@@ -233,6 +250,9 @@ function applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, carry)
233250end
234251applyrecurrentcell(l:: AbstractRecurrentCell , x, ps, st, :: Nothing ) = apply(l, x, ps, st)
235252
253+ # Used to construct the initial state of the recurrent cell
254+ function init_recurrent_state end
255+
236256@doc doc"""
237257 RNNCell(in_dims => out_dims, activation=tanh; use_bias=True(), train_state=False(),
238258 init_bias=nothing, init_weight=nothing, init_recurrent_weight=init_weight,
@@ -343,15 +363,20 @@ end
343363
344364initialstates(rng:: AbstractRNG , :: RNNCell ) = (rng= Utils. sample_replicate(rng),)
345365
346- function (rnn:: RNNCell{False} )( x:: AbstractMatrix , ps, st:: NamedTuple )
366+ function init_recurrent_state (rnn:: RNNCell{False} , x:: AbstractMatrix , ps, st:: NamedTuple )
347367 rng = replicate(st. rng)
348368 hidden_state = init_rnn_hidden_state(rng, rnn, x)
349- return rnn((x, ( hidden_state,)), ps, merge(st, (; rng) ))
369+ return ( hidden_state,), merge(st, (; rng))
350370end
351371
352- function (rnn :: RNNCell{True} )( x:: AbstractMatrix , ps, st:: NamedTuple )
372+ function init_recurrent_state( :: RNNCell{True} , x:: AbstractMatrix , ps, st:: NamedTuple )
353373 hidden_state = init_trainable_rnn_hidden_state(ps. hidden_state, x)
354- return rnn((x, (hidden_state,)), ps, st)
374+ return (hidden_state,), st
375+ end
376+
377+ function (rnn:: RNNCell )(x:: AbstractMatrix , ps, st:: NamedTuple )
378+ hidden_state, st = init_recurrent_state(rnn, x, ps, st)
379+ return rnn((x, hidden_state), ps, st)
355380end
356381
357382@trace function (rnn:: RNNCell )(
@@ -547,31 +572,42 @@ end
547572
548573initialstates(rng:: AbstractRNG , :: LSTMCell ) = (rng= Utils. sample_replicate(rng),)
549574
550- function (lstm:: LSTMCell{False,False} )(x:: AbstractMatrix , ps, st:: NamedTuple )
575+ function init_recurrent_state(
576+ lstm:: LSTMCell{False,False} , x:: AbstractMatrix , ps, st:: NamedTuple
577+ )
551578 rng = replicate(st. rng)
552579 hidden_state = init_rnn_hidden_state(rng, lstm, x)
553580 memory = init_rnn_hidden_state(rng, lstm, x)
554- return lstm((x, ( hidden_state, memory)), ps, merge(st, (; rng) ))
581+ return ( hidden_state, memory), merge(st, (; rng))
555582end
556583
557- function (lstm:: LSTMCell{True,False} )(x:: AbstractMatrix , ps, st:: NamedTuple )
584+ function init_recurrent_state(
585+ lstm:: LSTMCell{True,False} , x:: AbstractMatrix , ps, st:: NamedTuple
586+ )
558587 rng = replicate(st. rng)
559588 hidden_state = init_trainable_rnn_hidden_state(ps. hidden_state, x)
560589 memory = init_rnn_hidden_state(rng, lstm, x)
561- return lstm((x, ( hidden_state, memory)), ps, merge(st, (; rng) ))
590+ return ( hidden_state, memory), merge(st, (; rng))
562591end
563592
564- function (lstm:: LSTMCell{False,True} )(x:: AbstractMatrix , ps, st:: NamedTuple )
593+ function init_recurrent_state(
594+ lstm:: LSTMCell{False,True} , x:: AbstractMatrix , ps, st:: NamedTuple
595+ )
565596 rng = replicate(st. rng)
566597 hidden_state = init_rnn_hidden_state(rng, lstm, x)
567598 memory = init_trainable_rnn_hidden_state(ps. memory, x)
568- return lstm((x, ( hidden_state, memory)), ps, merge(st, (; rng) ))
599+ return ( hidden_state, memory), merge(st, (; rng))
569600end
570601
571- function (lstm :: LSTMCell{True,True} )( x:: AbstractMatrix , ps, st:: NamedTuple )
602+ function init_recurrent_state( :: LSTMCell{True,True} , x:: AbstractMatrix , ps, st:: NamedTuple )
572603 hidden_state = init_trainable_rnn_hidden_state(ps. hidden_state, x)
573604 memory = init_trainable_rnn_hidden_state(ps. memory, x)
574- return lstm((x, (hidden_state, memory)), ps, st)
605+ return (hidden_state, memory), st
606+ end
607+
608+ function (lstm:: LSTMCell )(x:: AbstractMatrix , ps, st:: NamedTuple )
609+ hidden_state, st = init_recurrent_state(lstm, x, ps, st)
610+ return lstm((x, hidden_state), ps, st)
575611end
576612
577613const _LSTMCellInputType = Tuple{<: AbstractMatrix ,Tuple{<: AbstractMatrix ,<: AbstractMatrix }}
@@ -744,16 +780,20 @@ end
744780
745781initialstates(rng:: AbstractRNG , :: GRUCell ) = (rng= Utils. sample_replicate(rng),)
746782
747- function (gru :: GRUCell{True} )( x:: AbstractMatrix , ps, st:: NamedTuple )
783+ function init_recurrent_state( :: GRUCell{True} , x:: AbstractMatrix , ps, st:: NamedTuple )
748784 hidden_state = init_trainable_rnn_hidden_state(ps. hidden_state, x)
749- return gru((x, ( hidden_state,)), ps, st)
785+ return ( hidden_state,), st
750786end
751787
752- function (gru:: GRUCell{False} )( x:: AbstractMatrix , ps, st:: NamedTuple )
788+ function init_recurrent_state (gru:: GRUCell{False} , x:: AbstractMatrix , ps, st:: NamedTuple )
753789 rng = replicate(st. rng)
754- st = merge(st, (; rng))
755790 hidden_state = init_rnn_hidden_state(rng, gru, x)
756- return gru((x, (hidden_state,)), ps, st)
791+ return (hidden_state,), merge(st, (; rng))
792+ end
793+
794+ function (gru:: GRUCell )(x:: AbstractMatrix , ps, st:: NamedTuple )
795+ hidden_state, st = init_recurrent_state(gru, x, ps, st)
796+ return gru((x, hidden_state), ps, st)
757797end
758798
759799const _GRUCellInputType = Tuple{<: AbstractMatrix ,Tuple{<: AbstractMatrix }}
0 commit comments