Skip to content

Commit d130816

Browse files
authored
Modify Muon optimizer (#21885)
* modify muon. * modify gemini review. * modify
1 parent 8c87f5d commit d130816

File tree

2 files changed

+82
-16
lines changed

2 files changed

+82
-16
lines changed

keras/src/optimizers/muon.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer):
2020
The Muon optimizer can use both the Muon update step or the
2121
AdamW update step based on the following:
2222
23-
- For any variable that isn't 2D, 3D or 4D, the AdamW step
23+
- For any variable that isn't 2D, the AdamW step
2424
will be used. This is not configurable.
2525
- If the argument `exclude_embeddings` (defaults to `True`) is set
2626
to `True`, the AdamW step will be used.
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer):
4646
that takes no arguments and returns the actual value to use.
4747
The exponential decay rate for the 1st moment estimates. Defaults to
4848
`0.9`.
49-
adam_beta_2: A float value or a constant float tensor, ora callable
49+
adam_beta_2: A float value or a constant float tensor, or a callable
5050
that takes no arguments and returns the actual value to use.
5151
The exponential decay rate for the 2nd moment estimates. Defaults to
5252
`0.999`.
53+
adam_weight_decay: Float. If set, weight decay is applied when using
54+
the Adam optimizer.
5355
epsilon: A small constant for numerical stability. This is
5456
"epsilon hat" in the Kingma and Ba paper
5557
(in the formula just before Section 2.1),
@@ -67,20 +69,25 @@ class Muon(optimizer.Optimizer):
6769
It is recommended to use the default value
6870
adam_lr_ratio: Float, the ratio of the learning rate when
6971
using Adam to the main learning rate.
70-
it is recommended to set it to 0.1
72+
It is recommended to set it to 1
7173
momentum: Float, momentum used by internal SGD.
7274
ns_steps: Integer, number of Newton-Schulz iterations to run.
7375
nesterov: Boolean, whether to use Nesterov-style momentum
7476
{{base_optimizer_keyword_args}}
77+
rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982
78+
that can enhance the stability of Muon, allowing it to use the
79+
same learning rate and weight decay as Adam. Defaults to `0.2`.
80+
Set to `None` to disable this feature.
7581
"""
7682

7783
def __init__(
7884
self,
7985
learning_rate=0.001,
8086
adam_beta_1=0.9,
8187
adam_beta_2=0.999,
88+
adam_weight_decay=0.004,
8289
epsilon=1e-7,
83-
weight_decay=0.1,
90+
weight_decay=0.004,
8491
clipnorm=None,
8592
clipvalue=None,
8693
global_clipnorm=None,
@@ -95,10 +102,11 @@ def __init__(
95102
muon_a=3.4445,
96103
muon_b=-4.7750,
97104
muon_c=2.0315,
98-
adam_lr_ratio=0.1,
105+
adam_lr_ratio=1,
99106
momentum=0.95,
100-
ns_steps=6,
107+
ns_steps=5,
101108
nesterov=True,
109+
rms_rate=0.2,
102110
**kwargs,
103111
):
104112
super().__init__(
@@ -127,12 +135,13 @@ def __init__(
127135
self.nesterov = nesterov
128136
self.exclude_embeddings = exclude_embeddings
129137
self.exclude_layers = exclude_layers or []
138+
self.adam_weight_decay = adam_weight_decay
139+
self.rms_rate = rms_rate
130140

131141
def _should_use_adamw(self, variable):
132-
# To use it with 4D convolutional filters,
133142
# it works well to just flatten their last 3 dimensions.
134143
# any {0,1}-D parameters should all be optimized by adam
135-
if not 1 < len(variable.shape) < 4:
144+
if len(variable.shape) != 2:
136145
return True
137146
if self.exclude_embeddings and "embedding" in variable.path.lower():
138147
return True
@@ -185,18 +194,13 @@ def update_step(self, gradient, variable, learning_rate):
185194
def _muon_update_step(self, gradient, variable, lr):
186195
m = self.adam_momentums[variable.path]
187196
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
188-
shape = variable.shape
189197
if self.nesterov:
190198
g = ops.add(gradient, self.momentum * m)
191199
else:
192200
g = m
201+
update = self.zeropower_via_newtonschulz5(g, self.ns_steps)
193202

194-
self.assign_sub(
195-
variable,
196-
lr
197-
* self.zeropower_via_newtonschulz5(g, self.ns_steps)
198-
* max(1, shape[0] / shape[1]) ** 0.5,
199-
)
203+
self.assign_sub(variable, self.lr_adjust(lr * update))
200204

201205
def _adamw_update_step(self, gradient, variable, learning_rate):
202206
"""Update step given gradient and the associated model variable."""
@@ -239,6 +243,20 @@ def transpose_last_axis(self, X):
239243
X = ops.transpose(X, temp_order)
240244
return X
241245

246+
def lr_adjust(self, x):
247+
"""Adjusts learning rate based on the Moonlight implementation.
248+
This method enhances the stability of Muon, allowing it to use the same
249+
learning rate and weight decay as Adam. For details, see
250+
https://arxiv.org/abs/2502.16982.
251+
For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,
252+
where `n` and `m` are the dimensions of the matrix.
253+
"""
254+
if self.rms_rate is None:
255+
return x
256+
# moonlight version
257+
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
258+
return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate
259+
242260
def zeropower_via_newtonschulz5(self, x, steps: int):
243261
"""We apply the Newton-Schulz iteration to compute matrix G.
244262
@@ -268,6 +286,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
268286
x = self.transpose_last_axis(x)
269287
return x
270288

289+
def _apply_weight_decay(self, variables):
290+
for variable in variables:
291+
if not self._use_weight_decay(variable):
292+
continue
293+
if self._should_use_adamw(variable):
294+
weight_decay_value = self.adam_weight_decay
295+
else:
296+
weight_decay_value = self.weight_decay
297+
if weight_decay_value is None:
298+
continue
299+
wd = ops.cast(weight_decay_value, variable.dtype)
300+
lr = ops.cast(self.learning_rate, variable.dtype)
301+
variable.assign(variable - variable * wd * lr)
302+
271303
def get_config(self):
272304
config = super().get_config()
273305
config.update(
@@ -284,6 +316,8 @@ def get_config(self):
284316
"ns_steps": self.ns_steps,
285317
"nesterov": self.nesterov,
286318
"exclude_embeddings": self.exclude_embeddings,
319+
"adam_weight_decay": self.adam_weight_decay,
320+
"rms_rate": self.rms_rate,
287321
}
288322
)
289323
return config

keras/src/optimizers/muon_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def test_muon_single_step(self):
7474
optimizer.build([vars])
7575
optimizer._muon_update_step(grads, vars, 0.5)
7676
self.assertAllClose(
77-
vars, [[1.13, 1.51], [2.57, 4.06]], rtol=1e-2, atol=1e-2
77+
vars,
78+
[[0.988775, 1.887053], [2.873428, 3.97035]],
79+
rtol=1e-2,
80+
atol=1e-2,
7881
)
7982

8083
def test_clip_norm(self):
@@ -88,3 +91,32 @@ def test_clip_value(self):
8891
grad = [np.array([100.0, 100.0])]
8992
clipped_grad = optimizer._clip_gradients(grad)
9093
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
94+
95+
def test_muon_weight_decay(self):
96+
variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]])
97+
weight_decay = 0.01
98+
expected_variable = variable - variable * weight_decay
99+
optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)
100+
optimizer._apply_weight_decay([variable])
101+
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)
102+
103+
def test_adamw_weight_decay(self):
104+
variable = backend.Variable(2.0)
105+
weight_decay = 0.01
106+
expected_variable = variable - variable * weight_decay
107+
optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)
108+
optimizer._apply_weight_decay([variable])
109+
110+
self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)
111+
112+
def test_lr_adjust_none(self):
113+
opt = Muon(rms_rate=None)
114+
x = ops.ones((4, 4))
115+
want = x
116+
self.assertAllClose(opt.lr_adjust(x), want)
117+
118+
def test_lr_adjust_2d(self):
119+
opt = Muon(rms_rate=0.2)
120+
x = ops.ones((4, 2))
121+
want = x * 0.2 * 2
122+
self.assertAllClose(opt.lr_adjust(x), want)

0 commit comments

Comments
 (0)