Skip to content

Commit e6dd99f

Browse files
committed
modify muon
1 parent af595c5 commit e6dd99f

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

keras/src/optimizers/muon.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class Muon(optimizer.Optimizer):
5050
that takes no arguments and returns the actual value to use.
5151
The exponential decay rate for the 2nd moment estimates. Defaults to
5252
`0.999`.
53+
adam_weight_decay: Float. If set, weight decay is applied when using
54+
the Adam optimizer.
5355
epsilon: A small constant for numerical stability. This is
5456
"epsilon hat" in the Kingma and Ba paper
5557
(in the formula just before Section 2.1),
@@ -79,6 +81,7 @@ def __init__(
7981
learning_rate=0.001,
8082
adam_beta_1=0.9,
8183
adam_beta_2=0.999,
84+
adam_weight_decay=0.004,
8285
epsilon=1e-7,
8386
weight_decay=0.1,
8487
clipnorm=None,
@@ -127,6 +130,7 @@ def __init__(
127130
self.nesterov = nesterov
128131
self.exclude_embeddings = exclude_embeddings
129132
self.exclude_layers = exclude_layers or []
133+
self.adam_weight_decay = adam_weight_decay
130134

131135
def _should_use_adamw(self, variable):
132136
# To use it with 4D convolutional filters,
@@ -268,6 +272,18 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
268272
x = self.transpose_last_axis(x)
269273
return x
270274

275+
def _apply_weight_decay(self, variables):
276+
if self.weight_decay is None:
277+
return
278+
for variable in variables:
279+
if self._use_weight_decay(variable):
280+
lr = ops.cast(self.learning_rate, variable.dtype)
281+
if self._should_use_adamw(variable):
282+
wd = ops.cast(self.adam_weight_decay, variable.dtype)
283+
else:
284+
wd = ops.cast(self.weight_decay, variable.dtype)
285+
variable.assign(variable - variable * wd * lr)
286+
271287
def get_config(self):
272288
config = super().get_config()
273289
config.update(
@@ -284,6 +300,7 @@ def get_config(self):
284300
"ns_steps": self.ns_steps,
285301
"nesterov": self.nesterov,
286302
"exclude_embeddings": self.exclude_embeddings,
303+
"adam_weight_decay": self.adam_weight_decay,
287304
}
288305
)
289306
return config

0 commit comments

Comments
 (0)