6969
7070@layer RNNCell
7171
72+ initialstates(rnn:: RNNCell ) = zeros_like(rnn. Wh, size(rnn. Wh, 2 ))
73+
7274function RNNCell(
7375 (in, out):: Pair ,
7476 σ = tanh;
@@ -82,7 +84,10 @@ function RNNCell(
8284 return RNNCell(σ, Wi, Wh, b)
8385end
8486
85- (m:: RNNCell )(x:: AbstractVecOrMat ) = m(x, zeros_like(x, size(m. Wh, 1 )))
87+ function (rnn:: RNNCell )(x:: AbstractVecOrMat )
88+ state = initialstates(rnn)
89+ rnn(x, state)
90+ end
8691
8792function (m:: RNNCell )(x:: AbstractVecOrMat , h:: AbstractVecOrMat )
8893 _size_check(m, x, 1 => size(m. Wi, 2 ))
261266
262267@layer LSTMCell
263268
269+ function initialstates(lstm:: LSTMCell )
270+ return zeros_like(lstm. Wh, size(lstm. Wh, 2 )), zeros_like(lstm. Wh, size(lstm. Wh, 2 ))
271+ end
272+
264273function LSTMCell(
265274 (in, out):: Pair ;
266275 init_kernel = glorot_uniform,
@@ -274,10 +283,9 @@ function LSTMCell(
274283 return cell
275284end
276285
277- function (m:: LSTMCell )(x:: AbstractVecOrMat )
278- h = zeros_like(x, size(m. Wh, 2 ))
279- c = zeros_like(h)
280- return m(x, (h, c))
286+ function (lstm:: LSTMCell )(x:: AbstractVecOrMat )
287+ state, cstate = initialstates(lstm)
288+ return lstm(x, (state, cstate))
281289end
282290
283291function (m:: LSTMCell )(x:: AbstractVecOrMat , (h, c))
447455
448456@layer GRUCell
449457
458+ initialstates(gru:: GRUCell ) = zeros_like(gru. Wh, size(gru. Wh, 2 ))
459+
450460function GRUCell(
451461 (in, out):: Pair ;
452462 init_kernel = glorot_uniform,
@@ -459,7 +469,10 @@ function GRUCell(
459469 return GRUCell(Wi, Wh, b)
460470end
461471
462- (m:: GRUCell )(x:: AbstractVecOrMat ) = m(x, zeros_like(x, size(m. Wh, 2 )))
472+ function (gru:: GRUCell )(x:: AbstractVecOrMat )
473+ state = initialstates(gru)
474+ return gru(x, state)
475+ end
463476
464477function (m:: GRUCell )(x:: AbstractVecOrMat , h)
465478 _size_check(m, x, 1 => size(m. Wi, 2 ))
603616
604617@layer GRUv3Cell
605618
619+ initialstates(gru:: GRUv3Cell ) = zeros_like(gru. Wh, size(gru. Wh, 2 ))
620+
606621function GRUv3Cell(
607622 (in, out):: Pair ;
608623 init_kernel = glorot_uniform,
@@ -616,7 +631,10 @@ function GRUv3Cell(
616631 return GRUv3Cell(Wi, Wh, b, Wh_h̃)
617632end
618633
619- (m:: GRUv3Cell )(x:: AbstractVecOrMat ) = m(x, zeros_like(x, size(m. Wh, 2 )))
634+ function (gru:: GRUv3Cell )(x:: AbstractVecOrMat )
635+ state = initialstates(gru)
636+ return gru(x, state)
637+ end
620638
621639function (m:: GRUv3Cell )(x:: AbstractVecOrMat , h)
622640 _size_check(m, x, 1 => size(m. Wi, 2 ))
0 commit comments