@@ -848,6 +848,106 @@ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
848848 return parameter - update
849849
850850
851+ class Muon (Optimizer ):
852+ r"""The Muon optimizer.
853+
854+ Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
855+ original implementation: `Muon: An optimizer for hidden layers in neural
856+ networks <https://kellerjordan.github.io/posts/muon/>`_
857+
858+ Note:
859+ - Muon may be sub-optimal for the embedding layer, the final fully
860+ connected layer, or any 0D/1D parameters. Those should be optimized
861+ by a different method (e.g., :class:`AdamW`).
862+ - For 4D convolutional filters, it works by flattening their last
863+ dimensions.
864+
865+ Args:
866+ learning_rate (float or callable): The learning rate.
867+ momentum (float, optional): The momentum strength. Default: ``0.95``
868+ weight_decay (float, optional): The weight decay (L2 penalty).
869+ Default: ``0.01``
870+ nesterov (bool, optional): Enables Nesterov momentum. Recommended for
871+ better performance. Default: ``True``
872+ ns_steps (int, optional): Number of Newton-Schulz iteration steps for
873+ orthogonalization. Default: ``5``
874+ """
875+
876+ def __init__ (
877+ self ,
878+ learning_rate : Union [float , Callable [[mx .array ], mx .array ]],
879+ momentum : float = 0.95 ,
880+ weight_decay : float = 0.01 ,
881+ nesterov : bool = True ,
882+ ns_steps : int = 5 ,
883+ ):
884+ super ().__init__ ()
885+
886+ self ._maybe_schedule ("learning_rate" , learning_rate )
887+ self .momentum = momentum
888+ self .weight_decay = weight_decay
889+ self .nesterov = nesterov
890+ self .ns_steps = ns_steps
891+
892+ def init_single (self , parameter : mx .array , state : dict ):
893+ """Initialize optimizer state"""
894+ state ["v" ] = mx .zeros_like (parameter )
895+
896+ def _zeropower_via_newtonschulz5 (self , X , steps : int ):
897+ assert (
898+ X .ndim == 2
899+ ), f"Expected a 2D array for Newton-Schulz iteration, got shape { X .shape } instead."
900+ a , b , c = (3.4445 , - 4.7750 , 2.0315 )
901+ transpose_needed = X .shape [- 2 ] > X .shape [- 1 ]
902+
903+ if transpose_needed :
904+ X = X .T
905+
906+ X = X / (mx .linalg .norm (X , keepdims = True ) + 1e-7 )
907+
908+ for _ in range (steps ):
909+ A = X @ X .T
910+ B = mx .addmm (b * A , A , A , beta = 1.0 , alpha = c )
911+ X = mx .addmm (a * X , B , X , beta = 1.0 , alpha = 1.0 )
912+
913+ if transpose_needed :
914+ X = X .T
915+ return X
916+
917+ def apply_single (self , gradient : mx .array , parameter : mx .array , state : dict ):
918+ """Performs the Muon parameter update"""
919+
920+ if self .weight_decay != 0 :
921+ gradient = gradient + self .weight_decay * parameter
922+
923+ v = self .momentum * state ["v" ]
924+ v = v + (1 - self .momentum ) * gradient
925+ state ["v" ] = v
926+
927+ if self .nesterov :
928+ update = gradient * (1 - self .momentum ) + v * self .momentum
929+ else :
930+ update = v
931+
932+ lr = self .learning_rate .astype (gradient .dtype )
933+
934+ if update .ndim >= 2 :
935+ original_shape = update .shape
936+ reshape_needed = update .ndim > 2
937+
938+ if reshape_needed :
939+ update = mx .reshape (update , (update .shape [0 ], - 1 ))
940+
941+ update = self ._zeropower_via_newtonschulz5 (update , steps = self .ns_steps )
942+
943+ if reshape_needed :
944+ update = mx .reshape (update , original_shape )
945+
946+ lr *= max (1 , update .shape [- 2 ] / update .shape [- 1 ]) ** 0.5
947+
948+ return parameter - lr * update
949+
950+
851951def clip_grad_norm (grads , max_norm ):
852952 """Clips the global norm of the gradients.
853953
0 commit comments