- 
                Notifications
    
You must be signed in to change notification settings  - Fork 730
 
Feature muon #3925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature muon #3925
Conversation
| /// hidden layers (weight matrices). Other parameters such as biases and embeddings | ||
| /// should be optimized using a standard method such as AdamW. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are those parameters ignored during training if you use only a single optimizer?
| /// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py | ||
| /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py | ||
| fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> { | ||
| assert!( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsure if this should be the default behavior. There isn't a great way yet to define multiple optimizers for a single burn module (e.g a linear layer with a bias vector). Do you have an idea: @laggui ?
Add muon optimizer to burn-optim
What's new?:
Notes:
bfloat16and so is pytorch implementation.f32.bfloat16tests could be added soon.Test summary
test_adjust_lr_fn_original- Verifies the Original learning-rate adjustment ratios for square, tall, and wide matrices.test_adjust_lr_fn_match_rms_adamw- Verifies the MatchRmsAdamW learning-rate adjustment ratios for example shapes.test_1d_tensor_panics- Ensures Newton–Schulz orthogonalization panics for 1D tensors (requires 2D).test_muon_optimizer_save_load_state- Verifies optimizer state can be saved and loaded for a Linear layer without bias.test_muon_with_weight_decay- Ensures weight decay is applied (weights are reduced) for a Linear layer without bias.test_newton_schulz_orthogonalization- Checks Newton–Schulz produces approximately orthogonal output (A * A^T ≈ I).test_tall_matrix_transpose- Ensures tall matrices are transposed internally and shape is preserved; verifies orthogonalization changes values and wide-matrix behavior.test_zero_gradient— Confirms Muon handles zero gradients without NaNs, creates state, and weight decay still reduces values when gradients are zero.Related issue, readings, etc.