-
Notifications
You must be signed in to change notification settings - Fork 994
add muon clip optimizer #6662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add muon clip optimizer #6662
Conversation
…ner (modelscope#6638) The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator.
Summary of ChangesHello @vx120, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a novel Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the MuonClip optimizer, an extension of the Muon optimizer with QK-clipping for improved training stability. The changes include the core optimizer implementation, integration into the optimizer factory, and UI components for configuration.
My review has identified a few critical issues, including a potential TypeError due to incorrect data type handling in both the optimizer creation logic and the core numerical iteration function. I've also found a high-severity UI bug where a tab is duplicated. Additionally, there are opportunities for code refactoring to improve maintainability, removal of dead code, and standardizing comment language for consistency.
Overall, the feature is a valuable addition, but the identified issues should be addressed to ensure correctness and robustness.
| def zeropower_via_newtonschulz5(G, steps): | ||
| """ | ||
| Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. | ||
| Optimized version with torch.compile. | ||
| """ | ||
| assert len(G.shape) == 2 | ||
| a, b, c = (3.4445, -4.7750, 2.0315) | ||
| X = G.bfloat16() | ||
| if G.size(0) > G.size(1): | ||
| X = X.T | ||
| # Ensure spectral norm is at most 1 | ||
| X = X / (X.norm() + 1e-7) | ||
| # Perform the NS iterations | ||
| for _ in range(steps): | ||
| A = X @ X.T | ||
| B = b * A + c * A @ A | ||
| X = a * X + B @ X | ||
|
|
||
| if G.size(0) > G.size(1): | ||
| X = X.T | ||
| return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function zeropower_via_newtonschulz5 hard-converts the input tensor G to bfloat16 and returns a bfloat16 tensor. This can cause a TypeError in _apply_muon_update if the model's parameter dtype does not match, as p.data.add_ requires matching dtypes. For numerical stability and correctness, it's better to perform calculations in float32 and then cast back to the original dtype, similar to the newton_schulz function. This also provides an opportunity to avoid checking the matrix dimensions twice.
def zeropower_via_newtonschulz5(G, steps):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
Optimized version with torch.compile.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.float()
transposed = G.size(0) > G.size(1)
if transposed:
X = X.T
# Ensure spectral norm is at most 1
X = X / (X.norm() + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if transposed:
X = X.T
return X.to(G.dtype)| optim_args[key] = value | ||
|
|
||
| # Set default values for MuonClip parameters | ||
| lr = optim_args.get('lr', args.learning_rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The learning rate lr is retrieved from optim_args as a string but not converted to a float. This will likely cause a TypeError when it's passed to the MuonClip optimizer, which expects a float. All other parameters from optim_args are correctly converted.
| lr = optim_args.get('lr', args.learning_rate) | |
| lr = float(optim_args.get('lr', args.learning_rate)) |
| with gr.TabItem(elem_id='multimodal_tab'): | ||
| with gr.Row(): | ||
| gr.Textbox(elem_id='vit_lr', lines=1, scale=20) | ||
| gr.Textbox(elem_id='aligner_lr', lines=1, scale=20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def newton_schulz(G: torch.Tensor, steps: int = 5, eps: float = 1e-7) -> torch.Tensor: | ||
| """ | ||
| Newton-Schulz iteration for matrix orthogonalization. | ||
| """ | ||
| # Coefficients from Muon paper | ||
| a, b, c = (3.4445, -4.7750, 2.0315) | ||
|
|
||
| # Convert to float for precision | ||
| X = G.float() | ||
| X /= (X.norm() + eps) | ||
|
|
||
| # Handle rectangular matrices by transposing | ||
| if G.size(0) > G.size(1): | ||
| X = X.T | ||
| transposed = True | ||
| else: | ||
| transposed = False | ||
|
|
||
| # Newton-Schulz iterations | ||
| for _ in range(steps): | ||
| A = X @ X.T | ||
| B = b * A + c * A @ A | ||
| X = a * X + B @ X | ||
|
|
||
| # Transpose back if needed | ||
| if transposed: | ||
| X = X.T | ||
|
|
||
| return X.to(G.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Store parameter names for classification | ||
| self.param_names = {} | ||
|
|
||
| # 修复:避免在初始化时立即分类参数,等待set_model调用 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments in this file are a mix of English and Chinese. For consistency and to make the code accessible to a wider audience, it's best to use English for all comments and docstrings. For example:
- line 114:
# 修复:避免在初始化时立即分类参数,等待set_model调用should be# Fix: Avoid classifying parameters immediately on initialization, wait for set_model call. - line 128:
# 修复:使用更安全的方式获取参数名称should be# Fix: Use a safer way to get parameter name. - line 139:
# 特别处理weight_norm相关的参数should be# Special handling for weight_norm related parameters. - line 159:
# 处理可能的deepcopy错误should be# Handle possible deepcopy errors. - line 172:
# 修复:先移除可能的weight_norm,然后再进行参数操作should be# Fix: First handle possible weight_norm, then perform parameter operations. - line 186:
"""处理weight_norm相关的深度拷贝问题"""should be"""Handle deepcopy issues related to weight_norm.""". - line 190:
# 检查模型中是否使用了weight_normshould be# Check if weight_norm is used in the model. - line 356:
# 确保参数已经分类should be# Ensure parameters are classified.
| # For Qwen2-style attention | ||
| if hasattr(attention_layer, 'q_proj') and hasattr(attention_layer, 'k_proj'): | ||
| max_logits = getattr(attention_layer, 'max_logits', 0.0) | ||
|
|
||
| if max_logits > tau: | ||
| gamma = tau / max_logits | ||
| sqrt_gamma = math.sqrt(gamma) | ||
|
|
||
| # Apply scaling to query and key projection weights | ||
| with torch.no_grad(): | ||
| attention_layer.q_proj.weight.data *= sqrt_gamma | ||
| attention_layer.k_proj.weight.data *= sqrt_gamma | ||
|
|
||
| # Reset max_logits | ||
| if hasattr(attention_layer, 'max_logits'): | ||
| attention_layer.max_logits = 0.0 | ||
|
|
||
| # For standard attention | ||
| elif hasattr(attention_layer, 'query') and hasattr(attention_layer, 'key'): | ||
| max_logits = getattr(attention_layer, 'max_logits', 0.0) | ||
|
|
||
| if max_logits > tau: | ||
| gamma = tau / max_logits | ||
| sqrt_gamma = math.sqrt(gamma) | ||
|
|
||
| with torch.no_grad(): | ||
| attention_layer.query.weight.data *= sqrt_gamma | ||
| attention_layer.key.weight.data *= sqrt_gamma | ||
|
|
||
| if hasattr(attention_layer, 'max_logits'): | ||
| attention_layer.max_logits = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for applying QK-Clip is duplicated for Qwen2-style attention layers (q_proj, k_proj) and standard attention layers (query, key). This can be refactored to reduce code duplication and improve maintainability.
q_proj, k_proj = None, None
# For Qwen2-style attention
if hasattr(attention_layer, 'q_proj') and hasattr(attention_layer, 'k_proj'):
q_proj = attention_layer.q_proj
k_proj = attention_layer.k_proj
# For standard attention
elif hasattr(attention_layer, 'query') and hasattr(attention_layer, 'key'):
q_proj = attention_layer.query
k_proj = attention_layer.key
if q_proj and k_proj:
max_logits = getattr(attention_layer, 'max_logits', 0.0)
if max_logits > tau:
gamma = tau / max_logits
sqrt_gamma = math.sqrt(gamma)
# Apply scaling to query and key projection weights
with torch.no_grad():
q_proj.weight.data *= sqrt_gamma
k_proj.weight.data *= sqrt_gamma
# Reset max_logits
if hasattr(attention_layer, 'max_logits'):
attention_layer.max_logits = 0.0|
Hello, please change the Chinese comments to English comments. 😊 |
PR type
PR information
I add a muon clip optimizer, an enhanced version of the muon optimizer that incorporates qk-clip to mitigate maxlogit explosion and improve training stability.