@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
2020 The Muon optimizer can use both the Muon update step or the
2121 AdamW update step based on the following:
2222
23- - For any variable that isn't 2D, 3D or 4D, the AdamW step
23+ - For any variable that isn't 2D, the AdamW step
2424 will be used. This is not configurable.
2525 - If the argument `exclude_embeddings` (defaults to `True`) is set
2626 to `True`, the AdamW step will be used.
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
4646 that takes no arguments and returns the actual value to use.
4747 The exponential decay rate for the 1st moment estimates. Defaults to
4848 `0.9`.
49- adam_beta_2: A float value or a constant float tensor, ora callable
49+ adam_beta_2: A float value or a constant float tensor, or a callable
5050 that takes no arguments and returns the actual value to use.
5151 The exponential decay rate for the 2nd moment estimates. Defaults to
5252 `0.999`.
53+ adam_weight_decay: Float. If set, weight decay is applied when using
54+ the Adam optimizer.
5355 epsilon: A small constant for numerical stability. This is
5456 "epsilon hat" in the Kingma and Ba paper
5557 (in the formula just before Section 2.1),
@@ -67,20 +69,25 @@ class Muon(optimizer.Optimizer):
6769 It is recommended to use the default value
6870 adam_lr_ratio: Float, the ratio of the learning rate when
6971 using Adam to the main learning rate.
70- it is recommended to set it to 0. 1
72+ It is recommended to set it to 1
7173 momentum: Float, momentum used by internal SGD.
7274 ns_steps: Integer, number of Newton-Schulz iterations to run.
7375 nesterov: Boolean, whether to use Nesterov-style momentum
7476 {{base_optimizer_keyword_args}}
77+ rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
78+ that can enhance the stability of Muon, allowing it to use the
79+ same learning rate and weight decay as Adam. Defaults to `0.2`.
80+ Set to `None` to disable this feature.
7581 """
7682
7783 def __init__ (
7884 self ,
7985 learning_rate = 0.001 ,
8086 adam_beta_1 = 0.9 ,
8187 adam_beta_2 = 0.999 ,
88+ adam_weight_decay = 0.004 ,
8289 epsilon = 1e-7 ,
83- weight_decay = 0.1 ,
90+ weight_decay = 0.004 ,
8491 clipnorm = None ,
8592 clipvalue = None ,
8693 global_clipnorm = None ,
@@ -95,10 +102,11 @@ def __init__(
95102 muon_a = 3.4445 ,
96103 muon_b = - 4.7750 ,
97104 muon_c = 2.0315 ,
98- adam_lr_ratio = 0. 1 ,
105+ adam_lr_ratio = 1 ,
99106 momentum = 0.95 ,
100- ns_steps = 6 ,
107+ ns_steps = 5 ,
101108 nesterov = True ,
109+ rms_rate = 0.2 ,
102110 ** kwargs ,
103111 ):
104112 super ().__init__ (
@@ -127,12 +135,13 @@ def __init__(
127135 self .nesterov = nesterov
128136 self .exclude_embeddings = exclude_embeddings
129137 self .exclude_layers = exclude_layers or []
138+ self .adam_weight_decay = adam_weight_decay
139+ self .rms_rate = rms_rate
130140
131141 def _should_use_adamw (self , variable ):
132- # To use it with 4D convolutional filters,
133142 # it works well to just flatten their last 3 dimensions.
134143 # any {0,1}-D parameters should all be optimized by adam
135- if not 1 < len (variable .shape ) < 4 :
144+ if len (variable .shape ) != 2 :
136145 return True
137146 if self .exclude_embeddings and "embedding" in variable .path .lower ():
138147 return True
@@ -185,18 +194,13 @@ def update_step(self, gradient, variable, learning_rate):
185194 def _muon_update_step (self , gradient , variable , lr ):
186195 m = self .adam_momentums [variable .path ]
187196 self .assign_add (m , ops .add (gradient , m * (self .momentum - 1 )))
188- shape = variable .shape
189197 if self .nesterov :
190198 g = ops .add (gradient , self .momentum * m )
191199 else :
192200 g = m
201+ update = self .zeropower_via_newtonschulz5 (g , self .ns_steps )
193202
194- self .assign_sub (
195- variable ,
196- lr
197- * self .zeropower_via_newtonschulz5 (g , self .ns_steps )
198- * max (1 , shape [0 ] / shape [1 ]) ** 0.5 ,
199- )
203+ self .assign_sub (variable , self .lr_adjust (lr * update ))
200204
201205 def _adamw_update_step (self , gradient , variable , learning_rate ):
202206 """Update step given gradient and the associated model variable."""
@@ -239,6 +243,20 @@ def transpose_last_axis(self, X):
239243 X = ops .transpose (X , temp_order )
240244 return X
241245
246+ def lr_adjust (self , x ):
247+ """Adjusts learning rate based on the Moonlight implementation.
248+ This method enhances the stability of Muon, allowing it to use the same
249+ learning rate and weight decay as Adam. For details, see
250+ https://arxiv.org/abs/2502.16982.
251+ For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
252+ where `n` and `m` are the dimensions of the matrix.
253+ """
254+ if self .rms_rate is None :
255+ return x
256+ # moonlight version
257+ # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
258+ return x * ops .sqrt (ops .maximum (x .shape [0 ], x .shape [1 ])) * self .rms_rate
259+
242260 def zeropower_via_newtonschulz5 (self , x , steps : int ):
243261 """We apply the Newton-Schulz iteration to compute matrix G.
244262
@@ -268,6 +286,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
268286 x = self .transpose_last_axis (x )
269287 return x
270288
289+ def _apply_weight_decay (self , variables ):
290+ for variable in variables :
291+ if not self ._use_weight_decay (variable ):
292+ continue
293+ if self ._should_use_adamw (variable ):
294+ weight_decay_value = self .adam_weight_decay
295+ else :
296+ weight_decay_value = self .weight_decay
297+ if weight_decay_value is None :
298+ continue
299+ wd = ops .cast (weight_decay_value , variable .dtype )
300+ lr = ops .cast (self .learning_rate , variable .dtype )
301+ variable .assign (variable - variable * wd * lr )
302+
271303 def get_config (self ):
272304 config = super ().get_config ()
273305 config .update (
@@ -284,6 +316,8 @@ def get_config(self):
284316 "ns_steps" : self .ns_steps ,
285317 "nesterov" : self .nesterov ,
286318 "exclude_embeddings" : self .exclude_embeddings ,
319+ "adam_weight_decay" : self .adam_weight_decay ,
320+ "rms_rate" : self .rms_rate ,
287321 }
288322 )
289323 return config
0 commit comments