Skip to content

Commit 10a2123

Browse files
author
Flax Authors
committed
Merge pull request #2604 from cgarciae:rnn-v2
PiperOrigin-RevId: 512990454
2 parents f9dab0a + 66dec51 commit 10a2123

File tree

5 files changed

+587
-105
lines changed

5 files changed

+587
-105
lines changed

docs/api_reference/flax.linen.rst

+2
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,5 @@ RNN primitives
268268
LSTMCell
269269
OptimizedLSTMCell
270270
GRUCell
271+
RNNCellBase
272+
RNN

examples/seq2seq/models.py

+23-103
Original file line numberDiff line numberDiff line change
@@ -25,78 +25,19 @@
2525
import jax.numpy as jnp
2626
import numpy as np
2727

28-
Array = Any
29-
PRNGKey = Any
28+
Array = jax.Array
29+
PRNGKey = jax.random.KeyArray
3030

3131

32-
class EncoderLSTM(nn.Module):
33-
"""EncoderLSTM Module wrapped in a lifted scan transform."""
34-
eos_id: int
35-
36-
@functools.partial(
37-
nn.scan,
38-
variable_broadcast='params',
39-
in_axes=1,
40-
out_axes=1,
41-
split_rngs={'params': False})
42-
@nn.compact
43-
def __call__(self, carry: Tuple[Array, Array],
44-
x: Array) -> Tuple[Tuple[Array, Array], Array]:
45-
"""Applies the module."""
46-
lstm_state, is_eos = carry
47-
new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
48-
# Pass forward the previous state if EOS has already been reached.
49-
def select_carried_state(new_state, old_state):
50-
return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
51-
# LSTM state is a tuple (c, h).
52-
carried_lstm_state = tuple(
53-
select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
54-
# Update `is_eos`.
55-
is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
56-
return (carried_lstm_state, is_eos), y
57-
58-
@staticmethod
59-
def initialize_carry(batch_size: int, hidden_size: int):
60-
# Use a dummy key since the default state init fn is just zeros.
61-
return nn.LSTMCell.initialize_carry(
62-
jax.random.PRNGKey(0), (batch_size,), hidden_size)
63-
64-
65-
class Encoder(nn.Module):
66-
"""LSTM encoder, returning state after finding the EOS token in the input."""
67-
hidden_size: int
68-
eos_id: int
69-
70-
@nn.compact
71-
def __call__(self, inputs: Array):
72-
# inputs.shape = (batch_size, seq_length, vocab_size).
73-
batch_size = inputs.shape[0]
74-
lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
75-
init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
76-
# We use the `is_eos` array to determine whether the encoder should carry
77-
# over the last lstm state, or apply the LSTM cell on the previous state.
78-
init_is_eos = jnp.zeros(batch_size, dtype=bool)
79-
init_carry = (init_lstm_state, init_is_eos)
80-
(final_state, _), _ = lstm(init_carry, inputs)
81-
return final_state
82-
83-
84-
class DecoderLSTM(nn.Module):
32+
class DecoderLSTMCell(nn.RNNCellBase):
8533
"""DecoderLSTM Module wrapped in a lifted scan transform.
86-
8734
Attributes:
8835
teacher_force: See docstring on Seq2seq module.
8936
vocab_size: Size of the vocabulary.
9037
"""
9138
teacher_force: bool
9239
vocab_size: int
9340

94-
@functools.partial(
95-
nn.scan,
96-
variable_broadcast='params',
97-
in_axes=1,
98-
out_axes=1,
99-
split_rngs={'params': False, 'lstm': True})
10041
@nn.compact
10142
def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
10243
"""Applies the DecoderLSTM model."""
@@ -116,40 +57,6 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
11657
return (lstm_state, prediction), (logits, prediction)
11758

11859

119-
class Decoder(nn.Module):
120-
"""LSTM decoder.
121-
122-
Attributes:
123-
init_state: [batch_size, hidden_size]
124-
Initial state of the decoder (i.e., the final state of the encoder).
125-
teacher_force: See docstring on Seq2seq module.
126-
vocab_size: Size of the vocabulary.
127-
"""
128-
init_state: Tuple[Any]
129-
teacher_force: bool
130-
vocab_size: int
131-
132-
@nn.compact
133-
def __call__(self, inputs: Array) -> Tuple[Array, Array]:
134-
"""Applies the decoder model.
135-
136-
Args:
137-
inputs: [batch_size, max_output_len-1, vocab_size]
138-
Contains the inputs to the decoder at each time step (only used when not
139-
using teacher forcing). Since each token at position i is fed as input
140-
to the decoder at position i+1, the last token is not provided.
141-
142-
Returns:
143-
Pair (logits, predictions), which are two arrays of respectively decoded
144-
logits and predictions (in one hot-encoding format).
145-
"""
146-
lstm = DecoderLSTM(teacher_force=self.teacher_force,
147-
vocab_size=self.vocab_size)
148-
init_carry = (self.init_state, inputs[:, 0])
149-
_, (logits, predictions) = lstm(init_carry, inputs)
150-
return logits, predictions
151-
152-
15360
class Seq2seq(nn.Module):
15461
"""Sequence-to-sequence class using encoder/decoder architecture.
15562
@@ -189,12 +96,25 @@ def __call__(self, encoder_inputs: Array,
18996
encoding format).
19097
"""
19198
# Encode inputs.
192-
init_decoder_state = Encoder(
193-
hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs)
194-
# Decode outputs.
195-
logits, predictions = Decoder(
196-
init_state=init_decoder_state,
197-
teacher_force=self.teacher_force,
198-
vocab_size=self.vocab_size)(decoder_inputs[:, :-1])
99+
encoder = nn.RNN(nn.LSTMCell(), self.hidden_size, return_carry=True, name='encoder')
100+
decoder = nn.RNN(DecoderLSTMCell(self.teacher_force, self.vocab_size), decoder_inputs.shape[-1],
101+
split_rngs={'params': False, 'lstm': True}, name='decoder')
102+
103+
segmentation_mask = self.get_segmentation_mask(encoder_inputs)
104+
105+
encoder_state, _ = encoder(encoder_inputs, segmentation_mask=segmentation_mask)
106+
logits, predictions = decoder(decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0]))
199107

200108
return logits, predictions
109+
110+
def get_segmentation_mask(self, inputs: Array) -> Array:
111+
"""Get segmentation mask for inputs."""
112+
# undo one-hot encoding
113+
inputs = jnp.argmax(inputs, axis=-1)
114+
# calculate eos index
115+
eos_idx = jnp.argmax(inputs == self.eos_id, axis=-1, keepdims=True)
116+
# create index array
117+
indexes = jnp.arange(inputs.shape[1])
118+
indexes = jnp.broadcast_to(indexes, inputs.shape[:2])
119+
# return mask
120+
return indexes < eos_idx

flax/linen/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@
120120
ConvLSTMCell as ConvLSTMCell,
121121
GRUCell as GRUCell,
122122
LSTMCell as LSTMCell,
123-
OptimizedLSTMCell as OptimizedLSTMCell
123+
OptimizedLSTMCell as OptimizedLSTMCell,
124+
RNNCellBase as RNNCellBase,
125+
RNN as RNN,
124126
)
125127
from .stochastic import Dropout as Dropout
126128
from .transforms import (

0 commit comments

Comments
 (0)