Skip to content

Commit ba8d6bf

Browse files
authored
Change the transformer to norm_first by default (#599)
1 parent 4a5f3b2 commit ba8d6bf

File tree

1 file changed

+31
-40
lines changed

1 file changed

+31
-40
lines changed

python/mlx/nn/layers/transformer.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
mlp_dims: Optional[int] = None,
117117
dropout: float = 0.0,
118118
activation: Callable[[Any], Any] = relu,
119-
norm_first: bool = False,
119+
norm_first: bool = True,
120120
):
121121
super().__init__()
122122
mlp_dims = mlp_dims or dims * 4
@@ -167,7 +167,7 @@ def __init__(
167167
mlp_dims: Optional[int] = None,
168168
dropout: float = 0.0,
169169
activation=relu,
170-
norm_first: bool = False,
170+
norm_first: bool = True,
171171
checkpoint: bool = False,
172172
):
173173
super().__init__()
@@ -182,10 +182,8 @@ def __init__(
182182

183183
def __call__(self, x, mask):
184184
for l in self.layers:
185-
if self.checkpoint:
186-
x = checkpoint(l)(x, mask)
187-
else:
188-
x = l(x, mask)
185+
l = checkpoint(l) if self.checkpoint else l
186+
x = l(x, mask)
189187
return self.ln(x)
190188

191189

@@ -197,7 +195,7 @@ def __init__(
197195
mlp_dims: Optional[int] = None,
198196
dropout: float = 0.0,
199197
activation: Callable[[Any], Any] = relu,
200-
norm_first: bool = False,
198+
norm_first: bool = True,
201199
):
202200
super().__init__()
203201
mlp_dims = mlp_dims or dims * 4
@@ -260,7 +258,7 @@ def __init__(
260258
mlp_dims: Optional[int] = None,
261259
dropout: float = 0.0,
262260
activation=relu,
263-
norm_first: bool = False,
261+
norm_first: bool = True,
264262
checkpoint: bool = False,
265263
):
266264
super().__init__()
@@ -275,10 +273,8 @@ def __init__(
275273

276274
def __call__(self, x, memory, x_mask, memory_mask):
277275
for l in self.layers:
278-
if self.checkpoint:
279-
x = checkpoint(l)(x, memory, x_mask, memory_mask)
280-
else:
281-
x = l(x, memory, x_mask, memory_mask)
276+
l = checkpoint(l) if self.checkpoint else l
277+
x = l(x, memory, x_mask, memory_mask)
282278
return self.ln(x)
283279

284280

@@ -317,7 +313,7 @@ class Transformer(Module):
317313
standard Transformer decoder. Default: ``None``.
318314
norm_first (bool, optional): if ``True``, encoder and decoder layers
319315
will perform layer normalization before attention and MLP
320-
operations, otherwise after. Default: ``False``.
316+
operations, otherwise after. Default: ``True``.
321317
chekpoint (bool, optional): if ``True`` perform gradient checkpointing
322318
to reduce the memory usage at the expense of more computation.
323319
Default: ``False``.
@@ -334,37 +330,32 @@ def __init__(
334330
activation: Callable[[Any], Any] = relu,
335331
custom_encoder: Optional[Any] = None,
336332
custom_decoder: Optional[Any] = None,
337-
norm_first: bool = False,
333+
norm_first: bool = True,
338334
checkpoint: bool = False,
339335
):
340336
super().__init__()
341-
if custom_encoder is not None:
342-
self.encoder = custom_encoder
343-
else:
344-
self.encoder = TransformerEncoder(
345-
num_encoder_layers,
346-
dims,
347-
num_heads,
348-
mlp_dims,
349-
dropout,
350-
activation,
351-
norm_first,
352-
checkpoint,
353-
)
354337

355-
if custom_decoder is not None:
356-
self.decoder = custom_decoder
357-
else:
358-
self.decoder = TransformerDecoder(
359-
num_decoder_layers,
360-
dims,
361-
num_heads,
362-
mlp_dims,
363-
dropout,
364-
activation,
365-
norm_first,
366-
checkpoint,
367-
)
338+
self.encoder = custom_encoder or TransformerEncoder(
339+
num_encoder_layers,
340+
dims,
341+
num_heads,
342+
mlp_dims,
343+
dropout,
344+
activation,
345+
norm_first,
346+
checkpoint,
347+
)
348+
349+
self.decoder = custom_decoder or TransformerDecoder(
350+
num_decoder_layers,
351+
dims,
352+
num_heads,
353+
mlp_dims,
354+
dropout,
355+
activation,
356+
norm_first,
357+
checkpoint,
358+
)
368359

369360
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
370361
memory = self.encoder(src, src_mask)

0 commit comments

Comments
 (0)