Skip to content

Commit 386bf30

Browse files
Document LSTMCell carry as the (c, h) tuple it actually is
The LSTMCell.__call__ docstring described 'carry' as 'the hidden state of the LSTM cell', leaving users to infer that it is actually a tuple of cell state and hidden state both of shape (*batch, features), typically created via LSTMCell.initialize_carry. Spell that contract out in both the Linen and NNX twins, and mirror the same docstring for OptimizedLSTMCell. Fixes #4124
1 parent 218c7ff commit 386bf30

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

flax/linen/recurrent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def __call__(self, carry, inputs):
136136
r"""A long short-term memory (LSTM) cell.
137137
138138
Args:
139-
carry: the hidden state of the LSTM cell,
139+
carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden
140+
state ``h``, both of shape ``(*batch, features)``. Typically
140141
initialized using ``LSTMCell.initialize_carry``.
141142
inputs: an ndarray with the input for the current time step.
142143
All dimensions except the final are considered batch dimensions.
@@ -285,8 +286,9 @@ def __call__(
285286
r"""An optimized long short-term memory (LSTM) cell.
286287
287288
Args:
288-
carry: the hidden state of the LSTM cell, initialized using
289-
``LSTMCell.initialize_carry``.
289+
carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden
290+
state ``h``, both of shape ``(*batch, features)``. Typically
291+
initialized using ``OptimizedLSTMCell.initialize_carry``.
290292
inputs: an ndarray with the input for the current time step. All
291293
dimensions except the final are considered batch dimensions.
292294

flax/nnx/nn/recurrent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def __call__(
201201
r"""A long short-term memory (LSTM) cell.
202202
203203
Args:
204-
carry: the hidden state of the LSTM cell,
204+
carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden
205+
state ``h``, both of shape ``(*batch, features)``. Typically
205206
initialized using ``LSTMCell.initialize_carry``.
206207
inputs: an ndarray with the input for the current time step.
207208
All dimensions except the final are considered batch dimensions.
@@ -382,8 +383,9 @@ def __call__(
382383
r"""An optimized long short-term memory (LSTM) cell.
383384
384385
Args:
385-
carry: the hidden state of the LSTM cell, initialized using
386-
``LSTMCell.initialize_carry``.
386+
carry: a tuple ``(c, h)`` of the cell state ``c`` and the hidden
387+
state ``h``, both of shape ``(*batch, features)``. Typically
388+
initialized using ``OptimizedLSTMCell.initialize_carry``.
387389
inputs: an ndarray with the input for the current time step.
388390
All dimensions except the final are considered batch dimensions.
389391

0 commit comments

Comments
 (0)