Skip to content

Commit 818cda1

Browse files
postmallocawniangeloskath
authored
Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <[email protected]> Co-authored-by: Angelos Katharopoulos <[email protected]>
1 parent 85143fe commit 818cda1

File tree

10 files changed

+235
-47
lines changed

10 files changed

+235
-47
lines changed

docs/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
src/python/_autosummary*/
22
src/python/nn/_autosummary*/
3+
src/python/optimizers/_autosummary*/

docs/src/python/optimizers.rst

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,6 @@ model's parameters and the **optimizer state**.
3131
3232
.. toctree::
3333

34-
optimizer
35-
36-
.. currentmodule:: mlx.optimizers
37-
38-
.. autosummary::
39-
:toctree: _autosummary
40-
:template: optimizers-template.rst
41-
42-
SGD
43-
RMSprop
44-
Adagrad
45-
Adafactor
46-
AdaDelta
47-
Adam
48-
AdamW
49-
Adamax
50-
Lion
34+
optimizers/optimizer
35+
optimizers/common_optimizers
36+
optimizers/schedulers
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.. _common_optimizers:
2+
3+
Common Optimizers
4+
=================
5+
6+
.. currentmodule:: mlx.optimizers
7+
8+
.. autosummary::
9+
:toctree: _autosummary
10+
:template: optimizers-template.rst
11+
12+
SGD
13+
RMSprop
14+
Adagrad
15+
Adafactor
16+
AdaDelta
17+
Adam
18+
AdamW
19+
Adamax
20+
Lion
File renamed without changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. _schedulers:
2+
3+
Schedulers
4+
==========
5+
6+
.. currentmodule:: mlx.optimizers
7+
8+
.. autosummary::
9+
:toctree: _autosummary
10+
11+
step_decay
12+
exponential_decay
13+
cosine_decay

python/mlx/optimizers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright © 2023-2024 Apple Inc.
2+
3+
from mlx.optimizers.optimizers import *
4+
from mlx.optimizers.schedulers import *
Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Copyright © 2023 Apple Inc.
1+
# Copyright © 2023-2024 Apple Inc.
22

33
import math
4-
from typing import List, Optional, Tuple
4+
from typing import Callable, List, Optional, Tuple, Union
55

66
import mlx.core as mx
77
from mlx.utils import tree_map
@@ -12,9 +12,10 @@ class Optimizer:
1212
optimizer on a per-parameter basis and apply it to a parameter tree.
1313
"""
1414

15-
def __init__(self):
15+
def __init__(self, schedulers=None):
1616
self._initialized = False
17-
self._state = {}
17+
self._state = {"step": mx.array(0, mx.uint64)}
18+
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
1819

1920
def update(self, model: "mlx.nn.Module", gradients: dict):
2021
"""Apply the gradients to the parameters of the model and update the
@@ -44,9 +45,8 @@ def init(self, parameters: dict):
4445
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
4546
>>> model = nn.Linear(2, 2)
4647
>>> optimizer.init(model.trainable_parameters())
47-
>>> optimizer.state
48-
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
49-
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
48+
>>> optimizer.state.keys()
49+
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
5050
"""
5151
self._state.update(tree_map(lambda x: {}, parameters))
5252
tree_map(self.init_single, parameters, self._state)
@@ -76,6 +76,15 @@ def apply_gradients(self, gradients: dict, parameters: dict):
7676
"""
7777
if not self._initialized:
7878
self.init(gradients)
79+
80+
# Update any scheduled variables
81+
for param, scheduler in self._schedulers.items():
82+
self.state[param] = scheduler(self.step)
83+
84+
# Increment the step
85+
self.state["step"] = self.step + 1
86+
87+
# Apply the update
7988
return tree_map(self.apply_single, gradients, parameters, self.state)
8089

8190
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
@@ -97,14 +106,31 @@ def state(self):
97106
def state(self, state: dict):
98107
self._state = state
99108

109+
@property
110+
def step(self):
111+
return self.state["step"]
112+
100113
@property
101114
def learning_rate(self):
102115
return self.state["learning_rate"]
103116

104117
@learning_rate.setter
105-
def learning_rate(self, learning_rate: mx.array):
118+
def learning_rate(self, learning_rate: Union[float, mx.array]):
106119
self.state["learning_rate"] = mx.array(learning_rate)
107120

121+
def _maybe_schedule(
122+
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
123+
):
124+
"""
125+
To be used by derived classes to optionally put a parameter on a schedule.
126+
"""
127+
if isinstance(param, Callable):
128+
self._schedulers[name] = param
129+
param = param(self.step)
130+
else:
131+
param = mx.array(param)
132+
self.state[name] = param
133+
108134

109135
class SGD(Optimizer):
110136
r"""The stochastic gradient descent optimizer.
@@ -117,7 +143,7 @@ class SGD(Optimizer):
117143
w_{t+1} &= w_t - \lambda v_{t+1}
118144
119145
Args:
120-
learning_rate (float): The learning rate :math:`\lambda`.
146+
learning_rate (float or callable): The learning rate :math:`\lambda`.
121147
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
122148
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
123149
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
@@ -126,7 +152,7 @@ class SGD(Optimizer):
126152

127153
def __init__(
128154
self,
129-
learning_rate: float,
155+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
130156
momentum: float = 0.0,
131157
weight_decay: float = 0.0,
132158
dampening: float = 0.0,
@@ -138,7 +164,7 @@ def __init__(
138164
)
139165
super().__init__()
140166

141-
self.learning_rate = learning_rate
167+
self._maybe_schedule("learning_rate", learning_rate)
142168
self.momentum = momentum
143169
self.weight_decay = weight_decay
144170
self.dampening = dampening
@@ -194,7 +220,7 @@ class RMSprop(Optimizer):
194220
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
195221
super().__init__()
196222

197-
self.learning_rate = learning_rate
223+
self._maybe_schedule("learning_rate", learning_rate)
198224
self.alpha = alpha
199225
self.eps = eps
200226

@@ -246,7 +272,7 @@ class Adagrad(Optimizer):
246272
def __init__(self, learning_rate: float, eps: float = 1e-8):
247273
super().__init__()
248274

249-
self.learning_rate = learning_rate
275+
self._maybe_schedule("learning_rate", learning_rate)
250276
self.eps = eps
251277

252278
if self.eps < 0.0:
@@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
295321
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
296322
super().__init__()
297323

298-
self.learning_rate = learning_rate
324+
self._maybe_schedule("learning_rate", learning_rate)
299325
self.rho = rho
300326
self.eps = eps
301327
if self.rho < 0.0:
@@ -361,7 +387,7 @@ def __init__(
361387
):
362388
super().__init__()
363389

364-
self.learning_rate = learning_rate
390+
self._maybe_schedule("learning_rate", learning_rate)
365391
self.betas = betas
366392
self.eps = eps
367393

@@ -526,7 +552,7 @@ def __init__(
526552
):
527553
super().__init__()
528554

529-
self.learning_rate = learning_rate
555+
self._maybe_schedule("learning_rate", learning_rate)
530556
self.betas = betas
531557
self.weight_decay = weight_decay
532558

@@ -596,7 +622,7 @@ def __init__(
596622
):
597623
super().__init__()
598624
if learning_rate is not None:
599-
self.learning_rate = learning_rate
625+
self._maybe_schedule("learning_rate", learning_rate)
600626
self.eps = eps
601627
self.clip_threshold = clip_threshold
602628
self.decay_rate = decay_rate
@@ -608,7 +634,6 @@ def __init__(
608634

609635
def init_single(self, parameter: mx.array, state: dict):
610636
"""Initialize optimizer state"""
611-
state["step"] = 0
612637
if parameter.ndim >= 2:
613638
shape = parameter.shape
614639
dtype = parameter.dtype
@@ -626,10 +651,11 @@ def _compute_rms(self, inputs):
626651
def _compute_learning_rate(self, step, parameter_rms):
627652
if self.relative_step:
628653
min_step = 1e-6 * step if self.warmup_init else 1e-2
629-
relative_step_size = min(min_step, 1 / math.sqrt(step))
654+
relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
630655
else:
631-
relative_step_size = self.learning_rate.astype(parameter_rms)
656+
relative_step_size = self.learning_rate
632657

658+
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
633659
parameter_scale = 1.0
634660
if self.scale_parameter:
635661
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
@@ -648,13 +674,12 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
648674
"""Performs the Adafactor parameter and state update."""
649675
factored = gradient.ndim >= 2
650676

651-
step = state["step"] + 1
652-
state["step"] = step
677+
step = self.step
653678
use_first_moment = self.beta_1 is not None
654679

655680
parameter_rms = self._compute_rms(parameter)
656681
learning_rate = self._compute_learning_rate(step, parameter_rms)
657-
beta_2 = 1.0 - math.pow(step, self.decay_rate)
682+
beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
658683
update = mx.square(gradient) + self.eps[0]
659684

660685
if factored:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright © 2023-2024 Apple Inc.
2+
3+
import math
4+
5+
import mlx.core as mx
6+
7+
8+
def exponential_decay(init: float, decay_rate: float):
9+
r"""Make an exponential decay scheduler.
10+
11+
Args:
12+
init (float): Initial value.
13+
decay_rate (float): Multiplicative factor to decay by.
14+
15+
Example:
16+
>>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
17+
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
18+
>>> optimizer.learning_rate
19+
array(0.1, dtype=float32)
20+
>>>
21+
>>> for _ in range(5): optimizer.update({}, {})
22+
...
23+
>>> optimizer.learning_rate
24+
array(0.06561, dtype=float32)
25+
"""
26+
27+
def schedule(step):
28+
return init * decay_rate**step
29+
30+
return schedule
31+
32+
33+
def step_decay(init: float, decay_rate: float, step_size: int):
34+
r"""Make a step decay scheduler.
35+
36+
Args:
37+
init (float): Initial value.
38+
decay_rate (float): Multiplicative factor to decay by.
39+
step_size (int): Decay every ``step_size`` steps.
40+
41+
Example:
42+
43+
>>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
44+
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
45+
>>> optimizer.learning_rate
46+
array(0.1, dtype=float32)
47+
>>>
48+
>>> for _ in range(21): optimizer.update({}, {})
49+
...
50+
>>> optimizer.learning_rate
51+
array(0.081, dtype=float32)
52+
"""
53+
54+
def schedule(step):
55+
return init * (decay_rate ** (step // step_size))
56+
57+
return schedule
58+
59+
60+
def cosine_decay(init: float, decay_steps: int):
61+
r"""Make a cosine decay scheduler.
62+
63+
Args:
64+
init (float): Initial value.
65+
decay_steps (int): Number of steps to decay over. The decayed
66+
value is constant for steps beyond ``decay_steps``.
67+
68+
Example:
69+
70+
>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
71+
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
72+
>>> optimizer.learning_rate
73+
array(0.1, dtype=float32)
74+
>>>
75+
>>> for _ in range(5): optimizer.update({}, {})
76+
...
77+
>>> optimizer.learning_rate
78+
array(0.0999961, dtype=float32)
79+
"""
80+
81+
def scheduler(step):
82+
s = mx.minimum(step, decay_steps)
83+
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
84+
return init * decay
85+
86+
return scheduler

python/src/array.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,12 @@ void init_array(py::module_& m) {
971971
return power(a, to_array(v, a.dtype()));
972972
},
973973
"other"_a)
974+
.def(
975+
"__rpow__",
976+
[](const array& a, const ScalarOrArray v) {
977+
return power(to_array(v, a.dtype()), a);
978+
},
979+
"other"_a)
974980
.def(
975981
"__ipow__",
976982
[](array& a, const ScalarOrArray v) {

0 commit comments

Comments
 (0)