Skip to content

Conversation

@vx120
Copy link

@vx120 vx120 commented Nov 19, 2025

PR type

  • New Feature

PR information

I add a muon clip optimizer, an enhanced version of the muon optimizer that incorporates qk-clip to mitigate maxlogit explosion and improve training stability.

vx120 and others added 30 commits November 13, 2025 14:07
…ner (modelscope#6638)

The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vx120, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a novel MuonClip optimizer designed to enhance the stability and performance of model training. It combines the principles of the Muon optimizer with QK-Clip functionality, specifically targeting the prevention of maxlogit explosion in attention mechanisms. The changes also address and provide warnings for deepcopy issues that can arise with weight_norm layers, ensuring a more robust optimization process. The new optimizer is fully integrated into the existing swift framework, making it readily available for use.

Highlights

  • New Optimizer: Introduced a new MuonClip optimizer, which is an enhanced version of the Muon optimizer.
  • QK-Clip Integration: The MuonClip optimizer incorporates QK-Clip to prevent maxlogit explosion and improve training stability in attention layers.
  • Weight Norm Handling: The implementation includes fixes and warnings for potential deepcopy issues related to weight_norm layers in the model.
  • Framework Integration: The MuonClip optimizer is integrated into the swift framework's optimizer creation factory and its UI for easy selection and configuration.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the MuonClip optimizer, an extension of the Muon optimizer with QK-clipping for improved training stability. The changes include the core optimizer implementation, integration into the optimizer factory, and UI components for configuration.

My review has identified a few critical issues, including a potential TypeError due to incorrect data type handling in both the optimizer creation logic and the core numerical iteration function. I've also found a high-severity UI bug where a tab is duplicated. Additionally, there are opportunities for code refactoring to improve maintainability, removal of dead code, and standardizing comment language for consistency.

Overall, the feature is a valuable addition, but the identified issues should be addressed to ensure correctness and robustness.

Comment on lines +43 to +63
def zeropower_via_newtonschulz5(G, steps):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
Optimized version with torch.compile.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(0) > G.size(1):
X = X.T
# Ensure spectral norm is at most 1
X = X / (X.norm() + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X

if G.size(0) > G.size(1):
X = X.T
return X
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function zeropower_via_newtonschulz5 hard-converts the input tensor G to bfloat16 and returns a bfloat16 tensor. This can cause a TypeError in _apply_muon_update if the model's parameter dtype does not match, as p.data.add_ requires matching dtypes. For numerical stability and correctness, it's better to perform calculations in float32 and then cast back to the original dtype, similar to the newton_schulz function. This also provides an opportunity to avoid checking the matrix dimensions twice.

def zeropower_via_newtonschulz5(G, steps):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
    Optimized version with torch.compile.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.float()
    transposed = G.size(0) > G.size(1)
    if transposed:
        X = X.T
    # Ensure spectral norm is at most 1
    X = X / (X.norm() + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * A @ A
        X = a * X + B @ X

    if transposed:
        X = X.T
    return X.to(G.dtype)

optim_args[key] = value

# Set default values for MuonClip parameters
lr = optim_args.get('lr', args.learning_rate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The learning rate lr is retrieved from optim_args as a string but not converted to a float. This will likely cause a TypeError when it's passed to the MuonClip optimizer, which expects a float. All other parameters from optim_args are correctly converted.

Suggested change
lr = optim_args.get('lr', args.learning_rate)
lr = float(optim_args.get('lr', args.learning_rate))

Comment on lines +158 to +161
with gr.TabItem(elem_id='multimodal_tab'):
with gr.Row():
gr.Textbox(elem_id='vit_lr', lines=1, scale=20)
gr.Textbox(elem_id='aligner_lr', lines=1, scale=20)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The multimodal_tab is defined twice in the do_build_ui method. This will create a duplicate tab in the user interface, which is likely unintended and can cause confusion. The second instance should be removed.

Comment on lines +11 to +39
def newton_schulz(G: torch.Tensor, steps: int = 5, eps: float = 1e-7) -> torch.Tensor:
"""
Newton-Schulz iteration for matrix orthogonalization.
"""
# Coefficients from Muon paper
a, b, c = (3.4445, -4.7750, 2.0315)

# Convert to float for precision
X = G.float()
X /= (X.norm() + eps)

# Handle rectangular matrices by transposing
if G.size(0) > G.size(1):
X = X.T
transposed = True
else:
transposed = False

# Newton-Schulz iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X

# Transpose back if needed
if transposed:
X = X.T

return X.to(G.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function newton_schulz is defined but appears to be unused. The optimizer uses zeropower_via_newtonschulz5 instead. If this function is not needed, it should be removed to avoid dead code and improve maintainability.

# Store parameter names for classification
self.param_names = {}

# 修复:避免在初始化时立即分类参数,等待set_model调用
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments in this file are a mix of English and Chinese. For consistency and to make the code accessible to a wider audience, it's best to use English for all comments and docstrings. For example:

  • line 114: # 修复:避免在初始化时立即分类参数,等待set_model调用 should be # Fix: Avoid classifying parameters immediately on initialization, wait for set_model call.
  • line 128: # 修复:使用更安全的方式获取参数名称 should be # Fix: Use a safer way to get parameter name.
  • line 139: # 特别处理weight_norm相关的参数 should be # Special handling for weight_norm related parameters.
  • line 159: # 处理可能的deepcopy错误 should be # Handle possible deepcopy errors.
  • line 172: # 修复:先移除可能的weight_norm,然后再进行参数操作 should be # Fix: First handle possible weight_norm, then perform parameter operations.
  • line 186: """处理weight_norm相关的深度拷贝问题""" should be """Handle deepcopy issues related to weight_norm.""".
  • line 190: # 检查模型中是否使用了weight_norm should be # Check if weight_norm is used in the model.
  • line 356: # 确保参数已经分类 should be # Ensure parameters are classified.

Comment on lines +314 to +344
# For Qwen2-style attention
if hasattr(attention_layer, 'q_proj') and hasattr(attention_layer, 'k_proj'):
max_logits = getattr(attention_layer, 'max_logits', 0.0)

if max_logits > tau:
gamma = tau / max_logits
sqrt_gamma = math.sqrt(gamma)

# Apply scaling to query and key projection weights
with torch.no_grad():
attention_layer.q_proj.weight.data *= sqrt_gamma
attention_layer.k_proj.weight.data *= sqrt_gamma

# Reset max_logits
if hasattr(attention_layer, 'max_logits'):
attention_layer.max_logits = 0.0

# For standard attention
elif hasattr(attention_layer, 'query') and hasattr(attention_layer, 'key'):
max_logits = getattr(attention_layer, 'max_logits', 0.0)

if max_logits > tau:
gamma = tau / max_logits
sqrt_gamma = math.sqrt(gamma)

with torch.no_grad():
attention_layer.query.weight.data *= sqrt_gamma
attention_layer.key.weight.data *= sqrt_gamma

if hasattr(attention_layer, 'max_logits'):
attention_layer.max_logits = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for applying QK-Clip is duplicated for Qwen2-style attention layers (q_proj, k_proj) and standard attention layers (query, key). This can be refactored to reduce code duplication and improve maintainability.

            q_proj, k_proj = None, None
            # For Qwen2-style attention
            if hasattr(attention_layer, 'q_proj') and hasattr(attention_layer, 'k_proj'):
                q_proj = attention_layer.q_proj
                k_proj = attention_layer.k_proj
            # For standard attention
            elif hasattr(attention_layer, 'query') and hasattr(attention_layer, 'key'):
                q_proj = attention_layer.query
                k_proj = attention_layer.key

            if q_proj and k_proj:
                max_logits = getattr(attention_layer, 'max_logits', 0.0)
                
                if max_logits > tau:
                    gamma = tau / max_logits
                    sqrt_gamma = math.sqrt(gamma)
                    
                    # Apply scaling to query and key projection weights
                    with torch.no_grad():
                        q_proj.weight.data *= sqrt_gamma
                        k_proj.weight.data *= sqrt_gamma
                    
                    # Reset max_logits
                    if hasattr(attention_layer, 'max_logits'):
                        attention_layer.max_logits = 0.0

@Jintao-Huang
Copy link
Collaborator

Hello, please change the Chinese comments to English comments. 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants