@@ -50,6 +50,8 @@ class Muon(optimizer.Optimizer):
5050 that takes no arguments and returns the actual value to use.
5151 The exponential decay rate for the 2nd moment estimates. Defaults to
5252 `0.999`.
53+ adam_weight_decay: Float. If set, weight decay is applied when using
54+ the Adam optimizer.
5355 epsilon: A small constant for numerical stability. This is
5456 "epsilon hat" in the Kingma and Ba paper
5557 (in the formula just before Section 2.1),
@@ -79,6 +81,7 @@ def __init__(
7981 learning_rate = 0.001 ,
8082 adam_beta_1 = 0.9 ,
8183 adam_beta_2 = 0.999 ,
84+ adam_weight_decay = 0.004 ,
8285 epsilon = 1e-7 ,
8386 weight_decay = 0.1 ,
8487 clipnorm = None ,
@@ -127,6 +130,7 @@ def __init__(
127130 self .nesterov = nesterov
128131 self .exclude_embeddings = exclude_embeddings
129132 self .exclude_layers = exclude_layers or []
133+ self .adam_weight_decay = adam_weight_decay
130134
131135 def _should_use_adamw (self , variable ):
132136 # To use it with 4D convolutional filters,
@@ -268,6 +272,18 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
268272 x = self .transpose_last_axis (x )
269273 return x
270274
275+ def _apply_weight_decay (self , variables ):
276+ if self .weight_decay is None :
277+ return
278+ for variable in variables :
279+ if self ._use_weight_decay (variable ):
280+ lr = ops .cast (self .learning_rate , variable .dtype )
281+ if self ._should_use_adamw (variable ):
282+ wd = ops .cast (self .adam_weight_decay , variable .dtype )
283+ else :
284+ wd = ops .cast (self .weight_decay , variable .dtype )
285+ variable .assign (variable - variable * wd * lr )
286+
271287 def get_config (self ):
272288 config = super ().get_config ()
273289 config .update (
@@ -284,6 +300,7 @@ def get_config(self):
284300 "ns_steps" : self .ns_steps ,
285301 "nesterov" : self .nesterov ,
286302 "exclude_embeddings" : self .exclude_embeddings ,
303+ "adam_weight_decay" : self .adam_weight_decay ,
287304 }
288305 )
289306 return config
0 commit comments