1- # Copyright © 2023 Apple Inc.
1+ # Copyright © 2023-2024 Apple Inc.
22
33import math
4- from typing import List , Optional , Tuple
4+ from typing import Callable , List , Optional , Tuple , Union
55
66import mlx .core as mx
77from mlx .utils import tree_map
@@ -12,9 +12,10 @@ class Optimizer:
1212 optimizer on a per-parameter basis and apply it to a parameter tree.
1313 """
1414
15- def __init__ (self ):
15+ def __init__ (self , schedulers = None ):
1616 self ._initialized = False
17- self ._state = {}
17+ self ._state = {"step" : mx .array (0 , mx .uint64 )}
18+ self ._schedulers = {k : v for k , v in (schedulers or {}).items ()}
1819
1920 def update (self , model : "mlx.nn.Module" , gradients : dict ):
2021 """Apply the gradients to the parameters of the model and update the
@@ -44,9 +45,8 @@ def init(self, parameters: dict):
4445 >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
4546 >>> model = nn.Linear(2, 2)
4647 >>> optimizer.init(model.trainable_parameters())
47- >>> optimizer.state
48- {'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
49- [0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
48+ >>> optimizer.state.keys()
49+ dict_keys(['step', 'learning_rate', 'weight', 'bias'])
5050 """
5151 self ._state .update (tree_map (lambda x : {}, parameters ))
5252 tree_map (self .init_single , parameters , self ._state )
@@ -76,6 +76,15 @@ def apply_gradients(self, gradients: dict, parameters: dict):
7676 """
7777 if not self ._initialized :
7878 self .init (gradients )
79+
80+ # Update any scheduled variables
81+ for param , scheduler in self ._schedulers .items ():
82+ self .state [param ] = scheduler (self .step )
83+
84+ # Increment the step
85+ self .state ["step" ] = self .step + 1
86+
87+ # Apply the update
7988 return tree_map (self .apply_single , gradients , parameters , self .state )
8089
8190 def apply_single (self , gradient : mx .array , parameter : mx .array , state : dict ):
@@ -97,14 +106,31 @@ def state(self):
97106 def state (self , state : dict ):
98107 self ._state = state
99108
109+ @property
110+ def step (self ):
111+ return self .state ["step" ]
112+
100113 @property
101114 def learning_rate (self ):
102115 return self .state ["learning_rate" ]
103116
104117 @learning_rate .setter
105- def learning_rate (self , learning_rate : mx .array ):
118+ def learning_rate (self , learning_rate : Union [ float , mx .array ] ):
106119 self .state ["learning_rate" ] = mx .array (learning_rate )
107120
121+ def _maybe_schedule (
122+ self , name : str , param : Union [float , Callable [[mx .array ], mx .array ]]
123+ ):
124+ """
125+ To be used by derived classes to optionally put a parameter on a schedule.
126+ """
127+ if isinstance (param , Callable ):
128+ self ._schedulers [name ] = param
129+ param = param (self .step )
130+ else :
131+ param = mx .array (param )
132+ self .state [name ] = param
133+
108134
109135class SGD (Optimizer ):
110136 r"""The stochastic gradient descent optimizer.
@@ -117,7 +143,7 @@ class SGD(Optimizer):
117143 w_{t+1} &= w_t - \lambda v_{t+1}
118144
119145 Args:
120- learning_rate (float): The learning rate :math:`\lambda`.
146+ learning_rate (float or callable ): The learning rate :math:`\lambda`.
121147 momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
122148 weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
123149 dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
@@ -126,7 +152,7 @@ class SGD(Optimizer):
126152
127153 def __init__ (
128154 self ,
129- learning_rate : float ,
155+ learning_rate : Union [ float , Callable [[ mx . array ], mx . array ]] ,
130156 momentum : float = 0.0 ,
131157 weight_decay : float = 0.0 ,
132158 dampening : float = 0.0 ,
@@ -138,7 +164,7 @@ def __init__(
138164 )
139165 super ().__init__ ()
140166
141- self .learning_rate = learning_rate
167+ self ._maybe_schedule ( " learning_rate" , learning_rate )
142168 self .momentum = momentum
143169 self .weight_decay = weight_decay
144170 self .dampening = dampening
@@ -194,7 +220,7 @@ class RMSprop(Optimizer):
194220 def __init__ (self , learning_rate : float , alpha : float = 0.99 , eps : float = 1e-8 ):
195221 super ().__init__ ()
196222
197- self .learning_rate = learning_rate
223+ self ._maybe_schedule ( " learning_rate" , learning_rate )
198224 self .alpha = alpha
199225 self .eps = eps
200226
@@ -246,7 +272,7 @@ class Adagrad(Optimizer):
246272 def __init__ (self , learning_rate : float , eps : float = 1e-8 ):
247273 super ().__init__ ()
248274
249- self .learning_rate = learning_rate
275+ self ._maybe_schedule ( " learning_rate" , learning_rate )
250276 self .eps = eps
251277
252278 if self .eps < 0.0 :
@@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
295321 def __init__ (self , learning_rate : float , rho : float = 0.9 , eps : float = 1e-6 ):
296322 super ().__init__ ()
297323
298- self .learning_rate = learning_rate
324+ self ._maybe_schedule ( " learning_rate" , learning_rate )
299325 self .rho = rho
300326 self .eps = eps
301327 if self .rho < 0.0 :
@@ -361,7 +387,7 @@ def __init__(
361387 ):
362388 super ().__init__ ()
363389
364- self .learning_rate = learning_rate
390+ self ._maybe_schedule ( " learning_rate" , learning_rate )
365391 self .betas = betas
366392 self .eps = eps
367393
@@ -526,7 +552,7 @@ def __init__(
526552 ):
527553 super ().__init__ ()
528554
529- self .learning_rate = learning_rate
555+ self ._maybe_schedule ( " learning_rate" , learning_rate )
530556 self .betas = betas
531557 self .weight_decay = weight_decay
532558
@@ -596,7 +622,7 @@ def __init__(
596622 ):
597623 super ().__init__ ()
598624 if learning_rate is not None :
599- self .learning_rate = learning_rate
625+ self ._maybe_schedule ( " learning_rate" , learning_rate )
600626 self .eps = eps
601627 self .clip_threshold = clip_threshold
602628 self .decay_rate = decay_rate
@@ -608,7 +634,6 @@ def __init__(
608634
609635 def init_single (self , parameter : mx .array , state : dict ):
610636 """Initialize optimizer state"""
611- state ["step" ] = 0
612637 if parameter .ndim >= 2 :
613638 shape = parameter .shape
614639 dtype = parameter .dtype
@@ -626,10 +651,11 @@ def _compute_rms(self, inputs):
626651 def _compute_learning_rate (self , step , parameter_rms ):
627652 if self .relative_step :
628653 min_step = 1e-6 * step if self .warmup_init else 1e-2
629- relative_step_size = min (min_step , 1 / math . sqrt (step ))
654+ relative_step_size = mx . minimum (min_step , mx . rsqrt (step ))
630655 else :
631- relative_step_size = self .learning_rate . astype ( parameter_rms )
656+ relative_step_size = self .learning_rate
632657
658+ relative_step_size = relative_step_size .astype (parameter_rms .dtype )
633659 parameter_scale = 1.0
634660 if self .scale_parameter :
635661 parameter_scale = mx .maximum (self .eps [1 ], parameter_rms )
@@ -648,13 +674,12 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
648674 """Performs the Adafactor parameter and state update."""
649675 factored = gradient .ndim >= 2
650676
651- step = state ["step" ] + 1
652- state ["step" ] = step
677+ step = self .step
653678 use_first_moment = self .beta_1 is not None
654679
655680 parameter_rms = self ._compute_rms (parameter )
656681 learning_rate = self ._compute_learning_rate (step , parameter_rms )
657- beta_2 = 1.0 - math . pow (step , self .decay_rate )
682+ beta_2 = 1.0 - (step ** self .decay_rate ). astype ( parameter_rms . dtype )
658683 update = mx .square (gradient ) + self .eps [0 ]
659684
660685 if factored :
0 commit comments