@@ -116,7 +116,7 @@ def __init__(
116116 mlp_dims : Optional [int ] = None ,
117117 dropout : float = 0.0 ,
118118 activation : Callable [[Any ], Any ] = relu ,
119- norm_first : bool = False ,
119+ norm_first : bool = True ,
120120 ):
121121 super ().__init__ ()
122122 mlp_dims = mlp_dims or dims * 4
@@ -167,7 +167,7 @@ def __init__(
167167 mlp_dims : Optional [int ] = None ,
168168 dropout : float = 0.0 ,
169169 activation = relu ,
170- norm_first : bool = False ,
170+ norm_first : bool = True ,
171171 checkpoint : bool = False ,
172172 ):
173173 super ().__init__ ()
@@ -182,10 +182,8 @@ def __init__(
182182
183183 def __call__ (self , x , mask ):
184184 for l in self .layers :
185- if self .checkpoint :
186- x = checkpoint (l )(x , mask )
187- else :
188- x = l (x , mask )
185+ l = checkpoint (l ) if self .checkpoint else l
186+ x = l (x , mask )
189187 return self .ln (x )
190188
191189
@@ -197,7 +195,7 @@ def __init__(
197195 mlp_dims : Optional [int ] = None ,
198196 dropout : float = 0.0 ,
199197 activation : Callable [[Any ], Any ] = relu ,
200- norm_first : bool = False ,
198+ norm_first : bool = True ,
201199 ):
202200 super ().__init__ ()
203201 mlp_dims = mlp_dims or dims * 4
@@ -260,7 +258,7 @@ def __init__(
260258 mlp_dims : Optional [int ] = None ,
261259 dropout : float = 0.0 ,
262260 activation = relu ,
263- norm_first : bool = False ,
261+ norm_first : bool = True ,
264262 checkpoint : bool = False ,
265263 ):
266264 super ().__init__ ()
@@ -275,10 +273,8 @@ def __init__(
275273
276274 def __call__ (self , x , memory , x_mask , memory_mask ):
277275 for l in self .layers :
278- if self .checkpoint :
279- x = checkpoint (l )(x , memory , x_mask , memory_mask )
280- else :
281- x = l (x , memory , x_mask , memory_mask )
276+ l = checkpoint (l ) if self .checkpoint else l
277+ x = l (x , memory , x_mask , memory_mask )
282278 return self .ln (x )
283279
284280
@@ -317,7 +313,7 @@ class Transformer(Module):
317313 standard Transformer decoder. Default: ``None``.
318314 norm_first (bool, optional): if ``True``, encoder and decoder layers
319315 will perform layer normalization before attention and MLP
320- operations, otherwise after. Default: ``False ``.
316+ operations, otherwise after. Default: ``True ``.
321317 chekpoint (bool, optional): if ``True`` perform gradient checkpointing
322318 to reduce the memory usage at the expense of more computation.
323319 Default: ``False``.
@@ -334,37 +330,32 @@ def __init__(
334330 activation : Callable [[Any ], Any ] = relu ,
335331 custom_encoder : Optional [Any ] = None ,
336332 custom_decoder : Optional [Any ] = None ,
337- norm_first : bool = False ,
333+ norm_first : bool = True ,
338334 checkpoint : bool = False ,
339335 ):
340336 super ().__init__ ()
341- if custom_encoder is not None :
342- self .encoder = custom_encoder
343- else :
344- self .encoder = TransformerEncoder (
345- num_encoder_layers ,
346- dims ,
347- num_heads ,
348- mlp_dims ,
349- dropout ,
350- activation ,
351- norm_first ,
352- checkpoint ,
353- )
354337
355- if custom_decoder is not None :
356- self .decoder = custom_decoder
357- else :
358- self .decoder = TransformerDecoder (
359- num_decoder_layers ,
360- dims ,
361- num_heads ,
362- mlp_dims ,
363- dropout ,
364- activation ,
365- norm_first ,
366- checkpoint ,
367- )
338+ self .encoder = custom_encoder or TransformerEncoder (
339+ num_encoder_layers ,
340+ dims ,
341+ num_heads ,
342+ mlp_dims ,
343+ dropout ,
344+ activation ,
345+ norm_first ,
346+ checkpoint ,
347+ )
348+
349+ self .decoder = custom_decoder or TransformerDecoder (
350+ num_decoder_layers ,
351+ dims ,
352+ num_heads ,
353+ mlp_dims ,
354+ dropout ,
355+ activation ,
356+ norm_first ,
357+ checkpoint ,
358+ )
368359
369360 def __call__ (self , src , tgt , src_mask , tgt_mask , memory_mask ):
370361 memory = self .encoder (src , src_mask )
0 commit comments