|
1 | 1 | # Copyright © 2023 Apple Inc. |
2 | 2 |
|
3 | 3 | import math |
4 | | -from typing import List |
| 4 | +from typing import List, Optional, Tuple |
5 | 5 |
|
6 | 6 | import mlx.core as mx |
7 | 7 | from mlx.utils import tree_map |
@@ -76,7 +76,7 @@ def apply_single( |
76 | 76 |
|
77 | 77 |
|
78 | 78 | class SGD(Optimizer): |
79 | | - r"""Stochastic gradient descent optimizer. |
| 79 | + r"""The stochastic gradient descent optimizer. |
80 | 80 |
|
81 | 81 | Updates a parameter :math:`w` with a gradient :math:`g` as follows |
82 | 82 |
|
@@ -141,7 +141,7 @@ def apply_single( |
141 | 141 |
|
142 | 142 |
|
143 | 143 | class RMSprop(Optimizer): |
144 | | - r"""Implementation of the RMSprop optimizer [1]. |
| 144 | + r"""The RMSprop optimizer [1]. |
145 | 145 |
|
146 | 146 | [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning |
147 | 147 |
|
@@ -190,7 +190,7 @@ def apply_single( |
190 | 190 |
|
191 | 191 |
|
192 | 192 | class Adagrad(Optimizer): |
193 | | - r"""Implementation of the Adagrad optimizer [1]. |
| 193 | + r"""The Adagrad optimizer [1]. |
194 | 194 |
|
195 | 195 | Our Adagrad implementation follows the original paper. In detail, |
196 | 196 |
|
@@ -235,7 +235,7 @@ def apply_single( |
235 | 235 |
|
236 | 236 |
|
237 | 237 | class AdaDelta(Optimizer): |
238 | | - r"""Implementation of the AdaDelta optimizer with learning rate[1]. |
| 238 | + r"""The AdaDelta optimizer with a learning rate [1]. |
239 | 239 |
|
240 | 240 | Our AdaDelta implementation follows the original paper. In detail, |
241 | 241 |
|
@@ -294,7 +294,7 @@ def apply_single( |
294 | 294 |
|
295 | 295 |
|
296 | 296 | class Adam(Optimizer): |
297 | | - r"""Implementation of the Adam optimizer [1]. |
| 297 | + r"""The Adam optimizer [1]. |
298 | 298 |
|
299 | 299 | Our Adam implementation follows the original paper and omits the bias |
300 | 300 | correction in the first and second moment estimates. In detail, |
@@ -346,7 +346,7 @@ def apply_single( |
346 | 346 |
|
347 | 347 |
|
348 | 348 | class AdamW(Adam): |
349 | | - r"""Implementation of the AdamW optimizer [1]. |
| 349 | + r"""The AdamW optimizer [1]. |
350 | 350 |
|
351 | 351 | Following the above convention, in contrast with [1], we do not use bias |
352 | 352 | correction in the first and second moments for AdamW. We update the weights |
@@ -395,8 +395,7 @@ def apply_single( |
395 | 395 |
|
396 | 396 |
|
397 | 397 | class Adamax(Adam): |
398 | | - r"""Implementation of the Adamax optimizer. It is a variant of Adam based |
399 | | - on the infinity norm [1]. |
| 398 | + r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1]. |
400 | 399 |
|
401 | 400 | Our Adam implementation follows the original paper and omits the bias |
402 | 401 | correction in the first and second moment estimates. In detail, |
@@ -449,7 +448,7 @@ def apply_single( |
449 | 448 |
|
450 | 449 |
|
451 | 450 | class Lion(Optimizer): |
452 | | - r"""Implementation of the Lion optimizer [1]. |
| 451 | + r"""The Lion optimizer [1]. |
453 | 452 |
|
454 | 453 | Since updates are computed through the sign operation, they tend to |
455 | 454 | have larger norm than for other optimizers such as SGD and Adam. |
@@ -502,3 +501,139 @@ def apply_single( |
502 | 501 | if weight_decay > 0: |
503 | 502 | parameter = (1 - lr * weight_decay) * parameter |
504 | 503 | return parameter - lr * mx.sign(c) |
| 504 | + |
| 505 | + |
| 506 | +class Adafactor(Optimizer): |
| 507 | + r"""The Adafactor optimizer. |
| 508 | +
|
| 509 | + Our Adafactor implementation follows the original paper: `Adafactor: |
| 510 | + Adaptive Learning Rates with Sublinear Memory Cost |
| 511 | + <https://arxiv.org/abs/1804.04235>`_ |
| 512 | +
|
| 513 | + Args: |
| 514 | + learning_rate (float, optional): The learning rate. Default: ``None``. |
| 515 | + eps (tuple(float, float), optional): The first term :math:`\epsilon_1` |
| 516 | + added to the square of the gradients to improve numerical |
| 517 | + stability and the second term :math:`\epsilon_2` is used for |
| 518 | + parameter scaling if ``parameter_scale`` is set to ``True``. |
| 519 | + Default: ``(1e-30, 1e-3)``. |
| 520 | + clip_threshold (float, optional): Clips the unscaled update at |
| 521 | + ``clip_threshold``. Default: ``1.0``. |
| 522 | + decay_rate (float, optional): Coefficient for the running average |
| 523 | + of the squared gradient. Default: ``-0.8``. |
| 524 | + beta_1 (float, optional): If set to a value bigger than zero |
| 525 | + then first moment will be used. Default: ``None``. |
| 526 | + weight_decay (float, optional): The weight decay :math:`\lambda`. |
| 527 | + Default: ``0.0``. |
| 528 | + scale_parameter (bool, optional): If set to ``True`` the learning rate |
| 529 | + will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`. |
| 530 | + Default: ``True``. |
| 531 | + relative_step (bool, optional): If set to ``True`` the ``learning_rate`` |
| 532 | + will be ignored and relative step size will be computed. |
| 533 | + Default: ``True``. |
| 534 | + warmup_init (bool, optional): If set to ``True`` then the relative |
| 535 | + step size will be calculated by the current step. Default: |
| 536 | + ``False``. |
| 537 | + """ |
| 538 | + |
| 539 | + def __init__( |
| 540 | + self, |
| 541 | + learning_rate: Optional[float] = None, |
| 542 | + eps: Tuple[float, float] = (1e-30, 1e-3), |
| 543 | + clip_threshold: float = 1.0, |
| 544 | + decay_rate: float = -0.8, |
| 545 | + beta_1: Optional[float] = None, |
| 546 | + weight_decay: float = 0.0, |
| 547 | + scale_parameter: bool = True, |
| 548 | + relative_step: bool = True, |
| 549 | + warmup_init: bool = False, |
| 550 | + ): |
| 551 | + super().__init__() |
| 552 | + self.learning_rate = learning_rate |
| 553 | + self.eps = eps |
| 554 | + self.clip_threshold = clip_threshold |
| 555 | + self.decay_rate = decay_rate |
| 556 | + self.beta_1 = beta_1 |
| 557 | + self.weight_decay = weight_decay |
| 558 | + self.scale_parameter = scale_parameter |
| 559 | + self.relative_step = relative_step |
| 560 | + self.warmup_init = warmup_init |
| 561 | + |
| 562 | + def _compute_rms(self, inputs): |
| 563 | + return mx.sqrt(mx.mean(mx.square(inputs))) |
| 564 | + |
| 565 | + def _compute_learning_rate(self, step, parameter_rms): |
| 566 | + relative_step_size = self.learning_rate |
| 567 | + if self.relative_step: |
| 568 | + min_step = 1e-6 * step if self.warmup_init else 1e-2 |
| 569 | + relative_step_size = min(min_step, 1 / math.sqrt(step)) |
| 570 | + |
| 571 | + parameter_scale = 1.0 |
| 572 | + if self.scale_parameter: |
| 573 | + parameter_scale = mx.maximum(self.eps[1], parameter_rms) |
| 574 | + return parameter_scale * relative_step_size |
| 575 | + |
| 576 | + def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col): |
| 577 | + r_factor = mx.rsqrt( |
| 578 | + exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True) |
| 579 | + ) |
| 580 | + c_factor = mx.rsqrt(exp_avg_sq_col) |
| 581 | + return mx.matmul( |
| 582 | + mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0) |
| 583 | + ) |
| 584 | + |
| 585 | + def apply_single( |
| 586 | + self, gradient: mx.array, parameter: mx.array, state: OptimizerState |
| 587 | + ): |
| 588 | + """Performs the Adafactor parameter and state update.""" |
| 589 | + gradient_shape = gradient.shape |
| 590 | + factored = len(gradient_shape) >= 2 |
| 591 | + step = state.get("step", 0) + 1 |
| 592 | + state["step"] = step |
| 593 | + use_first_moment = self.beta_1 is not None |
| 594 | + |
| 595 | + parameter_rms = self._compute_rms(parameter) |
| 596 | + learning_rate = self._compute_learning_rate(step, parameter_rms) |
| 597 | + beta_2 = 1.0 - math.pow(step, self.decay_rate) |
| 598 | + update = mx.square(gradient) + self.eps[0] |
| 599 | + |
| 600 | + if factored: |
| 601 | + exp_avg_sq_row = state.get( |
| 602 | + "exp_avg_sq_row", mx.zeros(gradient_shape[:-1], dtype=gradient.dtype) |
| 603 | + ) |
| 604 | + exp_avg_sq_col = state.get( |
| 605 | + "exp_avg_sq_col", |
| 606 | + mx.zeros( |
| 607 | + gradient_shape[:-2] + gradient_shape[-1:], dtype=gradient.dtype |
| 608 | + ), |
| 609 | + ) |
| 610 | + exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + ( |
| 611 | + (1 - beta_2) * mx.mean(update, axis=-1) |
| 612 | + ) |
| 613 | + exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + ( |
| 614 | + (1 - beta_2) * mx.mean(update, axis=-2) |
| 615 | + ) |
| 616 | + state["exp_avg_sq_row"] = exp_avg_sq_row |
| 617 | + state["exp_avg_sq_col"] = exp_avg_sq_col |
| 618 | + update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col) |
| 619 | + update = update * gradient |
| 620 | + else: |
| 621 | + exp_avg_sq = state.get("exp_avg_sq", mx.zeros_like(gradient)) |
| 622 | + exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update) |
| 623 | + state["exp_avg_sq"] = exp_avg_sq |
| 624 | + update = mx.rsqrt(exp_avg_sq) * gradient |
| 625 | + |
| 626 | + update = update / mx.maximum( |
| 627 | + 1.0, self._compute_rms(update) / self.clip_threshold |
| 628 | + ) |
| 629 | + update = learning_rate * update |
| 630 | + |
| 631 | + if use_first_moment: |
| 632 | + exp_avg = state.get("exp_avg", mx.zeros_like(gradient)) |
| 633 | + exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update) |
| 634 | + state["exp_avg"] = exp_avg |
| 635 | + update = exp_avg |
| 636 | + |
| 637 | + if self.weight_decay != 0: |
| 638 | + parameter += parameter * (-self.weight_decay * learning_rate) |
| 639 | + return parameter - update |
0 commit comments