Skip to content

Commit deee214

Browse files
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 45adec1 commit deee214

File tree

6 files changed

+184
-7
lines changed

6 files changed

+184
-7
lines changed

ACKNOWLEDGMENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
1919
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
2020
- Paul Paczuski: Improved stability of BCE loss calculation
2121
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
22+
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
2223

2324
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
2425
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

docs/src/python/optimizers/common_optimizers.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ Common Optimizers
1919
Adamax
2020
Lion
2121
MultiOptimizer
22+
Muon

mlx/backend/cuda/matmul.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ class MatMul {
119119
uint64_t b_rows,
120120
uint64_t b_cols,
121121
int64_t ldb,
122-
bool c_transposed,
123122
int64_t ldc,
124123
int32_t batch_count,
125124
int64_t a_batch_stride,
@@ -141,7 +140,7 @@ class MatMul {
141140
b_batch_stride) {
142141
auto type = dtype_to_cuda_type(dtype);
143142
c_desc_ = create_matrix_layout(
144-
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
143+
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
145144
}
146145

147146
~MatMul() {
@@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
403402
assert(inputs.size() == 3);
404403
auto& a_pre = inputs[0];
405404
auto& b_pre = inputs[1];
406-
auto& c_pre = inputs[2];
407-
408-
out.set_data(allocator::malloc(out.nbytes()));
405+
auto c = inputs[2];
409406

410407
/////////////////////////////////////////////////////////////////////////////
411408
// Init checks and prep
@@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
418415
// the arrays
419416
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
420417
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
421-
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
418+
419+
int64_t ldc;
420+
{
421+
auto stx = c.strides()[c.ndim() - 2];
422+
auto sty = c.strides()[c.ndim() - 1];
423+
if (sty == 1 && stx == c.shape(-1)) {
424+
ldc = stx;
425+
out.set_data(allocator::malloc(out.nbytes()));
426+
} else if (sty == 1 && stx == 0) {
427+
ldc = 0;
428+
out.set_data(allocator::malloc(out.nbytes()));
429+
} else {
430+
// Copy C into out and set C to out
431+
ldc = c.shape(-1);
432+
copy_gpu(c, out, CopyType::General, s);
433+
c = out;
434+
}
435+
}
422436

423437
/////////////////////////////////////////////////////////////////////////////
424438
// Check and collapse batch dimensions
@@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
456470
K,
457471
N,
458472
ldb,
459-
c_transposed,
460473
ldc,
461474
batch_shape.back(),
462475
a_batch_strides.back(),

python/mlx/optimizers/optimizers.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,106 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
848848
return parameter - update
849849

850850

851+
class Muon(Optimizer):
852+
r"""The Muon optimizer.
853+
854+
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
855+
original implementation: `Muon: An optimizer for hidden layers in neural
856+
networks <https://kellerjordan.github.io/posts/muon/>`_
857+
858+
Note:
859+
- Muon may be sub-optimal for the embedding layer, the final fully
860+
connected layer, or any 0D/1D parameters. Those should be optimized
861+
by a different method (e.g., :class:`AdamW`).
862+
- For 4D convolutional filters, it works by flattening their last
863+
dimensions.
864+
865+
Args:
866+
learning_rate (float or callable): The learning rate.
867+
momentum (float, optional): The momentum strength. Default: ``0.95``
868+
weight_decay (float, optional): The weight decay (L2 penalty).
869+
Default: ``0.01``
870+
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
871+
better performance. Default: ``True``
872+
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
873+
orthogonalization. Default: ``5``
874+
"""
875+
876+
def __init__(
877+
self,
878+
learning_rate: Union[float, Callable[[mx.array], mx.array]],
879+
momentum: float = 0.95,
880+
weight_decay: float = 0.01,
881+
nesterov: bool = True,
882+
ns_steps: int = 5,
883+
):
884+
super().__init__()
885+
886+
self._maybe_schedule("learning_rate", learning_rate)
887+
self.momentum = momentum
888+
self.weight_decay = weight_decay
889+
self.nesterov = nesterov
890+
self.ns_steps = ns_steps
891+
892+
def init_single(self, parameter: mx.array, state: dict):
893+
"""Initialize optimizer state"""
894+
state["v"] = mx.zeros_like(parameter)
895+
896+
def _zeropower_via_newtonschulz5(self, X, steps: int):
897+
assert (
898+
X.ndim == 2
899+
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
900+
a, b, c = (3.4445, -4.7750, 2.0315)
901+
transpose_needed = X.shape[-2] > X.shape[-1]
902+
903+
if transpose_needed:
904+
X = X.T
905+
906+
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
907+
908+
for _ in range(steps):
909+
A = X @ X.T
910+
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
911+
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
912+
913+
if transpose_needed:
914+
X = X.T
915+
return X
916+
917+
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
918+
"""Performs the Muon parameter update"""
919+
920+
if self.weight_decay != 0:
921+
gradient = gradient + self.weight_decay * parameter
922+
923+
v = self.momentum * state["v"]
924+
v = v + (1 - self.momentum) * gradient
925+
state["v"] = v
926+
927+
if self.nesterov:
928+
update = gradient * (1 - self.momentum) + v * self.momentum
929+
else:
930+
update = v
931+
932+
lr = self.learning_rate.astype(gradient.dtype)
933+
934+
if update.ndim >= 2:
935+
original_shape = update.shape
936+
reshape_needed = update.ndim > 2
937+
938+
if reshape_needed:
939+
update = mx.reshape(update, (update.shape[0], -1))
940+
941+
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
942+
943+
if reshape_needed:
944+
update = mx.reshape(update, original_shape)
945+
946+
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
947+
948+
return parameter - lr * update
949+
950+
851951
def clip_grad_norm(grads, max_norm):
852952
"""Clips the global norm of the gradients.
853953

python/tests/test_blas.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,21 @@ def test_addmm(self):
691691
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
692692
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
693693

694+
# Transposed c
695+
a = mx.ones((10, 5)).T
696+
b = mx.ones((5, 5))
697+
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
698+
expected = 1.5 * a + 0.5 * (b @ a)
699+
self.assertTrue(mx.allclose(expected, out))
700+
701+
# Broadcast c
702+
a = mx.ones((5, 5))
703+
b = mx.ones((5, 5))
704+
c = mx.ones((1, 5))
705+
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
706+
expected = 1.5 * c + 0.5 * (a @ b)
707+
self.assertTrue(mx.allclose(expected, out))
708+
694709
def test_addmm_grad(self):
695710
def make_ref_addmm(alpha, beta):
696711
return lambda c, a, b: alpha * (a @ b) + beta * c

python/tests/test_optimizers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,53 @@ def test_adafactor(self):
286286
self.assertEqual(xp["x"].shape, x.shape)
287287
self.assertEqual(optimizer.state["step"], 2)
288288

289+
def test_muon(self):
290+
params = {
291+
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
292+
"second": mx.zeros((3, 3)),
293+
"conv": mx.zeros((16, 8, 3, 3)),
294+
}
295+
grads = tree_map(lambda x: mx.ones_like(x), params)
296+
297+
# Explicit init
298+
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
299+
optim.init(params)
300+
self.assertTrue(
301+
tree_equal(
302+
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
303+
params,
304+
optim.state,
305+
)
306+
)
307+
308+
# Test update
309+
updated_params = optim.apply_gradients(grads, params)
310+
311+
# Check that shapes are preserved
312+
self.assertTrue(
313+
tree_equal(
314+
lambda p, u: p.shape == u.shape,
315+
params,
316+
updated_params,
317+
)
318+
)
319+
320+
# Check that parameters actually changed
321+
self.assertFalse(
322+
tree_equal(
323+
lambda p, u: mx.array_equal(p, u),
324+
params,
325+
updated_params,
326+
)
327+
)
328+
329+
# Test with different configurations
330+
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
331+
optim_no_nesterov.apply_gradients(grads, params)
332+
333+
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
334+
optim_no_momentum.apply_gradients(grads, params)
335+
289336
def test_compiled_optimizer(self):
290337
model = nn.Linear(10, 10)
291338
x = mx.random.uniform(shape=(2, 10))

0 commit comments

Comments
 (0)