77import torch .nn as tnn
88import torch .nn .utils .rnn as tnnur
99
10- from reinvent_models .lib_invent .enums .generative_model_parameters import (
11- GenerativeModelParametersEnum ,
12- )
10+ from reinvent_models .lib_invent .enums .generative_model_parameters import GenerativeModelParametersEnum
1311
1412
1513class Encoder (tnn .Module ):
@@ -27,16 +25,10 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout):
2725
2826 self ._embedding = tnn .Sequential (
2927 tnn .Embedding (self .vocabulary_size , self .num_dimensions ),
30- tnn .Dropout (dropout ),
31- )
32- self ._rnn = tnn .LSTM (
33- self .num_dimensions ,
34- self .num_dimensions ,
35- self .num_layers ,
36- batch_first = True ,
37- dropout = self .dropout ,
38- bidirectional = True ,
28+ tnn .Dropout (dropout )
3929 )
30+ self ._rnn = tnn .LSTM (self .num_dimensions , self .num_dimensions , self .num_layers ,
31+ batch_first = True , dropout = self .dropout , bidirectional = True )
4032
4133 def forward (self , padded_seqs , seq_lengths ): # pylint: disable=arguments-differ
4234 # FIXME: This fails with a batch of 1 because squeezing looses a dimension with size 1
@@ -53,40 +45,26 @@ def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ
5345 padded_seqs = self ._embedding (padded_seqs )
5446 hs_h , hs_c = (hidden_state , hidden_state .clone ().detach ())
5547
56- # FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7
48+ #FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7
5749 seq_lengths = seq_lengths .cpu ()
5850
59- packed_seqs = tnnur .pack_padded_sequence (
60- padded_seqs , seq_lengths , batch_first = True , enforce_sorted = False
61- )
51+ packed_seqs = tnnur .pack_padded_sequence (padded_seqs , seq_lengths , batch_first = True , enforce_sorted = False )
6252 packed_seqs , (hs_h , hs_c ) = self ._rnn (packed_seqs , (hs_h , hs_c ))
6353 padded_seqs , _ = tnnur .pad_packed_sequence (packed_seqs , batch_first = True )
6454
6555 # sum up bidirectional layers and collapse
66- hs_h = (
67- hs_h .view (self .num_layers , 2 , batch_size , self .num_dimensions )
68- .sum (dim = 1 )
69- .squeeze ()
70- ) # (layers, batch, dim)
71- hs_c = (
72- hs_c .view (self .num_layers , 2 , batch_size , self .num_dimensions )
73- .sum (dim = 1 )
74- .squeeze ()
75- ) # (layers, batch, dim)
76- padded_seqs = (
77- padded_seqs .view (batch_size , max_seq_size , 2 , self .num_dimensions )
78- .sum (dim = 2 )
79- .squeeze ()
80- ) # (batch, seq, dim)
56+ hs_h = hs_h .view (self .num_layers , 2 , batch_size , self .num_dimensions )\
57+ .sum (dim = 1 ).squeeze () # (layers, batch, dim)
58+ hs_c = hs_c .view (self .num_layers , 2 , batch_size , self .num_dimensions )\
59+ .sum (dim = 1 ).squeeze () # (layers, batch, dim)
60+ padded_seqs = padded_seqs .view (batch_size , max_seq_size , 2 , self .num_dimensions )\
61+ .sum (dim = 2 ).squeeze () # (batch, seq, dim)
8162
8263 return padded_seqs , (hs_h , hs_c )
8364
8465 def _initialize_hidden_state (self , batch_size ):
8566 if torch .cuda .is_available ():
86- return torch .zeros (
87- self .num_layers * 2 , batch_size , self .num_dimensions
88- ).cuda ()
89- return torch .zeros (self .num_layers * 2 , batch_size , self .num_dimensions )
67+ return torch .zeros (self .num_layers * 2 , batch_size , self .num_dimensions ).cuda ()
9068
9169 def get_params (self ):
9270 parameter_enums = GenerativeModelParametersEnum
@@ -98,23 +76,23 @@ def get_params(self):
9876 parameter_enums .NUMBER_OF_LAYERS : self .num_layers ,
9977 parameter_enums .NUMBER_OF_DIMENSIONS : self .num_dimensions ,
10078 parameter_enums .VOCABULARY_SIZE : self .vocabulary_size ,
101- parameter_enums .DROPOUT : self .dropout ,
79+ parameter_enums .DROPOUT : self .dropout
10280 }
10381
10482
10583class AttentionLayer (tnn .Module ):
84+
10685 def __init__ (self , num_dimensions ):
10786 super (AttentionLayer , self ).__init__ ()
10887
10988 self .num_dimensions = num_dimensions
11089
11190 self ._attention_linear = tnn .Sequential (
112- tnn .Linear (self .num_dimensions * 2 , self .num_dimensions ), tnn .Tanh ()
91+ tnn .Linear (self .num_dimensions * 2 , self .num_dimensions ),
92+ tnn .Tanh ()
11393 )
11494
115- def forward (
116- self , padded_seqs , encoder_padded_seqs , decoder_mask
117- ): # pylint: disable=arguments-differ
95+ def forward (self , padded_seqs , encoder_padded_seqs , decoder_mask ): # pylint: disable=arguments-differ
11896 """
11997 Performs the forward pass.
12098 :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -124,19 +102,12 @@ def forward(
124102 """
125103 # scaled dot-product
126104 # (batch, seq_d, 1, dim)*(batch, 1, seq_e, dim) => (batch, seq_d, seq_e*)
127- attention_weights = (
128- (padded_seqs .unsqueeze (dim = 2 ) * encoder_padded_seqs .unsqueeze (dim = 1 ))
129- .sum (dim = 3 )
130- .div (math .sqrt (self .num_dimensions ))
105+ attention_weights = (padded_seqs .unsqueeze (dim = 2 )* encoder_padded_seqs .unsqueeze (dim = 1 ))\
106+ .sum (dim = 3 ).div (math .sqrt (self .num_dimensions ))\
131107 .softmax (dim = 2 )
132- )
133108 # (batch, seq_d, seq_e*)@(batch, seq_e, dim) => (batch, seq_d, dim)
134109 attention_context = attention_weights .bmm (encoder_padded_seqs )
135- return (
136- self ._attention_linear (torch .cat ([padded_seqs , attention_context ], dim = 2 ))
137- * decoder_mask ,
138- attention_weights ,
139- )
110+ return (self ._attention_linear (torch .cat ([padded_seqs , attention_context ], dim = 2 ))* decoder_mask , attention_weights )
140111
141112
142113class Decoder (tnn .Module ):
@@ -154,26 +125,16 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout):
154125
155126 self ._embedding = tnn .Sequential (
156127 tnn .Embedding (self .vocabulary_size , self .num_dimensions ),
157- tnn .Dropout (dropout ),
158- )
159- self ._rnn = tnn .LSTM (
160- self .num_dimensions ,
161- self .num_dimensions ,
162- self .num_layers ,
163- batch_first = True ,
164- dropout = self .dropout ,
165- bidirectional = False ,
128+ tnn .Dropout (dropout )
166129 )
130+ self ._rnn = tnn .LSTM (self .num_dimensions , self .num_dimensions , self .num_layers ,
131+ batch_first = True , dropout = self .dropout , bidirectional = False )
167132
168133 self ._attention = AttentionLayer (self .num_dimensions )
169134
170- self ._linear = tnn .Linear (
171- self .num_dimensions , self .vocabulary_size
172- ) # just to redimension
135+ self ._linear = tnn .Linear (self .num_dimensions , self .vocabulary_size ) # just to redimension
173136
174- def forward (
175- self , padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states
176- ): # pylint: disable=arguments-differ
137+ def forward (self , padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states ): # pylint: disable=arguments-differ
177138 """
178139 Performs the forward pass.
179140 :param padded_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -187,20 +148,13 @@ def forward(
187148
188149 padded_encoded_seqs = self ._embedding (padded_seqs )
189150 packed_encoded_seqs = tnnur .pack_padded_sequence (
190- padded_encoded_seqs , seq_lengths , batch_first = True , enforce_sorted = False
191- )
192- packed_encoded_seqs , hidden_states = self ._rnn (
193- packed_encoded_seqs , hidden_states
194- )
195- padded_encoded_seqs , _ = tnnur .pad_packed_sequence (
196- packed_encoded_seqs , batch_first = True
197- ) # (batch, seq, dim)
151+ padded_encoded_seqs , seq_lengths , batch_first = True , enforce_sorted = False )
152+ packed_encoded_seqs , hidden_states = self ._rnn (packed_encoded_seqs , hidden_states )
153+ padded_encoded_seqs , _ = tnnur .pad_packed_sequence (packed_encoded_seqs , batch_first = True ) # (batch, seq, dim)
198154
199155 mask = (padded_encoded_seqs [:, :, 0 ] != 0 ).unsqueeze (dim = - 1 ).type (torch .float )
200- attn_padded_encoded_seqs , attention_weights = self ._attention (
201- padded_encoded_seqs , encoder_padded_seqs , mask
202- )
203- logits = self ._linear (attn_padded_encoded_seqs ) * mask # (batch, seq, voc_size)
156+ attn_padded_encoded_seqs , attention_weights = self ._attention (padded_encoded_seqs , encoder_padded_seqs , mask )
157+ logits = self ._linear (attn_padded_encoded_seqs )* mask # (batch, seq, voc_size)
204158 return logits , hidden_states , attention_weights
205159
206160 def get_params (self ):
@@ -213,7 +167,7 @@ def get_params(self):
213167 parameter_enum .NUMBER_OF_LAYERS : self .num_layers ,
214168 parameter_enum .NUMBER_OF_DIMENSIONS : self .num_dimensions ,
215169 parameter_enum .VOCABULARY_SIZE : self .vocabulary_size ,
216- parameter_enum .DROPOUT : self .dropout ,
170+ parameter_enum .DROPOUT : self .dropout
217171 }
218172
219173
@@ -228,9 +182,7 @@ def __init__(self, encoder_params, decoder_params):
228182 self ._encoder = Encoder (** encoder_params )
229183 self ._decoder = Decoder (** decoder_params )
230184
231- def forward (
232- self , encoder_seqs , encoder_seq_lengths , decoder_seqs , decoder_seq_lengths
233- ): # pylint: disable=arguments-differ
185+ def forward (self , encoder_seqs , encoder_seq_lengths , decoder_seqs , decoder_seq_lengths ): # pylint: disable=arguments-differ
234186 """
235187 Performs the forward pass.
236188 :param encoder_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -239,12 +191,8 @@ def forward(
239191 :param decoder_seq_lengths: The lengths of the decoder sequences.
240192 :return : The output logits as a tensor (batch, seq_d, dim).
241193 """
242- encoder_padded_seqs , hidden_states = self .forward_encoder (
243- encoder_seqs , encoder_seq_lengths
244- )
245- logits , _ , _ = self .forward_decoder (
246- decoder_seqs , decoder_seq_lengths , encoder_padded_seqs , hidden_states
247- )
194+ encoder_padded_seqs , hidden_states = self .forward_encoder (encoder_seqs , encoder_seq_lengths )
195+ logits , _ , _ = self .forward_decoder (decoder_seqs , decoder_seq_lengths , encoder_padded_seqs , hidden_states )
248196 return logits
249197
250198 def forward_encoder (self , padded_seqs , seq_lengths ):
@@ -256,19 +204,15 @@ def forward_encoder(self, padded_seqs, seq_lengths):
256204 """
257205 return self ._encoder (padded_seqs , seq_lengths )
258206
259- def forward_decoder (
260- self , padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states
261- ):
207+ def forward_decoder (self , padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states ):
262208 """
263209 Does a forward pass only of the decoder.
264210 :param hidden_states: The hidden states from the encoder.
265211 :param padded_seqs: The data to feed to the decoder.
266212 :param seq_lengths: The length of each sequence in the batch.
267213 :return : Returns the logits and the hidden state for each element of the sequence passed.
268214 """
269- return self ._decoder (
270- padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states
271- )
215+ return self ._decoder (padded_seqs , seq_lengths , encoder_padded_seqs , hidden_states )
272216
273217 def get_params (self ):
274218 """
@@ -277,5 +221,5 @@ def get_params(self):
277221 """
278222 return {
279223 "encoder_params" : self ._encoder .get_params (),
280- "decoder_params" : self ._decoder .get_params (),
224+ "decoder_params" : self ._decoder .get_params ()
281225 }
0 commit comments