Skip to content

Commit f8409e7

Browse files
committed
modify muon
1 parent 441dbba commit f8409e7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

keras/src/optimizers/muon.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,17 @@ def zeropower_via_newtonschulz5(self, x, steps: int):
273273
return x
274274

275275
def _apply_weight_decay(self, variables):
276-
if self.weight_decay is None:
277-
return
278276
for variable in variables:
279277
if self._use_weight_decay(variable):
280-
lr = ops.cast(self.learning_rate, variable.dtype)
281278
if self._should_use_adamw(variable):
279+
if self.adam_weight_decay is None:
280+
return
282281
wd = ops.cast(self.adam_weight_decay, variable.dtype)
283282
else:
283+
if self.weight_decay is None:
284+
return
284285
wd = ops.cast(self.weight_decay, variable.dtype)
286+
lr = ops.cast(self.learning_rate, variable.dtype)
285287
variable.assign(variable - variable * wd * lr)
286288

287289
def get_config(self):

0 commit comments

Comments
 (0)