25
25
import jax .numpy as jnp
26
26
import numpy as np
27
27
28
- Array = Any
29
- PRNGKey = Any
28
+ Array = jax . Array
29
+ PRNGKey = jax . random . KeyArray
30
30
31
31
32
- class EncoderLSTM (nn .Module ):
33
- """EncoderLSTM Module wrapped in a lifted scan transform."""
34
- eos_id : int
35
-
36
- @functools .partial (
37
- nn .scan ,
38
- variable_broadcast = 'params' ,
39
- in_axes = 1 ,
40
- out_axes = 1 ,
41
- split_rngs = {'params' : False })
42
- @nn .compact
43
- def __call__ (self , carry : Tuple [Array , Array ],
44
- x : Array ) -> Tuple [Tuple [Array , Array ], Array ]:
45
- """Applies the module."""
46
- lstm_state , is_eos = carry
47
- new_lstm_state , y = nn .LSTMCell ()(lstm_state , x )
48
- # Pass forward the previous state if EOS has already been reached.
49
- def select_carried_state (new_state , old_state ):
50
- return jnp .where (is_eos [:, np .newaxis ], old_state , new_state )
51
- # LSTM state is a tuple (c, h).
52
- carried_lstm_state = tuple (
53
- select_carried_state (* s ) for s in zip (new_lstm_state , lstm_state ))
54
- # Update `is_eos`.
55
- is_eos = jnp .logical_or (is_eos , x [:, self .eos_id ])
56
- return (carried_lstm_state , is_eos ), y
57
-
58
- @staticmethod
59
- def initialize_carry (batch_size : int , hidden_size : int ):
60
- # Use a dummy key since the default state init fn is just zeros.
61
- return nn .LSTMCell .initialize_carry (
62
- jax .random .PRNGKey (0 ), (batch_size ,), hidden_size )
63
-
64
-
65
- class Encoder (nn .Module ):
66
- """LSTM encoder, returning state after finding the EOS token in the input."""
67
- hidden_size : int
68
- eos_id : int
69
-
70
- @nn .compact
71
- def __call__ (self , inputs : Array ):
72
- # inputs.shape = (batch_size, seq_length, vocab_size).
73
- batch_size = inputs .shape [0 ]
74
- lstm = EncoderLSTM (name = 'encoder_lstm' , eos_id = self .eos_id )
75
- init_lstm_state = lstm .initialize_carry (batch_size , self .hidden_size )
76
- # We use the `is_eos` array to determine whether the encoder should carry
77
- # over the last lstm state, or apply the LSTM cell on the previous state.
78
- init_is_eos = jnp .zeros (batch_size , dtype = bool )
79
- init_carry = (init_lstm_state , init_is_eos )
80
- (final_state , _ ), _ = lstm (init_carry , inputs )
81
- return final_state
82
-
83
-
84
- class DecoderLSTM (nn .Module ):
32
+ class DecoderLSTMCell (nn .RNNCellBase ):
85
33
"""DecoderLSTM Module wrapped in a lifted scan transform.
86
-
87
34
Attributes:
88
35
teacher_force: See docstring on Seq2seq module.
89
36
vocab_size: Size of the vocabulary.
90
37
"""
91
38
teacher_force : bool
92
39
vocab_size : int
93
40
94
- @functools .partial (
95
- nn .scan ,
96
- variable_broadcast = 'params' ,
97
- in_axes = 1 ,
98
- out_axes = 1 ,
99
- split_rngs = {'params' : False , 'lstm' : True })
100
41
@nn .compact
101
42
def __call__ (self , carry : Tuple [Array , Array ], x : Array ) -> Array :
102
43
"""Applies the DecoderLSTM model."""
@@ -116,40 +57,6 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
116
57
return (lstm_state , prediction ), (logits , prediction )
117
58
118
59
119
- class Decoder (nn .Module ):
120
- """LSTM decoder.
121
-
122
- Attributes:
123
- init_state: [batch_size, hidden_size]
124
- Initial state of the decoder (i.e., the final state of the encoder).
125
- teacher_force: See docstring on Seq2seq module.
126
- vocab_size: Size of the vocabulary.
127
- """
128
- init_state : Tuple [Any ]
129
- teacher_force : bool
130
- vocab_size : int
131
-
132
- @nn .compact
133
- def __call__ (self , inputs : Array ) -> Tuple [Array , Array ]:
134
- """Applies the decoder model.
135
-
136
- Args:
137
- inputs: [batch_size, max_output_len-1, vocab_size]
138
- Contains the inputs to the decoder at each time step (only used when not
139
- using teacher forcing). Since each token at position i is fed as input
140
- to the decoder at position i+1, the last token is not provided.
141
-
142
- Returns:
143
- Pair (logits, predictions), which are two arrays of respectively decoded
144
- logits and predictions (in one hot-encoding format).
145
- """
146
- lstm = DecoderLSTM (teacher_force = self .teacher_force ,
147
- vocab_size = self .vocab_size )
148
- init_carry = (self .init_state , inputs [:, 0 ])
149
- _ , (logits , predictions ) = lstm (init_carry , inputs )
150
- return logits , predictions
151
-
152
-
153
60
class Seq2seq (nn .Module ):
154
61
"""Sequence-to-sequence class using encoder/decoder architecture.
155
62
@@ -189,12 +96,25 @@ def __call__(self, encoder_inputs: Array,
189
96
encoding format).
190
97
"""
191
98
# Encode inputs.
192
- init_decoder_state = Encoder (
193
- hidden_size = self .hidden_size , eos_id = self .eos_id )(encoder_inputs )
194
- # Decode outputs.
195
- logits , predictions = Decoder (
196
- init_state = init_decoder_state ,
197
- teacher_force = self .teacher_force ,
198
- vocab_size = self .vocab_size )(decoder_inputs [:, :- 1 ])
99
+ encoder = nn .RNN (nn .LSTMCell (), self .hidden_size , return_carry = True , name = 'encoder' )
100
+ decoder = nn .RNN (DecoderLSTMCell (self .teacher_force , self .vocab_size ), decoder_inputs .shape [- 1 ],
101
+ split_rngs = {'params' : False , 'lstm' : True }, name = 'decoder' )
102
+
103
+ segmentation_mask = self .get_segmentation_mask (encoder_inputs )
104
+
105
+ encoder_state , _ = encoder (encoder_inputs , segmentation_mask = segmentation_mask )
106
+ logits , predictions = decoder (decoder_inputs [:, :- 1 ], initial_carry = (encoder_state , decoder_inputs [:, 0 ]))
199
107
200
108
return logits , predictions
109
+
110
+ def get_segmentation_mask (self , inputs : Array ) -> Array :
111
+ """Get segmentation mask for inputs."""
112
+ # undo one-hot encoding
113
+ inputs = jnp .argmax (inputs , axis = - 1 )
114
+ # calculate eos index
115
+ eos_idx = jnp .argmax (inputs == self .eos_id , axis = - 1 , keepdims = True )
116
+ # create index array
117
+ indexes = jnp .arange (inputs .shape [1 ])
118
+ indexes = jnp .broadcast_to (indexes , inputs .shape [:2 ])
119
+ # return mask
120
+ return indexes < eos_idx
0 commit comments