-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Modify Muon optimizer #21885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify Muon optimizer #21885
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ class Muon(optimizer.Optimizer): | |
| The Muon optimizer can use both the Muon update step or the | ||
| AdamW update step based on the following: | ||
|
|
||
| - For any variable that isn't 2D, 3D or 4D, the AdamW step | ||
| - For any variable that isn't 2D, the AdamW step | ||
| will be used. This is not configurable. | ||
| - If the argument `exclude_embeddings` (defaults to `True`) is set | ||
| to `True`, the AdamW step will be used. | ||
|
|
@@ -46,10 +46,12 @@ class Muon(optimizer.Optimizer): | |
| that takes no arguments and returns the actual value to use. | ||
| The exponential decay rate for the 1st moment estimates. Defaults to | ||
| `0.9`. | ||
| adam_beta_2: A float value or a constant float tensor, ora callable | ||
| adam_beta_2: A float value or a constant float tensor, or a callable | ||
| that takes no arguments and returns the actual value to use. | ||
| The exponential decay rate for the 2nd moment estimates. Defaults to | ||
| `0.999`. | ||
| adam_weight_decay: Float. If set, weight decay is applied when using | ||
| the Adam optimizer. | ||
| epsilon: A small constant for numerical stability. This is | ||
| "epsilon hat" in the Kingma and Ba paper | ||
| (in the formula just before Section 2.1), | ||
|
|
@@ -67,20 +69,26 @@ class Muon(optimizer.Optimizer): | |
| It is recommended to use the default value | ||
| adam_lr_ratio: Float, the ratio of the learning rate when | ||
| using Adam to the main learning rate. | ||
| it is recommended to set it to 0.1 | ||
| it is recommended to set it to 1 | ||
| momentum: Float, momentum used by internal SGD. | ||
| ns_steps: Integer, number of Newton-Schulz iterations to run. | ||
| nesterov: Boolean, whether to use Nesterov-style momentum | ||
| {{base_optimizer_keyword_args}} | ||
| `rms_rate`: A trick from https://arxiv.org/abs/2502.16982. | ||
| This parameter can enhance the stability of Muon, | ||
| allowing it to use the same learning rate and weight decay as Adam. | ||
| It is default to set it to `0.2` | ||
| If you wish to disable it, it is set None. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| learning_rate=0.001, | ||
| adam_beta_1=0.9, | ||
| adam_beta_2=0.999, | ||
| adam_weight_decay=0.004, | ||
| epsilon=1e-7, | ||
| weight_decay=0.1, | ||
| weight_decay=0.004, | ||
| clipnorm=None, | ||
| clipvalue=None, | ||
| global_clipnorm=None, | ||
|
|
@@ -95,10 +103,11 @@ def __init__( | |
| muon_a=3.4445, | ||
| muon_b=-4.7750, | ||
| muon_c=2.0315, | ||
| adam_lr_ratio=0.1, | ||
| adam_lr_ratio=1, | ||
| momentum=0.95, | ||
| ns_steps=6, | ||
| ns_steps=5, | ||
| nesterov=True, | ||
| rms_rate=0.2, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
|
|
@@ -127,12 +136,14 @@ def __init__( | |
| self.nesterov = nesterov | ||
| self.exclude_embeddings = exclude_embeddings | ||
| self.exclude_layers = exclude_layers or [] | ||
| self.adam_weight_decay = adam_weight_decay | ||
| self.rms_rate = rms_rate | ||
|
|
||
| def _should_use_adamw(self, variable): | ||
| # To use it with 4D convolutional filters, | ||
| # it works well to just flatten their last 3 dimensions. | ||
pass-lin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # any {0,1}-D parameters should all be optimized by adam | ||
| if not 1 < len(variable.shape) < 4: | ||
| if len(variable.shape) != 2: | ||
| return True | ||
|
Comment on lines
+144
to
145
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not following this change. In the Moonlight implementation the criteria for using Muon is that the So for AdamW, the criteria would be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The optimization target of Muon is matrices. In the 3D case, reshaping into matrices is necessary for effective optimization. However, this involves too many assumptions, and introducing it would only unnecessarily increase complexity. In fact, Muon never considered the case of CNNs. It was designed with only 1D-Transformer scenarios in mind.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the original implementation of MoonLight, they could ensure that the optimization target is a Transformer model based on PyTorch. However, in the Keras implementation, we cannot guarantee this. For example, in a typical case with the PyTorch backend, if we mix Similarly, if the optimization target is a 3D CNN model, the parameter meanings for the CNN model differ between the "channels_last" and "channels_first" formats. We lack reasonable assumptions to perform reshaping in such cases. The Muon optimizer in Keras should be a general-purpose optimizer, and a general-purpose optimizer should not rely on too many assumptions. Therefore, we can only use the most conservative approach: we do not optimize anything other than matrices. This is also the reason why we do not use the Keller Jordan Version. The Keller Jordan Version assumes that the optimized matrix must be either
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation. |
||
| if self.exclude_embeddings and "embedding" in variable.path.lower(): | ||
| return True | ||
|
|
@@ -185,17 +196,15 @@ def update_step(self, gradient, variable, learning_rate): | |
| def _muon_update_step(self, gradient, variable, lr): | ||
| m = self.adam_momentums[variable.path] | ||
| self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) | ||
| shape = variable.shape | ||
| if self.nesterov: | ||
| g = ops.add(gradient, self.momentum * m) | ||
| else: | ||
| g = m | ||
| update = self.zeropower_via_newtonschulz5(g, self.ns_steps) | ||
|
|
||
| self.assign_sub( | ||
| variable, | ||
| lr | ||
| * self.zeropower_via_newtonschulz5(g, self.ns_steps) | ||
| * max(1, shape[0] / shape[1]) ** 0.5, | ||
| self.lr_adjust(lr * update), | ||
| ) | ||
|
||
|
|
||
| def _adamw_update_step(self, gradient, variable, learning_rate): | ||
|
|
@@ -239,6 +248,18 @@ def transpose_last_axis(self, X): | |
| X = ops.transpose(X, temp_order) | ||
| return X | ||
|
|
||
| def lr_adjust(self, x): | ||
| """ | ||
| You can check the details at https://arxiv.org/pdf/2502.16982. | ||
| For a 2D matrix of size m,the analytical solution provided in the paper | ||
| rate * x * sqrt(max(n,m)) | ||
| """ | ||
pass-lin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.rms_rate is None: | ||
| return x | ||
| # moonlight version | ||
| # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py | ||
| return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate | ||
|
|
||
| def zeropower_via_newtonschulz5(self, x, steps: int): | ||
| """We apply the Newton-Schulz iteration to compute matrix G. | ||
|
|
||
|
|
@@ -268,6 +289,20 @@ def zeropower_via_newtonschulz5(self, x, steps: int): | |
| x = self.transpose_last_axis(x) | ||
| return x | ||
|
|
||
| def _apply_weight_decay(self, variables): | ||
| for variable in variables: | ||
| if self._use_weight_decay(variable): | ||
| if self._should_use_adamw(variable): | ||
| if self.adam_weight_decay is None: | ||
| continue | ||
| wd = ops.cast(self.adam_weight_decay, variable.dtype) | ||
| else: | ||
| if self.weight_decay is None: | ||
| continue | ||
| wd = ops.cast(self.weight_decay, variable.dtype) | ||
| lr = ops.cast(self.learning_rate, variable.dtype) | ||
| variable.assign(variable - variable * wd * lr) | ||
pass-lin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
|
|
@@ -284,6 +319,8 @@ def get_config(self): | |
| "ns_steps": self.ns_steps, | ||
| "nesterov": self.nesterov, | ||
| "exclude_embeddings": self.exclude_embeddings, | ||
| "adam_weight_decay": self.adam_weight_decay, | ||
| "rms_rate": self.rms_rate, | ||
| } | ||
| ) | ||
| return config | ||
Uh oh!
There was an error while loading. Please reload this page.