Skip to content

Commit 37fc9db

Browse files
hazemessammawni
andauthored
Added Adafactor (#415)
* Added adafactor * Added Adafactor and ran pre-commit * modified operations * Added docstrings * Switched two ops to fix a bug * added underscore for internal functions and removed the plus sign in the last return statment * Removed parameter rms from the optimizer state because its not needed * Added simple MNIST test for Adafactor and temporary training log * remove test files * nits in docs * comment nit --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 755dcf6 commit 37fc9db

File tree

3 files changed

+164
-10
lines changed

3 files changed

+164
-10
lines changed

docs/src/python/optimizers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ model's parameters and the **optimizer state**.
4040
SGD
4141
RMSprop
4242
Adagrad
43+
Adafactor
4344
AdaDelta
4445
Adam
4546
AdamW

python/mlx/optimizers.py

Lines changed: 145 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright © 2023 Apple Inc.
22

33
import math
4-
from typing import List
4+
from typing import List, Optional, Tuple
55

66
import mlx.core as mx
77
from mlx.utils import tree_map
@@ -76,7 +76,7 @@ def apply_single(
7676

7777

7878
class SGD(Optimizer):
79-
r"""Stochastic gradient descent optimizer.
79+
r"""The stochastic gradient descent optimizer.
8080
8181
Updates a parameter :math:`w` with a gradient :math:`g` as follows
8282
@@ -141,7 +141,7 @@ def apply_single(
141141

142142

143143
class RMSprop(Optimizer):
144-
r"""Implementation of the RMSprop optimizer [1].
144+
r"""The RMSprop optimizer [1].
145145
146146
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
147147
@@ -190,7 +190,7 @@ def apply_single(
190190

191191

192192
class Adagrad(Optimizer):
193-
r"""Implementation of the Adagrad optimizer [1].
193+
r"""The Adagrad optimizer [1].
194194
195195
Our Adagrad implementation follows the original paper. In detail,
196196
@@ -235,7 +235,7 @@ def apply_single(
235235

236236

237237
class AdaDelta(Optimizer):
238-
r"""Implementation of the AdaDelta optimizer with learning rate[1].
238+
r"""The AdaDelta optimizer with a learning rate [1].
239239
240240
Our AdaDelta implementation follows the original paper. In detail,
241241
@@ -294,7 +294,7 @@ def apply_single(
294294

295295

296296
class Adam(Optimizer):
297-
r"""Implementation of the Adam optimizer [1].
297+
r"""The Adam optimizer [1].
298298
299299
Our Adam implementation follows the original paper and omits the bias
300300
correction in the first and second moment estimates. In detail,
@@ -346,7 +346,7 @@ def apply_single(
346346

347347

348348
class AdamW(Adam):
349-
r"""Implementation of the AdamW optimizer [1].
349+
r"""The AdamW optimizer [1].
350350
351351
Following the above convention, in contrast with [1], we do not use bias
352352
correction in the first and second moments for AdamW. We update the weights
@@ -395,8 +395,7 @@ def apply_single(
395395

396396

397397
class Adamax(Adam):
398-
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
399-
on the infinity norm [1].
398+
r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
400399
401400
Our Adam implementation follows the original paper and omits the bias
402401
correction in the first and second moment estimates. In detail,
@@ -449,7 +448,7 @@ def apply_single(
449448

450449

451450
class Lion(Optimizer):
452-
r"""Implementation of the Lion optimizer [1].
451+
r"""The Lion optimizer [1].
453452
454453
Since updates are computed through the sign operation, they tend to
455454
have larger norm than for other optimizers such as SGD and Adam.
@@ -502,3 +501,139 @@ def apply_single(
502501
if weight_decay > 0:
503502
parameter = (1 - lr * weight_decay) * parameter
504503
return parameter - lr * mx.sign(c)
504+
505+
506+
class Adafactor(Optimizer):
507+
r"""The Adafactor optimizer.
508+
509+
Our Adafactor implementation follows the original paper: `Adafactor:
510+
Adaptive Learning Rates with Sublinear Memory Cost
511+
<https://arxiv.org/abs/1804.04235>`_
512+
513+
Args:
514+
learning_rate (float, optional): The learning rate. Default: ``None``.
515+
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
516+
added to the square of the gradients to improve numerical
517+
stability and the second term :math:`\epsilon_2` is used for
518+
parameter scaling if ``parameter_scale`` is set to ``True``.
519+
Default: ``(1e-30, 1e-3)``.
520+
clip_threshold (float, optional): Clips the unscaled update at
521+
``clip_threshold``. Default: ``1.0``.
522+
decay_rate (float, optional): Coefficient for the running average
523+
of the squared gradient. Default: ``-0.8``.
524+
beta_1 (float, optional): If set to a value bigger than zero
525+
then first moment will be used. Default: ``None``.
526+
weight_decay (float, optional): The weight decay :math:`\lambda`.
527+
Default: ``0.0``.
528+
scale_parameter (bool, optional): If set to ``True`` the learning rate
529+
will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
530+
Default: ``True``.
531+
relative_step (bool, optional): If set to ``True`` the ``learning_rate``
532+
will be ignored and relative step size will be computed.
533+
Default: ``True``.
534+
warmup_init (bool, optional): If set to ``True`` then the relative
535+
step size will be calculated by the current step. Default:
536+
``False``.
537+
"""
538+
539+
def __init__(
540+
self,
541+
learning_rate: Optional[float] = None,
542+
eps: Tuple[float, float] = (1e-30, 1e-3),
543+
clip_threshold: float = 1.0,
544+
decay_rate: float = -0.8,
545+
beta_1: Optional[float] = None,
546+
weight_decay: float = 0.0,
547+
scale_parameter: bool = True,
548+
relative_step: bool = True,
549+
warmup_init: bool = False,
550+
):
551+
super().__init__()
552+
self.learning_rate = learning_rate
553+
self.eps = eps
554+
self.clip_threshold = clip_threshold
555+
self.decay_rate = decay_rate
556+
self.beta_1 = beta_1
557+
self.weight_decay = weight_decay
558+
self.scale_parameter = scale_parameter
559+
self.relative_step = relative_step
560+
self.warmup_init = warmup_init
561+
562+
def _compute_rms(self, inputs):
563+
return mx.sqrt(mx.mean(mx.square(inputs)))
564+
565+
def _compute_learning_rate(self, step, parameter_rms):
566+
relative_step_size = self.learning_rate
567+
if self.relative_step:
568+
min_step = 1e-6 * step if self.warmup_init else 1e-2
569+
relative_step_size = min(min_step, 1 / math.sqrt(step))
570+
571+
parameter_scale = 1.0
572+
if self.scale_parameter:
573+
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
574+
return parameter_scale * relative_step_size
575+
576+
def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
577+
r_factor = mx.rsqrt(
578+
exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
579+
)
580+
c_factor = mx.rsqrt(exp_avg_sq_col)
581+
return mx.matmul(
582+
mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
583+
)
584+
585+
def apply_single(
586+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
587+
):
588+
"""Performs the Adafactor parameter and state update."""
589+
gradient_shape = gradient.shape
590+
factored = len(gradient_shape) >= 2
591+
step = state.get("step", 0) + 1
592+
state["step"] = step
593+
use_first_moment = self.beta_1 is not None
594+
595+
parameter_rms = self._compute_rms(parameter)
596+
learning_rate = self._compute_learning_rate(step, parameter_rms)
597+
beta_2 = 1.0 - math.pow(step, self.decay_rate)
598+
update = mx.square(gradient) + self.eps[0]
599+
600+
if factored:
601+
exp_avg_sq_row = state.get(
602+
"exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype)
603+
)
604+
exp_avg_sq_col = state.get(
605+
"exp_avg_sq_col",
606+
mx.zeros(
607+
gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype
608+
),
609+
)
610+
exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
611+
(1 - beta_2) * mx.mean(update, axis=-1)
612+
)
613+
exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
614+
(1 - beta_2) * mx.mean(update, axis=-2)
615+
)
616+
state["exp_avg_sq_row"] = exp_avg_sq_row
617+
state["exp_avg_sq_col"] = exp_avg_sq_col
618+
update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
619+
update = update * gradient
620+
else:
621+
exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient))
622+
exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
623+
state["exp_avg_sq"] = exp_avg_sq
624+
update = mx.rsqrt(exp_avg_sq) * gradient
625+
626+
update = update / mx.maximum(
627+
1.0, self._compute_rms(update) / self.clip_threshold
628+
)
629+
update = learning_rate * update
630+
631+
if use_first_moment:
632+
exp_avg = state.get("exp_avg", mx.zeros_like(gradient))
633+
exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
634+
state["exp_avg"] = exp_avg
635+
update = exp_avg
636+
637+
if self.weight_decay != 0:
638+
parameter += parameter * (-self.weight_decay * learning_rate)
639+
return parameter - update

python/tests/test_optimizers.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ def test_optimizers(self):
3939
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
4040
self.assertTrue(all_equal)
4141

42+
def test_adafactor(self):
43+
x = mx.zeros((5, 5))
44+
grad = mx.ones_like(x)
45+
optimizer = opt.Adafactor()
46+
for _ in range(2):
47+
xp = optimizer.apply_single(grad, x, optimizer.state)
48+
self.assertEqual(xp.dtype, x.dtype)
49+
self.assertEqual(xp.shape, x.shape)
50+
51+
x = mx.zeros((5, 5), mx.float16)
52+
grad = mx.ones_like(x)
53+
optimizer = opt.Adafactor()
54+
for _ in range(2):
55+
xp = optimizer.apply_single(grad, x, optimizer.state)
56+
self.assertEqual(xp.dtype, x.dtype)
57+
self.assertEqual(xp.shape, x.shape)
58+
self.assertEqual(optimizer.state["step"], 2)
59+
4260

4361
if __name__ == "__main__":
4462
unittest.main()

0 commit comments

Comments
 (0)