Skip to content

Commit fd3377d

Browse files
authored
Support bias correction in Adam and AdamW optimizers (#1640)
1 parent d0b6cb0 commit fd3377d

File tree

2 files changed

+78
-11
lines changed

2 files changed

+78
-11
lines changed

python/mlx/optimizers/optimizers.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,7 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
395395

396396

397397
class Adam(Optimizer):
398-
r"""The Adam optimizer [1].
399-
400-
Our Adam implementation follows the original paper and omits the bias
401-
correction in the first and second moment estimates. In detail,
398+
r"""The Adam optimizer [1]. In detail,
402399
403400
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
404401
optimization. ICLR 2015.
@@ -416,19 +413,23 @@ class Adam(Optimizer):
416413
gradient and its square. Default: ``(0.9, 0.999)``
417414
eps (float, optional): The term :math:`\epsilon` added to the
418415
denominator to improve numerical stability. Default: ``1e-8``
416+
bias_correction (bool, optional): If set to ``True``, bias correction
417+
is applied. Default: ``False``
419418
"""
420419

421420
def __init__(
422421
self,
423422
learning_rate: Union[float, Callable[[mx.array], mx.array]],
424423
betas: List[float] = [0.9, 0.999],
425424
eps: float = 1e-8,
425+
bias_correction: bool = False,
426426
):
427427
super().__init__()
428428

429429
self._maybe_schedule("learning_rate", learning_rate)
430430
self.betas = betas
431431
self.eps = eps
432+
self.bias_correction = bias_correction
432433

433434
def init_single(self, parameter: mx.array, state: dict):
434435
"""Initialize optimizer state"""
@@ -441,6 +442,8 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
441442
lr = self.learning_rate.astype(gradient.dtype)
442443
b1, b2 = self.betas
443444
eps = self.eps
445+
bias_correction = self.bias_correction
446+
step = self.step
444447

445448
m = state["m"]
446449
v = state["v"]
@@ -449,15 +452,17 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
449452
state["m"] = m
450453
state["v"] = v
451454

452-
return parameter - lr * m / (mx.sqrt(v) + eps)
455+
if bias_correction:
456+
numerator = lr / (1 - b1**step) * m
457+
denominator = mx.sqrt(v) / mx.sqrt(1 - b2**step) + eps
458+
return parameter - numerator / denominator
459+
else:
460+
return parameter - lr * m / (mx.sqrt(v) + eps)
453461

454462

455463
class AdamW(Adam):
456-
r"""The AdamW optimizer [1].
457-
458-
Following the above convention, in contrast with [1], we do not use bias
459-
correction in the first and second moments for AdamW. We update the weights
460-
with a weight_decay (:math:`\lambda`) value:
464+
r"""The AdamW optimizer [1]. We update the weights with a weight_decay
465+
(:math:`\lambda`) value:
461466
462467
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
463468
regularization. ICLR 2019.
@@ -477,6 +482,8 @@ class AdamW(Adam):
477482
denominator to improve numerical stability. Default: ``1e-8``
478483
weight_decay (float, optional): The weight decay :math:`\lambda`.
479484
Default: ``0``.
485+
bias_correction (bool, optional): If set to ``True``, bias correction
486+
is applied. Default: ``False``
480487
"""
481488

482489
def __init__(
@@ -485,8 +492,14 @@ def __init__(
485492
betas: List[float] = [0.9, 0.999],
486493
eps: float = 1e-8,
487494
weight_decay: float = 0.01,
495+
bias_correction: bool = False,
488496
):
489-
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
497+
super().__init__(
498+
learning_rate=learning_rate,
499+
betas=betas,
500+
eps=eps,
501+
bias_correction=bias_correction,
502+
)
490503
self.weight_decay = weight_decay
491504

492505
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):

python/tests/test_optimizers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,17 @@
1010
import mlx.optimizers as opt
1111
import mlx.utils
1212
import mlx_tests
13+
import numpy as np
1314
from mlx.utils import tree_flatten, tree_map, tree_unflatten
1415

16+
try:
17+
import torch
18+
import torch.nn.functional as F
19+
20+
has_torch = True
21+
except ImportError as e:
22+
has_torch = False
23+
1524

1625
def get_all_optimizers():
1726
classes = dict()
@@ -186,6 +195,51 @@ def test_adam(self):
186195
)
187196
)
188197

198+
@unittest.skipIf(not has_torch, "requires Torch")
199+
def test_adamw_matches_pytorch(self):
200+
mx.random.seed(0)
201+
np.random.seed(0)
202+
203+
model = nn.Linear(3, 1)
204+
init_weight = np.array(model.weight.tolist())
205+
init_bias = np.array(model.bias.tolist())
206+
207+
def loss_fn(model, x, y):
208+
pred = model(x)
209+
return nn.losses.mse_loss(pred, y)
210+
211+
x = np.random.rand(3, 3)
212+
y = np.random.rand(3, 1)
213+
214+
optimizer = opt.AdamW(learning_rate=3e-4, bias_correction=True)
215+
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
216+
loss, grads = loss_and_grad_fn(model, mx.array(x), mx.array(y))
217+
optimizer.update(model, grads)
218+
219+
# Equivalent torch code
220+
torch_model = torch.nn.Linear(3, 1)
221+
222+
# copy over the parameters
223+
torch_model.weight.data = torch.tensor(init_weight, dtype=torch.float32)
224+
torch_model.bias.data = torch.tensor(init_bias, dtype=torch.float32)
225+
226+
torch_optimizer = torch.optim.AdamW(torch_model.parameters(), lr=3e-4)
227+
torch_optimizer.zero_grad()
228+
pred = torch_model(torch.tensor(x, dtype=torch.float32))
229+
loss = torch.nn.MSELoss()(pred, torch.tensor(y, dtype=torch.float32))
230+
loss.backward()
231+
torch_optimizer.step()
232+
233+
for name, param in torch_model.named_parameters():
234+
mlx_grad = np.array(grads[name])
235+
torch_grad = param.grad.detach().numpy()
236+
self.assertTrue(np.allclose(torch_grad, mlx_grad))
237+
238+
for name, param in torch_model.named_parameters():
239+
mlx_param = np.array(model[name])
240+
torch_param = param.data.detach().numpy()
241+
self.assertTrue(np.allclose(torch_param, mlx_param))
242+
189243
def test_lion(self):
190244
params = {
191245
"first": [mx.zeros((10,)), mx.zeros((1,))],

0 commit comments

Comments
 (0)