Skip to content

Commit 9e1b1bb

Browse files
added initialstates
1 parent 74c3a63 commit 9e1b1bb

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

src/layers/recurrent.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ end
6969

7070
@layer RNNCell
7171

72+
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
73+
7274
function RNNCell(
7375
(in, out)::Pair,
7476
σ = tanh;
@@ -82,7 +84,10 @@ function RNNCell(
8284
return RNNCell(σ, Wi, Wh, b)
8385
end
8486

85-
(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
87+
function (rnn::RNNCell)(x::AbstractVecOrMat)
88+
state = initialstates(rnn)
89+
rnn(x, state)
90+
end
8691

8792
function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
8893
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -261,6 +266,10 @@ end
261266

262267
@layer LSTMCell
263268

269+
function initialstates(lstm:: LSTMCell)
270+
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
271+
end
272+
264273
function LSTMCell(
265274
(in, out)::Pair;
266275
init_kernel = glorot_uniform,
@@ -274,10 +283,9 @@ function LSTMCell(
274283
return cell
275284
end
276285

277-
function (m::LSTMCell)(x::AbstractVecOrMat)
278-
h = zeros_like(x, size(m.Wh, 2))
279-
c = zeros_like(h)
280-
return m(x, (h, c))
286+
function (lstm::LSTMCell)(x::AbstractVecOrMat)
287+
state, cstate = initialstates(lstm)
288+
return lstm(x, (state, cstate))
281289
end
282290

283291
function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
@@ -447,6 +455,8 @@ end
447455

448456
@layer GRUCell
449457

458+
initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))
459+
450460
function GRUCell(
451461
(in, out)::Pair;
452462
init_kernel = glorot_uniform,
@@ -459,7 +469,10 @@ function GRUCell(
459469
return GRUCell(Wi, Wh, b)
460470
end
461471

462-
(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
472+
function (gru::GRUCell)(x::AbstractVecOrMat)
473+
state = initialstates(gru)
474+
return gru(x, state)
475+
end
463476

464477
function (m::GRUCell)(x::AbstractVecOrMat, h)
465478
_size_check(m, x, 1 => size(m.Wi, 2))
@@ -603,6 +616,8 @@ end
603616

604617
@layer GRUv3Cell
605618

619+
initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))
620+
606621
function GRUv3Cell(
607622
(in, out)::Pair;
608623
init_kernel = glorot_uniform,
@@ -616,7 +631,10 @@ function GRUv3Cell(
616631
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
617632
end
618633

619-
(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
634+
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
635+
state = initialstates(gru)
636+
return gru(x, state)
637+
end
620638

621639
function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
622640
_size_check(m, x, 1 => size(m.Wi, 2))

0 commit comments

Comments
 (0)