diff --git a/ci/h-test.sh b/ci/h-test.sh index 3dc63cac2e542..ec9498f821e4c 100644 --- a/ci/h-test.sh +++ b/ci/h-test.sh @@ -165,7 +165,8 @@ concurrency_list="^test_fp8_deep_gemm$|\ ^test_scaled_dot_product_attention$|\ ^test_compat_scaled_dot_product_attention$|\ ^test_flash_attention$|\ -^test_batched_gemm$" +^test_batched_gemm$|\ +^test_parallel_dygraph_muon$" cd ${work_dir}/build tmp_dir=$(mktemp -d) diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 9e2911a57d65e..04d7be79e54af 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -337,6 +337,8 @@ def __init__(self) -> None: ] self.sync_param_name: list[str] = ["embedding", "layer_norm", ".b_"] + self.use_muon_sharding: bool = False + self.__lock_attr = True logger.info("distributed strategy initialized") diff --git a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py index 389cc84a3c74b..1bdbea5f488d0 100755 --- a/python/paddle/distributed/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/distributed/fleet/base/meta_optimizer_factory.py @@ -25,6 +25,7 @@ meta_optimizer_names.remove("HybridParallelOptimizer") meta_optimizer_names.remove("HeterParallelOptimizer") meta_optimizer_names.remove("DGCMomentumOptimizer") +meta_optimizer_names.remove("MuonShardingOptimizer") class MetaOptimizerFactory: diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index e8fedd586de34..8f15c856e28a6 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -31,6 +31,7 @@ AdaptiveLocalSGDOptimizer, LocalSGDOptimizer, ) +from .muon_sharding_optimizer import MuonShardingOptimizer # noqa: F401 from .pipeline_optimizer import PipelineOptimizer # noqa: F401 from .ps_optimizer import ParameterServerOptimizer # noqa: F401 from .qat_optimizer import QATOptimizer # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index da6c6e1ec3353..8b6ca7571bdcf 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -23,6 +23,9 @@ DygraphShardingOptimizer, DygraphShardingOptimizerV2, ) +from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import ( + MuonShardingOptimizer, +) from paddle.distributed.fleet.utils.hybrid_parallel_util import ( obtain_optimizer_parameters_list, ) @@ -284,11 +287,13 @@ def __init__(self, optimizer, hcg, strategy): split_param = strategy.hybrid_configs[ 'sharding_configs' ].split_param - ShardingOptimizer = ( - DygraphShardingOptimizerV2 - if split_param - else DygraphShardingOptimizer - ) + use_muon_sharding = getattr(strategy, "use_muon_sharding", False) + if use_muon_sharding: + ShardingOptimizer = MuonShardingOptimizer + elif split_param: + ShardingOptimizer = DygraphShardingOptimizerV2 + else: + ShardingOptimizer = DygraphShardingOptimizer optimizer = ShardingOptimizer(optimizer, hcg) self._enable_timer = strategy.hybrid_configs["enable_optimizer_timer"] @@ -335,6 +340,7 @@ def __init__(self, optimizer, hcg, strategy): MixPrecisionOptimizer, DygraphShardingOptimizer, DygraphShardingOptimizerV2, + MuonShardingOptimizer, ), ) @@ -628,7 +634,11 @@ def _hybrid_sync_grad(self, parameter_list): if self._sharding_enable: assert isinstance( self._inner_opt, - (DygraphShardingOptimizer, DygraphShardingOptimizerV2), + ( + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, + MuonShardingOptimizer, + ), ) self._inner_opt.reduce_gradients(parameter_list, self._hcg) dp_parameter_list = self._inner_opt.filter_parameters( diff --git a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py new file mode 100644 index 0000000000000..15b52bee738ef --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py @@ -0,0 +1,1152 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MuonShardingOptimizer (Sharding Stage1 V3): Hybrid Tensor-wise + Element-wise Sharding +================================================================== + +Designed for Muon optimizer compatibility: + - 2D (Muon) parameters: assigned as *whole tensors* to ranks (like V1). + This avoids the expensive sharding gather in Muon's _muon_update. + - Non-2D (AdamW) parameters: split element-wise via reduce-scatter (like V2). + This provides memory balancing across ranks. + +The key insight is that Muon requires the full 2D matrix for Newton-Schulz +orthogonalisation, so keeping 2D params whole on each rank eliminates the +need for gather_varlen communication during the optimizer step. + +Parameters are grouped by their `color` attribute, which specifies the +communication group to use: + - color=None or -1: default sharding_group + - color='moe_expert': moe_sharding_group + - color=: hcg.get__parallel_group() (extensible design) +""" + +import math +import os +import warnings +from collections import defaultdict +from functools import reduce as functools_reduce + +import paddle +from paddle import framework +from paddle.base.framework import EagerParamBase +from paddle.distributed import fleet +from paddle.distributed.communication.reduce import ( + ReduceOp, + is_avg_reduce_op_supported, +) +from paddle.distributed.fleet.utils import timer_helper as timer +from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.fleet.utils.tensor_fusion_helper import ( + HOOK_ACTION, + FusedCommBuffer, + assign_group_by_size, +) + +g_shard_bypass_dygraph_optimizer = int( + os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0) +) +g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0)) + + +def _is_trainable(param): + return not param.stop_gradient + + +class MuonShardingOptimizer: + """ + Hybrid sharding optimizer for Muon: + - 2D (Muon) parameters: tensor-wise assignment to ranks (no cross-rank split). + Gradient communication uses reduce; parameter sync uses broadcast. + - Non-2D (AdamW) parameters: element-wise split across ranks (like V2). + Gradient communication uses reduce-scatter; parameter sync uses all-gather. + + Parameters are grouped by `color` attribute to determine the communication + group. Each color group has its own 2D parameter partition and communication. + + This avoids the expensive gather_varlen in Muon's _muon_update while + maintaining memory balance across ranks. + """ + + def __init__(self, optimizer, hcg=None): + logger.info("init MuonShardingOptimizer") + + if isinstance(optimizer._parameter_list[0], dict): + raise TypeError( + "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter" + ) + if not hasattr(optimizer, '_apply_optimize') or not callable( + optimizer._apply_optimize + ): + raise ValueError( + "the optimizer object should have _apply_optimize function" + ) + + self._inner_opt = optimizer + # Get hcg from fleet if not provided + if hcg is None: + hcg = fleet.fleet._hcg + self._hcg = hcg + self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() + self._sharding_rank = self._hcg.get_sharding_parallel_rank() + self._global_rank = paddle.distributed.get_rank() + + # Temporarily: TP is not supported in MuonShardingOptimizer + _tp_world_size = self._hcg.get_model_parallel_world_size() + assert _tp_world_size == 1, ( + f"MuonShardingOptimizer does not support tensor parallelism yet. " + f"Got tp_world_size={_tp_world_size}. Please set tensor_parallel_degree=1." + ) + + strategy = fleet.fleet._user_defined_strategy + sharding_configs = strategy.hybrid_configs['sharding_configs'] + + self.tensor_fusion = sharding_configs.tensor_fusion + self.accumulate_steps = sharding_configs.accumulate_steps + self.comm_overlap = sharding_configs.comm_overlap + self.comm_buffer_size_MB = sharding_configs.comm_buffer_size_MB + self.use_reduce_avg = sharding_configs.use_reduce_avg + + if self.use_reduce_avg and (not is_avg_reduce_op_supported()): + self.use_reduce_avg = False + warnings.warn( + "nccl reduce_avg requires paddle compiled with cuda and nccl>=2.10.0, " + "please check compilation setups." + ) + + pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap + self.pp_overlap = pp_overlap + + self._use_main_grad = hasattr(optimizer._parameter_list[0], "main_grad") + + # The full original parameter list + self._parameter_list = list(optimizer._parameter_list) + self._origin_parameter_list = list(optimizer._parameter_list) + + # Build color -> group_info mapping + self._color_to_group_info = self._build_color_to_group_info(hcg) + + # Extract MoE group info from color_to_group_info for backward compatibility + moe_info = self._color_to_group_info.get('moe_expert', {}) + self._moe_sharding_world_size = moe_info.get('world_size', 1) + self._moe_sharding_rank = moe_info.get('rank', 0) + self._moe_sharding_group = moe_info.get('group', None) + + # Get muon_param_info_map from Muon optimizer + # This map contains use_muon field for each parameter, determined by Trainer + self._muon_param_info_map = getattr( + optimizer, '_muon_param_info_map', {} + ) + + # ---- Step 1: Separate params into categories by color ---- + # Parameters are grouped by their `color` attribute: + # - color=None or -1: default sharding_group (key: None) + # - color='moe_expert': moe_sharding_group (key: 'moe_expert') + # - color=: corresponding parallel group (key: ) + # + # For each color group: + # - 2D (Muon) params: whole tensor, assigned to ranks via tensor-wise partition + # - non-2D (AdamW) params: element-wise split via FusedCommBuffer + # + # This design is extensible: adding a new communication group only requires + # setting the `color` attribute on parameters, no code changes needed here. + self._params_2d_by_color = defaultdict( + list + ) # color -> list of 2D params + self._params_1d = [] # All non-2D params (single list, sharding_group only) + self.clear_color = set() + self._color_to_comm_buffer_list = {} + for p in self._parameter_list: + if not _is_trainable(p): + continue + + # Extract color value + color = getattr(p, 'color', -1) + if isinstance(color, dict): + color_val = color.get('color', -1) + else: + color_val = color + + # Normalize color: treat None/-1 as default (None key) + if color_val == -1 or color_val is None: + color_key = None + else: + color_key = color_val + + # Check if this color group supports 2D tensor-wise partition + group_info = self._color_to_group_info.get(color_key) + + param_info = self._muon_param_info_map.get(p.name) + assert param_info is not None, ( + f"Parameter {p.name!r} (shape={list(p.shape)}) has no muon_param_info. " + f"Trainer._build_muon_param_info_map must set muon_param_info on all " + f"trainable parameters before MuonShardingOptimizer is constructed." + ) + use_muon = param_info.use_muon + + if use_muon: + self._params_2d_by_color[color_key].append(p) + else: + # Non-2D params always go to 1D element-wise split (sharding_group only) + self._params_1d.append(p) + + # ---- Step 2: Partition 2D params for each color group ---- + # For each color, compute rank-to-params and param-to-rank mappings + self._rank2params_2d_by_color = {} # color -> {rank -> [params]} + self._param2rank_2d_by_color = {} # color -> {param_name -> rank} + + for color_key, params_2d in self._params_2d_by_color.items(): + group_info = self._color_to_group_info.get(color_key, {}) + world_size = group_info.get('world_size', 1) + + if world_size <= 1: + # No partition needed, all params stay on rank 0 + self._rank2params_2d_by_color[color_key] = {0: list(params_2d)} + self._param2rank_2d_by_color[color_key] = { + p.name: 0 for p in params_2d + } + else: + # Greedy partition across ranks + label = color_key if color_key else "default" + self._rank2params_2d_by_color[color_key] = ( + self._partition_2d_parameters( + list(params_2d), world_size, label=label + ) + ) + self._param2rank_2d_by_color[color_key] = {} + for rank, params in self._rank2params_2d_by_color[ + color_key + ].items(): + for p in params: + self._param2rank_2d_by_color[color_key][p.name] = rank + + # add sort 2d params + for color_key, params_2d in self._params_2d_by_color.items(): + params_2d.sort( + key=lambda p: self._param2rank_2d_by_color[color_key][p.name] + ) + + # ---- Backward compatibility: expose legacy attributes ---- + # These are kept for any external code that might reference them + self._params_2d = self._params_2d_by_color.get(None, []) + self._params_2d_moe = self._params_2d_by_color.get('moe_expert', []) + self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []}) + self._param2rank_2d = self._param2rank_2d_by_color.get(None, {}) + self._rank2params_2d_moe = self._rank2params_2d_by_color.get( + 'moe_expert', {0: []} + ) + self._param2rank_2d_moe = self._param2rank_2d_by_color.get( + 'moe_expert', {} + ) + + self._use_fuse_gradients = g_shard_fused_gradient + # ---- Build comm buffers for 2D params (V1-style) ---- + if self._use_fuse_gradients: + if not hasattr(self, 'comm_buffer_2d'): + self.comm_buffer_2d = self._build_2d_comm_buffers() + self.comm_buffer_2d.sort(key=lambda x: x._dst) + + # ---- Step 3: Build comm buffers for 1D params (V2-style) ---- + self._slice_params = {} + self._comm_buffer_list = [] + self._local_parameter_list_1d = [ + self._create_slice_param(p) for p in self._params_1d + ] + + self.param2bucket = {} + self.sd_release_grads = ( + strategy.hybrid_configs['pp_configs'].release_gradients + or sharding_configs.release_gradients + ) + self._build_1d_comm_buffers() + + # ---- Step 4: Build the optimizer's parameter list ---- + # The optimizer should see: + # - Non-MoE 2D params assigned to this rank (as whole tensors) + # - MoE expert 2D params assigned to this rank in moe_sharding_group + # - 1D slice_params for all non-2D params (element-wise shards) + local_2d_params = list( + self._rank2params_2d.get(self._sharding_rank, []) + ) + + if self._moe_sharding_world_size > 1: + local_2d_moe_params = list( + self._rank2params_2d_moe.get(self._moe_sharding_rank, []) + ) + else: + # moe_sharding_degree=1: this rank owns all its MoE expert params + local_2d_moe_params = list(self._rank2params_2d_moe.get(0, [])) + + local_opt_params = ( + local_2d_params + + local_2d_moe_params + + list(self._local_parameter_list_1d) + ) + + self._set_inner_opt_attr('_parameter_list', local_opt_params) + self._set_inner_opt_attr('_param_groups', local_opt_params) + + # For external iteration (clear_grad, etc.), expose all params + self._local_parameter_list = local_opt_params + + self._enable_timer = strategy.hybrid_configs.get( + "enable_optimizer_timer", False + ) + if self._enable_timer: + if not timer.is_timer_initialized(): + timer.set_timers() + self.timers = timer.get_timers() + + # --- [SLICE SIZE SUMMARY] Per-rank slice param sizes within this PP stage --- + _sg_group = hcg.get_sharding_parallel_group() + _N = self._sharding_world_size + + # 2D (non-MoE) params owned by this rank + _local_2d_numel = sum( + int(functools_reduce(lambda x, y: x * y, p.shape, 1)) + for p in self._rank2params_2d.get(self._sharding_rank, []) + ) + # 2D (MoE) params owned by this rank + _moe_rank_key = ( + self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0 + ) + _local_2d_moe_numel = sum( + int(functools_reduce(lambda x, y: x * y, p.shape, 1)) + for p in self._rank2params_2d_moe.get(_moe_rank_key, []) + ) + # 1D (AdamW) slice: each rank owns ceil(param.numel / world_size) elements per param. + # Sum over all 1D params in this sharding group (same color). + _local_1d_numel = sum( + math.ceil( + int(functools_reduce(lambda x, y: x * y, p.shape, 1)) / _N + ) + for p in self._params_1d + ) + + _local_total_numel = ( + _local_2d_numel + _local_2d_moe_numel + _local_1d_numel + ) + _local_total_MB = ( + _local_total_numel * 2 / (1024 * 1024) + ) # bf16/fp16 = 2 bytes + + # All-gather total numel from all sharding ranks in this PP stage + _local_numel_tensor = paddle.to_tensor( + [_local_total_numel], dtype='int64' + ) + _all_numel_list = [] + paddle.distributed.all_gather( + _all_numel_list, _local_numel_tensor, group=_sg_group + ) + _all_numel = [int(t.item()) for t in _all_numel_list] + _all_MB = [n * 2 / (1024 * 1024) for n in _all_numel] + + _max_MB = max(_all_MB) + _min_MB = min(_all_MB) + _imbalance = (_max_MB - _min_MB) / _max_MB if _max_MB > 0 else 0.0 + + if self._sharding_rank == 0: + logger.info( + f"[MuonSharding global_rank={self._global_rank} sharding_rank={self._sharding_rank}] " + f"SliceSize sharding_group ranks={_sg_group.ranks} | " + f"per-rank MB: {[f'{mb:.1f}' for mb in _all_MB]} | " + f"max memory diff={_imbalance * 100:.2f}%" + ) + + # ------------------------------------------------------------------ + # 2D partition (V1-style greedy) + # ------------------------------------------------------------------ + + @staticmethod + def _build_color_to_group_info(hcg): + """Build a mapping from color to communication group info. + + Returns: + dict: { + None: {'group': sharding_group, 'world_size': N, 'rank': r}, + 'moe_expert': {'group': moe_sharding_group, 'world_size': M, 'rank': s}, + # Future colors can be added here + } + """ + color_to_info = {} + + # Default sharding group + sharding_world_size = hcg.get_sharding_parallel_world_size() + sharding_group = hcg.get_sharding_parallel_group() + color_to_info[None] = { + 'group': sharding_group, + 'world_size': sharding_world_size, + 'rank': sharding_group.rank if sharding_group else 0, + } + + # MoE sharding group (if available) + if hasattr(hcg, "get_moe_sharding_parallel_world_size"): + moe_world_size = hcg.get_moe_sharding_parallel_world_size() + if moe_world_size > 0: + moe_group = hcg.get_moe_sharding_parallel_group() + color_to_info['moe_expert'] = { + 'group': moe_group, + 'world_size': moe_world_size, + 'rank': moe_group.rank if moe_group else 0, + } + + # Future: Add more color -> group mappings here as needed + # Example: + # if hasattr(hcg, "get_custom_parallel_world_size"): + # custom_world_size = hcg.get_custom_parallel_world_size() + # if custom_world_size > 0: + # custom_group = hcg.get_custom_parallel_group() + # color_to_info['custom'] = { + # 'group': custom_group, + # 'world_size': custom_world_size, + # 'rank': custom_group.rank if custom_group else 0, + # } + + return color_to_info + + def _partition_2d_parameters(self, params, world_size, label=""): + """Partition 2D parameters among ranks using greedy bin-packing.""" + mapping = {} + for rank in range(world_size): + mapping[rank] = [] + sizes = [0] * world_size + + parameters = list(params) + parameters.sort( + key=lambda p: functools_reduce(lambda x, y: x * y, p.shape), + reverse=True, + ) + + for param in parameters: + rank = sizes.index(min(sizes)) + mapping[rank].append(param) + numel = functools_reduce(lambda x, y: x * y, param.shape, 1) + sizes[rank] += numel + + return mapping + + def _build_2d_comm_buffers(self): + """Build communication buffers for 2D (Tensor-wise) parameters using all-reduce.""" + group_size = ( + self.comm_buffer_size_MB * 1024 * 1024 + if self.comm_buffer_size_MB > 0 + else 256 * 1024 * 1024 + ) + comm_buffers = [] + + for color_key, params_2d in self._params_2d_by_color.items(): + group_info = self._color_to_group_info.get(color_key, {}) + comm_group = group_info.get('group', None) + + fused_parameter_group = defaultdict(list) + + for p in params_2d: + dst_rank = self._param2rank_2d_by_color[color_key][p.name] + fused_parameter_group[dst_rank].append(p) + + absolute_dst_ranks = { + rank: comm_group.ranks[rank] for rank in fused_parameter_group + } + + for dst, params in fused_parameter_group.items(): + var_groups = assign_group_by_size(params, group_size) + abs_dst = absolute_dst_ranks[dst] + + buffer = [ + FusedCommBuffer( + group_idx, + parameters, + comm_group, + self.accumulate_steps, + act=HOOK_ACTION.REDUCE, + dst=abs_dst, + release_grads=False, + use_reduce_avg=True, + ) + for group_idx, parameters in var_groups.items() + ] + comm_buffers.extend(buffer) + + return comm_buffers + + # ------------------------------------------------------------------ + # 1D slice creation (V2-style) + # ------------------------------------------------------------------ + + def _create_slice_param(self, param): + """Create a placeholder slice parameter for 1D (element-wise) sharding.""" + slice_param = EagerParamBase(shape=[1], dtype=param.dtype) + slice_param.name = param.name + + def copy_attr(attr_name): + if hasattr(param, attr_name): + setattr(slice_param, attr_name, getattr(param, attr_name)) + + copy_attr("is_distributed") + copy_attr("optimize_attr") + copy_attr("do_model_average") + copy_attr("need_clip") + copy_attr("no_sync") + + self._slice_params[param.name] = slice_param + return slice_param + + def _build_1d_comm_buffers(self): + """Build communication buffers for 1D (AdamW) parameters using reduce-scatter.""" + if self.pp_overlap: + return + + comm_group = self._hcg.get_sharding_parallel_group() + group_size = ( + self.comm_buffer_size_MB * 1024 * 1024 + if self.comm_buffer_size_MB > 0 + else 256 * 1024 * 1024 + ) + + # Group 1D params by color (for MoE compatibility) + color_dict = defaultdict(list) + for param in self._params_1d: + color = getattr(param, 'color', -1) + color_group = comm_group + if isinstance(color, dict): + color_color = color.get('color', -1) + color_group = color.get('group', comm_group) + else: + color_color = color + color_dict[(color_color, color_group)].append(param) + + if not self.comm_overlap: + for color, params in color_dict.items(): + params.sort(key=lambda x: str(x.dtype)) + + group_idx = 0 + for color, params in color_dict.items(): + g_color = color[0] + g_group = color[1] + var_groups = assign_group_by_size(params, group_size) + for _, parameters in var_groups.items(): + buffer = FusedCommBuffer( + group_idx, + parameters, + g_group, + self.accumulate_steps, + act=HOOK_ACTION.REDUCE_SCATTER, + release_grads=self.sd_release_grads, + use_reduce_avg=self.use_reduce_avg, + free_grads_in_comm=False, + init_slice_param=False, + slice_params=self._slice_params, + ) + group_idx += 1 + self._comm_buffer_list.append(buffer) + if g_color not in self._color_to_comm_buffer_list.keys(): + self._color_to_comm_buffer_list[g_color] = [] + self._color_to_comm_buffer_list[g_color].append(buffer) + for p in parameters: + if p.name in self.param2bucket: + self.param2bucket[p.name].append(buffer) + else: + self.param2bucket[p.name] = [buffer] + + self._comm_buffer_list.sort(key=lambda x: x._dst) + + def clear_param_storage(self, color): + self.clear_color.add(color) + if color in self._color_to_comm_buffer_list.keys(): + for comm_buffer in self._color_to_comm_buffer_list[color]: + for param in comm_buffer.params: + grad_view = comm_buffer._sharding_param_grad_view[ + param.name + ] + slice_param = self._slice_params[param.name] + if ( + not g_shard_bypass_dygraph_optimizer + and grad_view._param_begin < grad_view._param_end + ): + grad_view.fill_slice_param(slice_param) + self._create_master_weight(slice_param) + slice_param._clear_dataptr() + comm_buffer._clear_param_storage() + + def reset_param_storage(self): + for color in self.clear_color: + if color is None: + continue + if color in self._color_to_comm_buffer_list.keys(): + for comm_buffer in self._color_to_comm_buffer_list[color]: + comm_buffer._reset_param_storage() + + # ------------------------------------------------------------------ + # Gradient communication + # ------------------------------------------------------------------ + + def _get_param_grad(self, param): + if not param.trainable: + return None + if hasattr(param, "main_grad"): + assert param._grad_ivar() is None, ( + "param.grad should be None when using main_grad" + ) + return param.main_grad + return param._grad_ivar() + + def _reduce_2d_grads(self, params, param2rank, comm_group): + """Reduce gradients for 2D params to their owner rank within comm_group.""" + for param in params: + g_var = self._get_param_grad(param) + if g_var is None: + if hasattr(param, "main_grad"): + g_var = paddle.zeros_like(param, dtype=paddle.float32) + param.main_grad = g_var + else: + g_var = paddle.zeros_like(param, dtype=param.dtype) + param.grad = g_var + + reduce_op = ReduceOp.AVG + if not self.use_reduce_avg: + nranks = comm_group.nranks + g_var.scale_(1.0 / nranks) + reduce_op = ReduceOp.SUM + + if paddle.distributed.in_auto_parallel_align_mode(): + reduce_op = ReduceOp.SUM + + param_rank = param2rank[param.name] + paddle.distributed.reduce( + g_var, + dst=comm_group.ranks[param_rank], + op=reduce_op, + group=comm_group, + sync_op=True, + ) + + def reduce_gradients(self, parameter_list, hcg): + """Reduce gradients: reduce for 2D params, reduce-scatter for 1D params.""" + if ( + paddle.is_compiled_with_xpu() + and os.getenv("XPU_CDNN_CLUSTER_PARALLEL") is not None + ): + paddle.device.synchronize() + + with framework.no_grad(): + # --- 2D params: reduce via comm buffers | per tensors --- + if self._use_fuse_gradients: + for comm_buffer in self.comm_buffer_2d: + comm_buffer._comm_grads() + else: + # --- Non-MoE 2D params: reduce to owner rank via sharding_group --- + sharding_group = hcg.get_sharding_parallel_group() + self._reduce_2d_grads( + self._params_2d, self._param2rank_2d, sharding_group + ) + + # --- MoE expert 2D params: reduce to owner rank via moe_sharding_group --- + if self._params_2d_moe and self._moe_sharding_group is not None: + if self._moe_sharding_world_size > 1: + self._reduce_2d_grads( + self._params_2d_moe, + self._param2rank_2d_moe, + self._moe_sharding_group, + ) + # When moe_sharding_degree=1, no reduce needed (single rank group) + + # --- 1D params: reduce-scatter via comm buffers --- + for comm_buffer in self._comm_buffer_list: + if self.sd_release_grads and comm_buffer.grad_storage is None: + if comm_buffer.need_reduce_scale_sync(): + for param in comm_buffer.params: + comm_buffer._copy_grad_to_buffer(param) + + if not self.comm_overlap: + comm_buffer._comm_grads() + + # wait for all comm_buffer tasks to finish + if self._use_fuse_gradients: + for comm_buffer in self.comm_buffer_2d: + comm_buffer.scale_grads() + for comm_buffer in self._comm_buffer_list: + comm_buffer.scale_grads() + + def filter_parameters(self, parameter_list, hcg): + """Filter parameters: return local 2D params + initialized 1D slices.""" + sharding_rank = hcg.get_sharding_parallel_rank() + local_2d = [ + p + for p in parameter_list + if p.name in self._param2rank_2d + and self._param2rank_2d[p.name] == sharding_rank + ] + # Also include MoE 2D params owned by this rank + if self._moe_sharding_world_size > 1: + moe_rank = self._moe_sharding_rank + else: + moe_rank = 0 + local_2d_moe = [ + p + for p in parameter_list + if p.name in self._param2rank_2d_moe + and self._param2rank_2d_moe[p.name] == moe_rank + ] + local_1d = [ + self._slice_params[p.name] + for p in parameter_list + if p.name in self._slice_params + ] + local_1d = [p for p in local_1d if p._is_initialized()] + return local_2d + local_2d_moe + local_1d + + # ------------------------------------------------------------------ + # Parameter sync after optimizer step + # ------------------------------------------------------------------ + + def _broadcast_2d_params(self, rank2params, comm_group): + """Broadcast 2D params from owner ranks within comm_group.""" + broadcast_tasks = [] + for rank, params in rank2params.items(): + src_rank = comm_group.ranks[rank] + for param in params: + if param.stop_gradient: + continue + task = paddle.distributed.broadcast( + param, + src=src_rank, + group=comm_group, + sync_op=False, + ) + broadcast_tasks.append(task) + return broadcast_tasks + + def _sharding_sync_parameters(self): + """Sync parameters: broadcast 2D, all-gather 1D.""" + comm_group = self._hcg.get_sharding_parallel_group() + + with framework.no_grad(): + all_tasks = [] + + # --- Non-MoE 2D params: broadcast from owner via sharding_group --- + all_tasks.extend( + self._broadcast_2d_params(self._rank2params_2d, comm_group) + ) + + # --- MoE expert 2D params: broadcast from owner via moe_sharding_group --- + if self._params_2d_moe and self._moe_sharding_group is not None: + if self._moe_sharding_world_size > 1: + all_tasks.extend( + self._broadcast_2d_params( + self._rank2params_2d_moe, self._moe_sharding_group + ) + ) + # When moe_sharding_degree=1, no broadcast needed (single rank group) + + for task in all_tasks: + task.wait() + + # --- 1D params: all-gather via comm buffers --- + for comm_buffer in self._comm_buffer_list: + comm_buffer.sync_params() + + # ------------------------------------------------------------------ + # Clear gradients + # ------------------------------------------------------------------ + + def clear_grad(self, set_to_zero=True): + """Clear gradients for all parameters.""" + + def clear_grad_func(p): + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p._grad_ivar() is None + if set_to_zero: + p.main_grad.zero_() + else: + p.main_grad._clear() + p.main_grad = None + elif not hasattr(p, "main_grad"): + if self.tensor_fusion: + if set_to_zero: + p.grad.zero_() + else: + p.grad._clear() + p.grad = None + else: + p.clear_gradient(set_to_zero) + + for p in self._parameter_list: + clear_grad_func(p) + + # 1D params are managed by comm buffers + if self.sd_release_grads and not self.pp_overlap: + for comm_buffer in self._comm_buffer_list: + if comm_buffer.need_reduce_scale_sync(): + comm_buffer._clear_grad_storage() + + if self._use_fuse_gradients: + for comm_buffer in self.comm_buffer_2d: + if comm_buffer.need_reduce_scale_sync(): + comm_buffer._clear_grad_storage() + + # ------------------------------------------------------------------ + # Optimizer step + # ------------------------------------------------------------------ + + def _collect_comm_buffers(self): + """Collect communication buffers (for PP overlap compatibility).""" + if self._comm_buffer_list: + return + for param in self._params_1d: + if not hasattr(param, "comm_buffer_ref"): + continue + comm_buffer_ref = param.comm_buffer_ref + del param.comm_buffer_ref + comm_buffer = comm_buffer_ref() + self._comm_buffer_list.append(comm_buffer) + + for bucket in self._comm_buffer_list: + for p in bucket._params: + if p.name in self.param2bucket: + self.param2bucket[p.name].append(bucket) + else: + self.param2bucket[p.name] = [bucket] + + def _assign_slice_grad(self): + """Assign gradients from comm buffers to slice params for 1D params.""" + for comm_buffer in self._comm_buffer_list: + for param in comm_buffer.params: + if param.name in self._slice_params: + slice_param = self._slice_params[param.name] + if self.sd_release_grads and hasattr( + slice_param, "main_grad" + ): + if not slice_param.main_grad._is_initialized(): + del slice_param.main_grad + comm_buffer.assign_slice_grad(param, slice_param) + + def step(self): + """Optimizer step: update local 2D params and 1D slices, then sync.""" + self._collect_comm_buffers() + self._assign_slice_grad() + + if not isinstance(self._origin_parameter_list[0], dict): + params_grads = [] + + # --- Non-MoE 2D params on this rank: full tensors --- + local_2d = self._rank2params_2d.get(self._sharding_rank, []) + for param in local_2d: + if param.stop_gradient: + continue + grad_var = param._grad_ivar() + if hasattr(param, "main_grad") and param.main_grad is not None: + grad_var = param.main_grad + if grad_var is not None: + params_grads.append((param, grad_var)) + + # --- MoE expert params on this rank --- + # Pass the original param (2D or 3D) directly to the optimizer. + # _muon_update already handles both shapes: + # - 2D [H, I]: standard Newton-Schulz + # - 3D [n_experts, H, I]: per-expert Newton-Schulz loop (Step 4) + # Keeping the original name avoids registering _expert_N accumulator + # keys that are absent from model_sharded_state_dict, which would + # break sharded_state_dict (checkpoint save). + if self._moe_sharding_world_size > 1: + local_2d_moe = self._rank2params_2d_moe.get( + self._moe_sharding_rank, [] + ) + else: + local_2d_moe = self._rank2params_2d_moe.get(0, []) + + for param in local_2d_moe: + if param.stop_gradient: + continue + grad_var = param._grad_ivar() + if hasattr(param, "main_grad") and param.main_grad is not None: + grad_var = param.main_grad + if grad_var is None: + continue + params_grads.append((param, grad_var)) + + # --- 1D params: slice params (element-wise shards) --- + for param in self._params_1d: + if param.stop_gradient: + continue + if param.name not in self._slice_params: + continue + slice_p = self._slice_params[param.name] + grad_var = slice_p._grad_ivar() + if ( + hasattr(slice_p, "main_grad") + and slice_p.main_grad is not None + ): + grad_var = slice_p.main_grad + if grad_var is not None: + params_grads.append((slice_p, grad_var)) + + self._apply_optimize( + loss=None, + startup_program=None, + params_grads=params_grads, + ) + + # Sync parameters across sharding ranks + self._sharding_sync_parameters() + + # ------------------------------------------------------------------ + # State dict (checkpoint save/load) + # ------------------------------------------------------------------ + + @framework.dygraph_only + def set_state_dict(self, state_dict): + inner_state = {} + # Local parameters = local 2D + local MoE 2D + 1D slice params + local_2d = list(self._rank2params_2d.get(self._sharding_rank, [])) + if self._moe_sharding_world_size > 1: + local_2d_moe = list( + self._rank2params_2d_moe.get(self._moe_sharding_rank, []) + ) + else: + local_2d_moe = list(self._rank2params_2d_moe.get(0, [])) + parameters = local_2d + local_2d_moe + # Add 1D params (use original param names for matching) + for p in self._params_1d: + parameters.append(p) + + if "LR_Scheduler" in state_dict: + inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler") + + if "master_weights" in state_dict: + master = state_dict.pop("master_weights") + inner_state["master_weights"] = {} + for p in parameters: + for k, v in master.items(): + if p.name == k: + v.name = self._inner_opt._gen_master_weight_var_name(p) + inner_state["master_weights"][k] = v + + for p in parameters: + for k, v in state_dict.items(): + if p.name in k: + inner_state[k] = v + + self._inner_opt.set_state_dict(inner_state) + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def _set_inner_opt_attr(self, attr_name, value): + inner_opt = self._inner_opt + inner_opt_name = '_inner_opt' + if not isinstance(attr_name, str): + raise TypeError( + f"attr_name should be str type, but is {type(attr_name)}" + ) + while hasattr(inner_opt, attr_name): + setattr(inner_opt, attr_name, value) + inner_opt = getattr(inner_opt, inner_opt_name, None) + if inner_opt is None: + break + + def sharded_state_dict(self, model_sharded_state_dict): + """Build a sharded optimizer state dict for flex checkpoint save/load. + + Overrides the inner Muon optimizer's sharded_state_dict to handle V3's + hybrid sharding scheme: + - 2D Muon params (non-MoE and MoE): whole tensor, shape matches + model's local_shape. Handled by delegating to the inner Muon's + sharded_state_dict after filtering out 1D param states. + - 1D AdamW params: accumulators are 1D shards (from reduce-scatter); + wrapped with is_flattened=True + flattened_range, like V2. + """ + import paddle as _paddle + from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + ShardedWeight, + create_sharded_weight_with_new_local, + ) + + # ---- Step 1: Collect flattened_range for each 1D (AdamW) param ---- + # Identical logic to DygraphShardingOptimizerV2.sharded_state_dict. + param_slice_info = {} # param_name -> slice(begin, end) + padded_param = set() + for buffer in self._comm_buffer_list: + for ( + param_name, + grad_view, + ) in buffer._sharding_param_grad_view.items(): + numel = grad_view._param.numel().item() + param_begin = grad_view._param_begin + param_end = grad_view._param_end + index = grad_view._index + padding_begin = index + numel + flattened_range = slice( + param_begin - index, + max( + min(padding_begin - index, param_end - index), + param_begin - index, + ), + ) + if param_end > padding_begin: + padded_param.add(param_name) + param_slice_info[param_name] = flattened_range + + # ---- Step 2: Build static_name → struct_name mapping ---- + model_sharded_sorted = dict(sorted(model_sharded_state_dict.items())) + static_to_struct = {} + for struct_name, sw in model_sharded_sorted.items(): + if sw.local_tensor.name not in static_to_struct: + static_to_struct[sw.local_tensor.name] = struct_name + + # ---- Step 3: Process all optimizer states ---- + _FP32_MASTER = "fp32_master_0" + _optimizer_scalar_names = ["beta1_pow_acc_0", "beta2_pow_acc_0"] + _optimizer_vector_names = ["moment1_0", "moment2_0"] + + def _make_2d_entry(uname, t, sp): + """Reshape tensor if numel matches but shape differs, then wrap as ShardedWeight.""" + target = sp.local_shape + if ( + tuple(t.shape) != tuple(target) + and t.numel() == _paddle.to_tensor(list(target)).prod().item() + ): + t = t.reshape(target) + return create_sharded_weight_with_new_local(uname, t, sp) + + def _split_state_name(vname): + if _FP32_MASTER in vname: + return tuple(vname.split("_" + _FP32_MASTER + "_", 1)) + for suffix in _optimizer_scalar_names + _optimizer_vector_names: + if vname.endswith(suffix): + return vname[: -(len(suffix) + 1)], suffix + raise ValueError( + f"Cannot parse optimizer state variable name: {vname!r}" + ) + + optimizer_state_dict = self._inner_opt.state_dict() + master_weights = optimizer_state_dict.pop("master_weights", None) + optimizer_state_dict.pop("LR_Scheduler", None) + + sharded_state = {} + + for key, tensor in optimizer_state_dict.items(): + static_name, state_type = _split_state_name(key) + if static_name not in static_to_struct: + continue + + struct_name = static_to_struct[static_name] + sharded_param = model_sharded_sorted[struct_name] + unified_name = f"{struct_name}.{state_type}" + + is_1d_param = static_name in param_slice_info + + if state_type in _optimizer_vector_names: + if is_1d_param: + # 1D AdamW shard: wrap with is_flattened=True (like V2) + flattened_range = param_slice_info[static_name] + if flattened_range.stop - flattened_range.start == 0: + continue + is_padded = static_name in padded_param + if is_padded: + local_tensor = _paddle.slice( + tensor, + axes=[0], + starts=[0], + ends=[flattened_range.stop - flattened_range.start], + ) + else: + local_tensor = tensor + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=local_tensor, + local_shape=sharded_param.local_shape, + global_shape=sharded_param.global_shape, + global_offset=sharded_param.global_offset, + is_flattened=True, + flattened_range=flattened_range, + ) + elif tensor.is_dist(): + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=tensor.shape, + global_shape=tensor.shape, + global_offset=sharded_param.global_offset, + ) + else: + # 2D Muon param (non-MoE or MoE): shape may differ between + # Python param.shape (3D view) and model storage (2D). + sharded_state[unified_name] = _make_2d_entry( + unified_name, tensor, sharded_param + ) + else: + # Scalar states (beta_pow): replicated + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=(1,), + global_shape=(1,), + global_offset=(0,), + ) + + # FP32 master weights + if master_weights: + for weight_key, tensor in master_weights.items(): + if weight_key not in static_to_struct: + continue + struct_name = static_to_struct[weight_key] + sharded_param = model_sharded_sorted[struct_name] + unified_name = f"{struct_name}.w_0" + is_1d_param = weight_key in param_slice_info + + if is_1d_param: + flattened_range = param_slice_info[weight_key] + if flattened_range.stop - flattened_range.start == 0: + continue + is_padded = weight_key in padded_param + if is_padded: + local_tensor = _paddle.slice( + tensor, + axes=[0], + starts=[0], + ends=[flattened_range.stop - flattened_range.start], + ) + else: + local_tensor = tensor + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=local_tensor, + local_shape=sharded_param.local_shape, + global_shape=sharded_param.global_shape, + global_offset=sharded_param.global_offset, + is_flattened=True, + flattened_range=flattened_range, + ) + elif tensor.is_dist(): + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=tensor.shape, + global_shape=tensor.shape, + global_offset=sharded_param.global_offset, + ) + else: + # Same reshape logic as for optimizer vector states: + # FP32 master weight may be 3D (e.g. grouped_gemm_experts + # [n_experts, H, I]) while model storage is 2D [n_experts*H, I]. + sharded_state[unified_name] = _make_2d_entry( + unified_name, tensor, sharded_param + ) + + return sharded_state + + def __getattr__(self, item): + return getattr(self._inner_opt, item) diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index dd4295acb1380..3c0c8bbc5aaf8 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -22,6 +22,7 @@ from .lamb import Lamb from .lbfgs import LBFGS from .momentum import Momentum +from .muon import Muon from .nadam import NAdam from .optimizer import Optimizer from .radam import RAdam @@ -45,4 +46,5 @@ 'NAdam', 'Lamb', 'LBFGS', + 'Muon', ] diff --git a/python/paddle/optimizer/muon.py b/python/paddle/optimizer/muon.py new file mode 100644 index 0000000000000..5d447204f35fc --- /dev/null +++ b/python/paddle/optimizer/muon.py @@ -0,0 +1,1020 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass + +import paddle +from paddle.base import framework +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + ShardedStateDict, + ShardedWeight, + create_sharded_weight_with_new_local, +) + +from ..nn.clip import GradientClipBase +from .optimizer import Optimizer + +# Debug logging for Muon optimizer +_logger = logging.getLogger(__name__) +MUON_DEBUG = os.environ.get("MUON_DEBUG", "0") == "1" + +__all__ = [] + + +# ------------------------------------------------------------------ +# Parameter metadata dataclasses +# ------------------------------------------------------------------ + + +@dataclass +class QKVInfo: + """Metadata for QKV weight matrices (GQA). + + Attributes: + head_num: Number of attention heads (Q heads). + kv_head_num: Number of key-value heads (for GQA). + num_key_value_groups: Number of Q heads per KV head. + """ + + head_num: int + kv_head_num: int + num_key_value_groups: int + + +@dataclass +class MLAInfo: + """Metadata for MLA weight matrices needed for head-split. + + Attributes: + param_name: Name of the parameter (q_b_proj, kv_b_proj, o_proj). + head_num: Number of attention heads. + """ + + param_name: str + head_num: int + + +@dataclass +class MuonParamInfo: + """Muon update metadata for a single parameter. + + This replaces the previous approach of setting dynamic attributes + directly on param objects. + + Attributes: + use_muon: If True, use Muon (orthogonal) updates; otherwise AdamW. + qkv_info: Required for QKV weight matrices. + intermediate_size: Required for FFN gate_up weights when muon_ffn_split is True. + """ + + use_muon: bool = True + qkv_info: QKVInfo | None = None + mla_info: MLAInfo | None = None + intermediate_size: int | None = None + + @property + def is_qkv(self) -> bool: + """True if this is a QKV weight matrix.""" + return self.qkv_info is not None + + @property + def is_mla(self) -> bool: + """True if this is an MLA weight matrix.""" + return self.mla_info is not None + + @property + def is_ffn_gate_up(self) -> bool: + """True if this is an FFN gate_up weight matrix.""" + return self.intermediate_size is not None + + +# Type alias for the parameter info mapping +MuonParamInfoMap = dict[str, MuonParamInfo] + +# ------------------------------------------------------------------ +# Newton-Schulz coefficient sets +# ------------------------------------------------------------------ + +_NS_COEFFICIENT_SETS = { + # Simple coefficient set (original) + "simple": [ + (3.4445, -4.7750, 2.0315), + ], + # Quintic iteration with optimized coefficients + # Source: https://leloykun.github.io/ponder/muon-opt-coeffs/ + "quintic": [ + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ], + # Polar Express iteration from https://arxiv.org/abs/2505.16932 + "polar_express": [ + (8.2051, -22.9019, 16.4607), + (4.0664, -2.8612, 0.5184), + (3.9096, -2.8234, 0.5250), + (3.2856, -2.4153, 0.4853), + (2.2779, -1.6198, 0.3985), + (1.8726, -1.2307, 0.3585), + (1.8564, -1.2132, 0.3568), + (1.8750, -1.2500, 0.3750), + ], + # AOL coefficients from https://github.com/thib-s/flash-newton-schulz + "aol": [ + (4.0098, -7.0585, 2.4635), + (3.4585, -5.5479, 2.5959), + (2.7573, -3.2939, 1.4254), + (2.7215, -3.0494, 1.3169), + ], +} + +# ------------------------------------------------------------------ +# Default parameter classification +# ------------------------------------------------------------------ + + +def _default_should_use_muon(name, shape, exclude_patterns): + """Default fallback logic for determining if a parameter should use Muon. + + This is only used when param.is_muon is not set. The actual exclusion + patterns must be configured via training_args.muon_exclude_patterns in yaml. + + Args: + name: Parameter name. + shape: Parameter shape. + exclude_patterns: List of substrings to exclude from Muon updates. + Must be provided (e.g., ['embed', 'bias', 'lm_head', 'mlp.gate']). + + Returns: + True if the parameter should use Muon (orthogonal) updates. + + Raises: + ValueError: If exclude_patterns is None. + """ + if exclude_patterns is None: + raise ValueError( + "muon_exclude_patterns must be set in yaml config. " + "Example: muon_exclude_patterns: ['embed', 'bias', 'lm_head', 'mlp.gate']" + ) + + if len(shape) not in (2, 3): + return False + + name_lower = name.lower() + for pattern in exclude_patterns: + if pattern.lower() in name_lower: + return False + return True + + +class Muon(Optimizer): + r""" + Muon optimizer for MuonShardingOptimizer (Sharding Stage1 V3) usage. + + For 2-D weight matrices (identified by :func:`_default_should_use_muon`), Muon + applies orthogonal gradient updates via Newton-Schulz iteration. For all + other parameters (embeddings, biases, expert weights, …) it falls back to + a standard AdamW update. + + Designed for ``MuonShardingOptimizer`` (Sharding Stage1 V3), where 2D parameters are + assigned as whole tensors to ranks. Currently we do not support TP=1, no sharding gather + or TP communication is needed during the optimizer step. + + Args: + learning_rate (float | LRScheduler): Learning rate. Default: ``0.02``. + parameters (list[Tensor]): Flat list of parameters to optimize. + momentum (float): Momentum coefficient for the Muon update. Default: ``0.95``. + adam_beta1 (float): β₁ for the AdamW fallback. Default: ``0.9``. + adam_beta2 (float): β₂ for the AdamW fallback. Default: ``0.95``. + weight_decay (float): Decoupled weight decay. Default: ``0.01``. + ns_steps (int): Newton-Schulz iteration steps. Default: ``5``. + nesterov (bool): Use Nesterov momentum in Muon. Default: ``True``. + adam_epsilon (float): ε for numerical stability in AdamW. Default: ``1e-9``. + grad_clip (GradientClipBase | None): Gradient clipping. Default: ``None``. + apply_decay_param_fun (callable | None): Function to select which + parameters receive weight decay. Default: ``None``. + muon_version (int): Scaling-function version (1/2/3). Default: ``1``. + muon_exclude_patterns (list[str] | None): Parameter names containing + any of these substrings will use AdamW instead of Muon. + Example: ``['embed', 'bias', 'lm_head', 'mlp.gate']``. + Default: ``None``. + muon_qkv_update_mode (str): Strategy for QKV fused weight matrices. + ``"split_head"`` orthogonalises each Q/K/V head independently; + ``"split_qkv"`` treats Q, K, V as three separate matrices; + ``"fused_qkv"`` treats the entire QKV matrix as one. + Default: ``"split_head"``. + muon_ffn_split (bool): If True, split FFN gate_up fused weights into + gate and up projections and orthogonalise them independently. + Default: ``False``. + muon_extra_scale_factor (float): Extra multiplicative scale applied + after the dimension-dependent scaling in ``_scaling_fn``. + Default: ``0.2``. + muon_param_info_map (MuonParamInfoMap | None): Per-parameter metadata + dict mapping param name to :class:`MuonParamInfo` (use_muon, + qkv_info, intermediate_size). Built by Trainer and passed in. + Default: ``None``. + multi_precision (bool): Maintain FP32 master weights when training in + BF16/FP16. Default: ``False``. + name (str | None): Optional name for the optimizer instance. + """ + + _moment_acc_str = "moment1" + _moment2_acc_str = "moment2" + _beta1_pow_acc_str = "beta1_pow_acc" + _beta2_pow_acc_str = "beta2_pow_acc" + + def __init__( + self, + learning_rate=0.02, + parameters=None, + momentum=0.95, + adam_beta1=0.9, + adam_beta2=0.95, + weight_decay=0.01, + ns_steps=5, + ns_coeff_type="simple", + nesterov=True, + adam_epsilon=1e-9, + grad_clip=None, + apply_decay_param_fun=None, + muon_version=1, + muon_exclude_patterns=None, + muon_qkv_update_mode="split_head", + muon_ffn_split=False, + muon_extra_scale_factor=0.2, + muon_param_info_map: MuonParamInfoMap | None = None, + multi_precision=False, + name=None, + **kwargs, + ): + if parameters is None: + raise ValueError( + "parameters argument given to the Optimizer should not be None." + ) + if not isinstance(parameters, list): + raise TypeError("parameters must be a list.") + if len(parameters) > 0 and isinstance(parameters[0], dict): + raise TypeError( + "Muon optimizer only supports a flat list of parameters, " + "not a list of parameter groups." + ) + if grad_clip is not None and not isinstance( + grad_clip, GradientClipBase + ): + raise TypeError( + "'grad_clip' should be an instance of GradientClipBase's derived class" + ) + + defaults = { + "momentum": momentum, + "adam_beta1": adam_beta1, + "adam_beta2": adam_beta2, + "weight_decay": weight_decay, + "ns_steps": ns_steps, + "nesterov": nesterov, + "epsilon": adam_epsilon, + "muon_version": muon_version, + "ns_coeff_type": ns_coeff_type, + } + + super().__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name, + ) + + self._multi_precision = multi_precision + self._master_weights = {} + self._apply_decay_param_fun = apply_decay_param_fun + self._muon_split_logged = False + self._muon_exclude_patterns = muon_exclude_patterns + self._muon_qkv_update_mode = muon_qkv_update_mode + self._muon_ffn_split = muon_ffn_split + self._muon_extra_scale_factor = muon_extra_scale_factor + self._ns_coeff_type = ns_coeff_type + self._muon_param_info_map = muon_param_info_map or {} + self._default_dict.update(defaults) + + # ------------------------------------------------------------------ + # Accumulator management + # ------------------------------------------------------------------ + + def _ensure_accumulators(self, param, use_muon, group): + """Create optimizer accumulators for *param* if they do not exist yet.""" + if ( + self._moment_acc_str in self._accumulators + and param.name in self._accumulators[self._moment_acc_str] + ): + return + + # FP32 master weight for mixed-precision training + if self._multi_precision and self._is_dtype_fp16_or_bf16(param.dtype): + if param.name not in self._master_weights: + self._create_master_weight(param) + + self._add_accumulator( + self._moment_acc_str, + param, + dtype=paddle.float32, + fill_value=0.0, + shape=param.shape, + type=framework.core.VarDesc.VarType.DENSE_TENSOR, + ) + + if not use_muon: + # AdamW-specific states + self._add_accumulator( + self._moment2_acc_str, + param, + dtype=paddle.float32, + fill_value=0.0, + shape=param.shape, + type=framework.core.VarDesc.VarType.DENSE_TENSOR, + ) + for acc_name, init_val in [ + (self._beta1_pow_acc_str, group.get("adam_beta1", 0.9)), + (self._beta2_pow_acc_str, group.get("adam_beta2", 0.95)), + ]: + self._add_accumulator( + acc_name, + param, + dtype=paddle.float32, + fill_value=1.0, + shape=[1], + type=framework.core.VarDesc.VarType.DENSE_TENSOR, + ) + + def _create_accumulators(self, block, parameters): + """Standard entry-point used by checkpoint-resume infrastructure.""" + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + for p in parameters: + param_info = self._muon_param_info_map.get(p.name) + if param_info is not None: + use_muon = param_info.use_muon + else: + use_muon = _default_should_use_muon( + p.name, + getattr(p, "original_shape", p.shape), + self._muon_exclude_patterns, + ) + self._ensure_accumulators(p, use_muon, self._default_dict) + + # ------------------------------------------------------------------ + # Newton-Schulz orthogonalisation + # ------------------------------------------------------------------ + + @staticmethod + def _zeropower_via_newtonschulz5( + X, steps=5, eps=1e-9, ns_coeff_type="simple" + ): + """Approximate the matrix sign function via Newton-Schulz iteration. + + Args: + X: Input tensor to orthogonalize. + steps: Number of Newton-Schulz iterations. + eps: Small constant for numerical stability. + ns_coeff_type: Type of coefficient set to use. + Options: "simple", "quintic", "polar_express", "aol". + """ + # Get coefficient set + coeff_sets = _NS_COEFFICIENT_SETS.get( + ns_coeff_type, _NS_COEFFICIENT_SETS["simple"] + ) + + if X.shape[-2] > X.shape[-1]: + X = X.T + transpose = True + else: + transpose = False + + orig_shape = X.shape + X_flat = X.flatten(start_axis=-2) + X_flat = paddle.nn.functional.normalize( + X_flat, p=2, axis=-1, epsilon=eps + ) + X = X_flat.reshape(orig_shape).astype(paddle.bfloat16) + + # Iterate with cycling coefficients + for i in range(steps): + a, b, c = coeff_sets[i % len(coeff_sets)] + A = paddle.matmul(X, X, transpose_y=True) + B = paddle.addmm(input=A, x=A, y=A, beta=b, alpha=c) + X = paddle.addmm(input=X, x=B, y=X, beta=a, alpha=1.0) + + return X.T if transpose else X + + @staticmethod + def _scaling_fn(orthogonal_update, version, extra_scale_factor=1.0): + """Apply dimension-dependent scaling to the orthogonal update.""" + din, dout = orthogonal_update.shape[0], orthogonal_update.shape[1] + if version == 1: + scale = max(1, dout / din) ** 0.5 + elif version == 2: + scale = (dout / din) ** 0.5 + else: # version == 3 (default) + scale = max(dout, din) ** 0.5 + return orthogonal_update * scale * extra_scale_factor + + @staticmethod + def _ortho_qkv_per_head( + matrix_2d_global, + kv_head_num, + num_key_value_groups, + ortho_fn, + ): + """Orthogonalise each Q/K/V head independently (interleaved layout). + + Args: + matrix_2d_global: QKV weight matrix [hidden, (num_key_value_groups + 2)*kv_head_num*head_dim]. + kv_head_num: Number of K/V heads. + num_key_value_groups: Number of Q heads per KV head. + ortho_fn: Callable (2d_matrix) -> 2d_matrix applying NS + scaling. + + Returns: + orthogonal_update: Same shape as input, each head orthogonalised. + """ + # Interleaved layout: [Q_kv0, K0, V0, Q_kv1, K1, V1, ...] + head_dim = matrix_2d_global.shape[1] // ( + num_key_value_groups * kv_head_num + 2 * kv_head_num + ) + groups = paddle.split(matrix_2d_global, kv_head_num, axis=1) + + processed_groups = [] + for group in groups: + q_part, k_head, v_head = paddle.split( + group, + [num_key_value_groups * head_dim, head_dim, head_dim], + axis=1, + ) + q_heads = paddle.split(q_part, num_key_value_groups, axis=1) + q_ortho = paddle.concat([ortho_fn(h) for h in q_heads], axis=1) + processed_groups.append( + paddle.concat( + [q_ortho, ortho_fn(k_head), ortho_fn(v_head)], axis=1 + ) + ) + + return paddle.concat(processed_groups, axis=1) + + @staticmethod + def _ortho_qkv_sep( + matrix_2d, + kv_head_num, + num_key_value_groups, + ortho_fn, + ): + """Orthogonalise Q, K, V as three separate whole matrices (interleaved layout). + + Gathers all Q heads into one block, all K heads into one block, all V heads + into one block (across kv_groups), orthogonalises each block as a whole with + one NS call, then scatters back to interleaved order. + + Args: + matrix_2d: QKV weight matrix [hidden, (num_key_value_groups + 2)*kv_head_num*head_dim]. + kv_head_num: Number of K/V heads. + num_key_value_groups: Number of Q heads per KV head. + ortho_fn: Callable (2d_matrix) -> 2d_matrix applying NS + scaling. + + Returns: + orthogonal_update: Same shape as input, Q/K/V each orthogonalised as whole. + """ + # Interleaved layout: [Q_kv0, K0, V0, Q_kv1, K1, V1, ...] + head_dim = matrix_2d.shape[1] // ( + num_key_value_groups * kv_head_num + 2 * kv_head_num + ) + q_group_size = num_key_value_groups * head_dim + + # Step 1: gather Q / K / V parts from each kv_group + groups = paddle.split(matrix_2d, kv_head_num, axis=1) + q_parts, k_parts, v_parts = [], [], [] + for group in groups: + q_p, k_p, v_p = paddle.split( + group, [q_group_size, head_dim, head_dim], axis=1 + ) + q_parts.append(q_p) + k_parts.append(k_p) + v_parts.append(v_p) + + # Step 2: orthogonalise each projection as one whole matrix + q_ortho = ortho_fn(paddle.concat(q_parts, axis=1)) + k_ortho = ortho_fn(paddle.concat(k_parts, axis=1)) + v_ortho = ortho_fn(paddle.concat(v_parts, axis=1)) + + # Step 3: split back and restore interleaved layout + q_groups = paddle.split(q_ortho, kv_head_num, axis=1) + k_groups = paddle.split(k_ortho, kv_head_num, axis=1) + v_groups = paddle.split(v_ortho, kv_head_num, axis=1) + + return paddle.concat( + [ + paddle.concat([q_groups[i], k_groups[i], v_groups[i]], axis=1) + for i in range(kv_head_num) + ], + axis=1, + ) + + @staticmethod + def _ortho_ffn_gate_up(matrix, intermediate_size, ortho_fn): + """Orthogonalise gate and up projections independently for FFN. + + Args: + matrix: FFN weight tensor. + - 2D: [hidden, 2*intermediate_size] for standard FFN + - 3D: [num_experts, hidden, 2*intermediate_size] for MoE FFN + intermediate_size: Size of each of gate/up projections. + ortho_fn: Callable (2d_matrix) -> 2d_matrix applying NS + scaling. + + Returns: + orthogonal_update: Tensor with gate and up orthogonalised separately. + """ + if matrix.ndim == 2: + gate, up = paddle.split( + matrix, [intermediate_size, intermediate_size], axis=1 + ) + return paddle.concat([ortho_fn(gate), ortho_fn(up)], axis=1) + + elif matrix.ndim == 3: + # MoE FFN: [n_experts, hidden, 2*intermediate_size] + expert_updates = [] + for ei in range(matrix.shape[0]): + gate, up = paddle.split( + matrix[ei], [intermediate_size, intermediate_size], axis=1 + ) + expert_updates.append( + paddle.concat([ortho_fn(gate), ortho_fn(up)], axis=1) + ) + return paddle.stack(expert_updates, axis=0) + + else: + raise ValueError( + f"FFN gate_up split expects 2D or 3D tensor, got shape {matrix.shape}" + ) + + @staticmethod + def _ortho_mla_per_head( + matrix_2d_global, + head_num, + ortho_fn, + axis, + ): + """Orthogonalise each MLA head independently.""" + groups = paddle.split(matrix_2d_global, head_num, axis=axis) + + processed_groups = [] + for group in groups: + processed_groups.append(ortho_fn(group)) + + return paddle.concat(processed_groups, axis=axis) + + # ------------------------------------------------------------------ + # Per-parameter update rules + # ------------------------------------------------------------------ + + @staticmethod + def _adamw_update( + param, + grad, + lr, + moment1, + moment2, + beta1_pow, + beta2_pow, + beta1, + beta2, + epsilon, + weight_decay, + ): + """In-place AdamW update for 1-D sharded parameters.""" + with paddle.no_grad(): + beta1_pow.scale_(beta1) + beta2_pow.scale_(beta2) + + if weight_decay > 0: + param.scale_(1.0 - lr * weight_decay) + + grad_f32 = ( + grad.astype(paddle.float32) + if grad.dtype != paddle.float32 + else grad + ) + + moment1.scale_(beta1).add_(grad_f32, alpha=1.0 - beta1) + moment2.scale_(beta2).add_( + paddle.square(grad_f32), alpha=1.0 - beta2 + ) + + bias1 = 1.0 - beta1_pow + bias2 = 1.0 - beta2_pow + update = ( + (moment1 / bias1) + / ((paddle.sqrt(moment2) / paddle.sqrt(bias2)) + epsilon) + * lr + ) + + if update.dtype != param.dtype: + update = update.astype(param.dtype) + + if hasattr(param, "subtract_"): + param.subtract_(update) + else: + paddle.assign(param - update, param) + + def _muon_update( + self, + param, + grad, + lr, + momentum_buffer, + momentum_beta, + ns_steps, + nesterov, + epsilon, + weight_decay, + version, + ): + """In-place Muon update for a 2D parameter tensor. + + Applies Newton-Schulz orthogonalisation to the 2D weight matrix and + updates the parameter in-place. MuonShardingOptimizer assigns whole + 2D tensors to ranks, so no sharding gather or TP communication is needed. + """ + param_shape = getattr(param, "original_shape", param.shape) + param_info = self._muon_param_info_map.get(param.name) + is_qkv = param_info is not None and param_info.is_qkv + is_mla: bool = param_info is not None and param_info.is_mla + is_ffn_gate_up = param_info is not None and param_info.is_ffn_gate_up + + with paddle.no_grad(): + grad_f32 = ( + grad.astype(momentum_buffer.dtype) + if grad.dtype != momentum_buffer.dtype + else grad + ) + + # Step 1: Momentum update + new_momentum = paddle.lerp( + momentum_buffer, grad_f32, 1.0 - momentum_beta + ) + paddle.assign(new_momentum, momentum_buffer) + update_buffer = ( + paddle.lerp(grad_f32, momentum_buffer, momentum_beta) + if nesterov + else momentum_buffer + ) + + # Step 2: Reshape update buffer to 2D matrix. + # MuonShardingOptimizer assigns whole 2D tensors to ranks, so params + # are already 2D/3D (no sharding gather needed). + matrix_2d_global = update_buffer.reshape(param_shape) + + # Shared NS + scaling closure (captures ns_steps, epsilon, version, ns_coeff_type) + def ortho_fn(m): + ns_out = Muon._zeropower_via_newtonschulz5( + m, + steps=ns_steps, + eps=epsilon, + ns_coeff_type=self._ns_coeff_type, + ) + scaled = Muon._scaling_fn( + ns_out, version, self._muon_extra_scale_factor + ) + return scaled + + # Step 3: Newton-Schulz orthogonalisation + if is_ffn_gate_up and self._muon_ffn_split: + # FFN gate_up split: orthogonalise gate and up projections independently. + intermediate_size = param_info.intermediate_size + if MUON_DEBUG: + _global_rank = paddle.distributed.get_rank() + if _global_rank == 0: + _logger.info( + f"[Muon] FFN split: param={param.name}, " + f"shape={matrix_2d_global.shape}, " + f"intermediate_size={intermediate_size}" + ) + + orthogonal_update = Muon._ortho_ffn_gate_up( + matrix_2d_global, intermediate_size, ortho_fn + ) + elif matrix_2d_global.ndim == 3: + # 3D fused MoE expert tensor [n_experts, H, I]. + # Apply Newton-Schulz independently to each expert's 2D slice. + n_experts = matrix_2d_global.shape[0] + orthogonal_update = paddle.stack( + [ortho_fn(matrix_2d_global[ei]) for ei in range(n_experts)], + axis=0, + ) + elif is_qkv and self._muon_qkv_update_mode in ( + "split_head", + "split_qkv", + ): + # Read QKV head info from param_info + qkv_info = param_info.qkv_info + kv_head_num = qkv_info.kv_head_num + num_key_value_groups = qkv_info.num_key_value_groups + + if self._muon_qkv_update_mode == "split_head": + # split_head update: each Q/K/V head orthogonalised independently. + if MUON_DEBUG: + _global_rank = paddle.distributed.get_rank() + if _global_rank == 0: + _logger.info( + f"[Muon] QKV split_head: param={param.name}, " + f"shape={matrix_2d_global.shape}, " + f"heads={qkv_info.head_num}/{kv_head_num}, " + f"num_key_value_groups={num_key_value_groups}" + ) + orthogonal_update = Muon._ortho_qkv_per_head( + matrix_2d_global, + kv_head_num, + num_key_value_groups, + ortho_fn, + ) + else: + # split_qkv: Q, K, V each as a whole matrix, one NS call each. + if MUON_DEBUG: + _global_rank = paddle.distributed.get_rank() + if _global_rank == 0: + _logger.info( + f"[Muon] QKV split_qkv: param={param.name}, " + f"shape={matrix_2d_global.shape}, " + f"head_num={qkv_info.head_num}, kv_head_num={kv_head_num}, " + f"num_key_value_groups={num_key_value_groups}" + ) + orthogonal_update = Muon._ortho_qkv_sep( + matrix_2d_global, + kv_head_num, + num_key_value_groups, + ortho_fn, + ) + elif is_mla and self._muon_qkv_update_mode == "split_head": + # MLA split_head update: each head of [q_b_proj, kv_b_proj, o_proj] orthogonalised independently. + mla_info = param_info.mla_info + param_name: str = mla_info.param_name + head_num = mla_info.head_num + if MUON_DEBUG: + _global_rank = paddle.distributed.get_rank() + if _global_rank == 0: + _logger.info( + f"[Muon] MLA split_head: param={param.name}, param_name={param_name}, " + f"shape={matrix_2d_global.shape}, " + f"head_num={head_num}" + ) + assert param_name in ("q_b_proj", "kv_b_proj", "o_proj"), ( + f"Unsupported MLA param name: {param_name}" + ) + orthogonal_update = Muon._ortho_mla_per_head( + matrix_2d_global, + head_num, + ortho_fn, + 0 if param_name == "o_proj" else 1, + ) + else: + # Standard 2D update: entire matrix as one Newton-Schulz call. + orthogonal_update = ortho_fn(matrix_2d_global) + + # Step 4: Apply update with optional weight decay + if weight_decay > 0: + param.scale_(1.0 - lr * weight_decay) + + final_step = orthogonal_update * lr + if final_step.dtype != param.dtype: + final_step = final_step.astype(param.dtype) + + if hasattr(param, "subtract_"): + param.subtract_(final_step) + else: + paddle.assign(param - final_step, param) + + # ------------------------------------------------------------------ + # Core optimization step + # ------------------------------------------------------------------ + + def _apply_optimize(self, loss, startup_program, params_grads): + if not framework.in_dygraph_mode(): + raise NotImplementedError( + "Muon optimizer only supports dygraph mode." + ) + + if self._grad_clip is not None: + params_grads = self._grad_clip(params_grads) + + group = self._default_dict + lr = self._learning_rate + if isinstance(lr, paddle.optimizer.lr.LRScheduler): + lr = lr() + wd = group.get("weight_decay", 0.0) + + muon_params = [] + adamw_params = [] + for param, grad in params_grads: + if grad is None: + continue + + param_info = self._muon_param_info_map.get(param.name) + assert param_info is not None, ( + f"muon_param_info_map does not have {param.name}" + ) + use_muon = param_info.use_muon + + self._ensure_accumulators(param, use_muon, group) + if use_muon: + muon_params.append((param, grad)) + else: + adamw_params.append((param, grad)) + + # --- Pass 1: Muon updates (large temporary tensors) --- + for param, grad in muon_params: + self._muon_update( + param, + grad, + lr, + self._get_accumulator(self._moment_acc_str, param), + group.get("momentum", 0.95), + group.get("ns_steps", 5), + group.get("nesterov", True), + group.get("epsilon", 1e-9), + wd, + version=group.get("muon_version", 3), + ) + if self._multi_precision and param.name in self._master_weights: + with paddle.no_grad(): + _cast_tmp = paddle.cast(param, paddle.float32) + paddle.assign(_cast_tmp, self._master_weights[param.name]) + del _cast_tmp + + # --- Pass 2: AdamW updates --- + for param, grad in adamw_params: + self._adamw_update( + param, + grad, + lr, + self._get_accumulator(self._moment_acc_str, param), + self._get_accumulator(self._moment2_acc_str, param), + self._get_accumulator(self._beta1_pow_acc_str, param), + self._get_accumulator(self._beta2_pow_acc_str, param), + group.get("adam_beta1", 0.9), + group.get("adam_beta2", 0.95), + group.get("epsilon", 1e-9), + wd, + ) + if self._multi_precision and param.name in self._master_weights: + with paddle.no_grad(): + _cast_tmp = paddle.cast(param, paddle.float32) + paddle.assign(_cast_tmp, self._master_weights[param.name]) + del _cast_tmp + + @framework.dygraph_only + def step(self) -> None: + params_grads = [ + (param, param._grad_ivar()) + for param in self._parameter_list + if not param.stop_gradient and param._grad_ivar() is not None + ] + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) + + def sharded_state_dict( + self, + model_sharded_state_dict: ShardedStateDict, + ) -> ShardedStateDict: + """Build a sharded optimizer state dict for flex checkpoint save/load. + + The layout mirrors :class:`paddle.optimizer.AdamW`'s implementation so + that the same ``dist.save_state_dict`` / ``dist.load_state_dict`` path + works for Muon checkpoints. + + Args: + model_sharded_state_dict: Sharded model state dict produced by + ``model.sharded_state_dict()``. + + Returns: + A dict mapping ``"."`` keys to + :class:`ShardedWeight` objects. + """ + _FP32_MASTER = "fp32_master_0" + _optimizer_scalar_names = [ + "beta1_pow_acc_0", + "beta2_pow_acc_0", + ] + _optimizer_vector_names = [ + "moment1_0", + "moment2_0", + ] + + def _split_state_name(vname): + if _FP32_MASTER in vname: + return tuple(vname.split("_" + _FP32_MASTER + "_", 1)) + for suffix in _optimizer_scalar_names + _optimizer_vector_names: + if vname.endswith(suffix): + return vname[: -(len(suffix) + 1)], suffix + raise ValueError( + f"Cannot parse optimizer state variable name: {vname!r}" + ) + + model_sharded_state_dict = dict( + sorted(model_sharded_state_dict.items()) + ) + + # Build static-name → struct-name mapping (handles shared weights) + static_to_struct = {} + for struct_name, sw in model_sharded_state_dict.items(): + local_name = sw.local_tensor.name + if local_name not in static_to_struct: + static_to_struct[local_name] = struct_name + + optimizer_state_dict = self.state_dict() + master_weights = optimizer_state_dict.pop("master_weights", None) + optimizer_state_dict.pop("LR_Scheduler", None) + + sharded_state: ShardedStateDict = {} + + # Optimizer states (moment1, moment2, beta_pow scalars) + for key, tensor in optimizer_state_dict.items(): + static_name, state_type = _split_state_name(key) + struct_name = static_to_struct[static_name] + sharded_param = model_sharded_state_dict[struct_name] + unified_name = f"{struct_name}.{state_type}" + + if state_type in _optimizer_vector_names: + # Vector states share the same sharding layout as the parameter + if tensor.is_dist(): + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=tensor.shape, + global_shape=tensor.shape, + global_offset=sharded_param.global_offset, + ) + else: + # Reshape accumulator if numel matches but shape differs. + # MoE: grouped_gemm_experts param.shape is 3D + # [n_experts, H, I] but model.state_dict() returns actual + # C++ storage shape 2D [n_experts*H, I]. moment1 was + # created with 3D shape, so we need to reshape here. + # V2 is unaffected: its moments are always 1D shards, + # so shape always matches and reshape is never triggered. + target_shape = sharded_param.local_shape + if ( + tuple(tensor.shape) != tuple(target_shape) + and tensor.numel() + == paddle.to_tensor(list(target_shape)).prod().item() + ): + tensor = tensor.reshape(target_shape) + sharded_state[unified_name] = ( + create_sharded_weight_with_new_local( + unified_name, tensor, sharded_param + ) + ) + else: + # Scalar states (beta_pow) are replicated – save as-is + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=(1,), + global_shape=(1,), + global_offset=(0,), + ) + + # FP32 master weights + if master_weights: + for weight_key, tensor in master_weights.items(): + struct_name = static_to_struct[weight_key] + sharded_param = model_sharded_state_dict[struct_name] + unified_name = f"{struct_name}.w_0" + + if tensor.is_dist(): + sharded_state[unified_name] = ShardedWeight( + key=unified_name, + local_tensor=tensor, + local_shape=tensor.shape, + global_shape=tensor.shape, + global_offset=sharded_param.global_offset, + ) + else: + sharded_state[unified_name] = ( + create_sharded_weight_with_new_local( + unified_name, tensor, sharded_param + ) + ) + + return sharded_state diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index aafcd874f001f..96bbf7f4777e3 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -946,3 +946,19 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) test_schedule_node MODULES test_schedule_node ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") endif() +if(WITH_NCCL) + if((WITH_GPU) AND LOCAL_ALL_PLAT) + bash_test_modules( + test_parallel_dygraph_muon + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "400" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22348;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_parallel_dygraph_muon PROPERTIES TIMEOUT "400") + endif() +endif() diff --git a/test/collective/fleet/hybrid_parallel_sharding_muon_model.py b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py new file mode 100644 index 0000000000000..a478e8e155cdb --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py @@ -0,0 +1,369 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Validates Muon optimizer with MuonShardingOptimizer. +# Muon requires whole 2D tensors for orthogonalization, so split_param is disabled. +# Tests all combinations of QKV/FFN/ns_coeff_type modes. +# Topology: sharding_degree=2, mp_degree=1 (2 ranks total) + +import random +import unittest +from dataclasses import dataclass + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.optimizer.muon import ( + MuonParamInfo, + QKVInfo, + _default_should_use_muon, +) + +# Parameter combinations +QKV_UPDATE_MODES = ["split_head", "split_qkv", "fused_qkv"] +FFN_SPLITS = [True, False] +NS_COEFF_TYPES = ["simple", "quintic", "polar_express", "aol"] + +# Model config +vocab_size = 20 +hidden_size = 64 +head_num = 4 # num_attention_heads +kv_head_num = 2 # num_key_value_heads (GQA) +head_dim = hidden_size // head_num +intermediate_size = 128 +qkv_dim = (head_num + 2 * kv_head_num) * head_dim +seq_length = 2 +batch_size = 4 +STEPS = 3 + +sharding_degree = 2 + + +@dataclass +class TestConfig: + """Test model config.""" + + vocab_size: int = vocab_size + hidden_size: int = hidden_size + head_num: int = head_num + kv_head_num: int = kv_head_num + head_dim: int = head_dim + intermediate_size: int = intermediate_size + qkv_dim: int = qkv_dim + + +class QKVFFNNet(paddle.nn.Layer): + """Test model with QKV and FFN gate_up. + + Parameter naming follows PaddleFormers: + - qkv_proj.weight: QKV fused weights + - up_gate_proj.weight: FFN gate_up fused weights + """ + + def __init__( + self, + config, + np_embed, + np_qkv, + np_o_proj, + np_up_gate, + np_down_proj, + np_lm_head, + ): + super().__init__() + self.config = config + + self.embed_tokens = paddle.nn.Embedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_embed) + ), + ) + + self.qkv_proj = paddle.nn.Linear( + config.hidden_size, + config.qkv_dim, + bias_attr=False, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_qkv) + ), + ) + + self.o_proj = paddle.nn.Linear( + config.head_num * config.head_dim, + config.hidden_size, + bias_attr=False, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_o_proj) + ), + ) + + self.up_gate_proj = paddle.nn.Linear( + config.hidden_size, + 2 * config.intermediate_size, + bias_attr=False, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_up_gate) + ), + ) + + self.down_proj = paddle.nn.Linear( + config.hidden_size, + config.hidden_size, + bias_attr=False, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_down_proj) + ), + ) + + self.lm_head = paddle.nn.Linear( + config.hidden_size, + config.vocab_size, + bias_attr=False, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_lm_head) + ), + ) + + def forward(self, x): + # Embedding + h = self.embed_tokens(x) # [batch, seq, hidden] + + # QKV projection (simplified: no real attention) + qkv = self.qkv_proj(h) # [batch, seq, qkv_dim] + # Simplified: mean pooling to simulate attention output + out = qkv.mean(axis=-1, keepdim=True) # [batch, seq, 1] + out = out.expand([x.shape[0], x.shape[1], self.config.hidden_size]) + + # Output projection + out = self.o_proj(out) # [batch, seq, hidden] + + # FFN up_gate projection + gate_up = self.up_gate_proj(out) # [batch, seq, 2*intermediate] + # Simplified: mean pooling + out = gate_up.mean(axis=-1, keepdim=True) + out = out.expand([x.shape[0], x.shape[1], self.config.hidden_size]) + + # Down projection + out = self.down_proj(out) # [batch, seq, hidden] + + # LM head + logits = self.lm_head(out) # [batch, seq, vocab] + + return logits + + +class TestDistShardingMuonTraining(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.config = TestConfig() + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": sharding_degree, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + self.strategy.use_muon_sharding = True + + fleet.init(is_collective=True, strategy=self.strategy) + self.data = [ + np.random.randint(0, vocab_size, (batch_size, seq_length)) + for _ in range(STEPS) + ] + + def train_batch(self, batch, model, optimizer): + with paddle.amp.auto_cast(dtype='bfloat16'): + output = model(batch) + loss = output.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + def build_optimizer(self, model, qkv_mode, ffn_split, ns_coeff): + """Build Muon optimizer, ref: PaddleFormers trainer.py L3122-3173.""" + + muon_param_info_map = {} + exclude_patterns = ["embed", "bias", "lm_head"] + + for name, param in model.named_parameters(): + use_muon = _default_should_use_muon( + name, param.shape, exclude_patterns + ) + + # QKV params: set QKVInfo + if "qkv_proj.weight" in name and len(param.shape) == 2: + param_info = MuonParamInfo( + use_muon=use_muon, + qkv_info=QKVInfo( + head_num=self.config.head_num, + kv_head_num=self.config.kv_head_num, + num_key_value_groups=self.config.head_num + // self.config.kv_head_num, + ), + ) + # FFN gate_up params: set intermediate_size + elif "up_gate_proj.weight" in name and ffn_split: + param_info = MuonParamInfo( + use_muon=use_muon, + intermediate_size=self.config.intermediate_size, + ) + else: + param_info = MuonParamInfo(use_muon=use_muon) + + muon_param_info_map[param.name] = param_info + return paddle.optimizer.Muon( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(0.5), + muon_param_info_map=muon_param_info_map, + muon_qkv_update_mode=qkv_mode, + muon_ffn_split=ffn_split, + ns_coeff_type=ns_coeff, + ) + + def _run_single_test(self, qkv_mode, ffn_split, ns_coeff): + """Run single test combination.""" + # Init weights + np_embed = np.random.random_sample((vocab_size, hidden_size)) + np_qkv = np.random.random_sample((hidden_size, qkv_dim)) + np_o_proj = np.random.random_sample((head_num * head_dim, hidden_size)) + np_up_gate = np.random.random_sample( + (hidden_size, 2 * intermediate_size) + ) + np_down_proj = np.random.random_sample((hidden_size, hidden_size)) + np_lm_head = np.random.random_sample((hidden_size, vocab_size)) + + # Distributed model + model_a = QKVFFNNet( + self.config, + np_embed, + np_qkv, + np_o_proj, + np_up_gate, + np_down_proj, + np_lm_head, + ) + model_a = mix_precision_utils.MixPrecisionLayer( + model_a, dtype="bfloat16" + ) + model_a = paddle.amp.decorate( + models=model_a, level='O2', dtype='bfloat16' + ) + optimizer_a = self.build_optimizer( + model_a, qkv_mode, ffn_split, ns_coeff + ) + + # Single-GPU reference model (same MixPrecisionLayer pattern for consistency) + model_b = QKVFFNNet( + self.config, + np_embed, + np_qkv, + np_o_proj, + np_up_gate, + np_down_proj, + np_lm_head, + ) + model_b = mix_precision_utils.MixPrecisionLayer( + model_b, dtype="bfloat16" + ) + model_b = paddle.amp.decorate( + models=model_b, level='O2', dtype='bfloat16' + ) + optimizer_b = self.build_optimizer( + model_b, qkv_mode, ffn_split, ns_coeff + ) + optimizer_b = mix_precision_utils.MixPrecisionOptimizer(optimizer_b) + + # Distributed wrapper + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + hcg = fleet.get_hybrid_communicate_group() + sharding_rank = hcg.get_sharding_parallel_rank() + local_batch_size = batch_size // sharding_degree + + for idx in range(STEPS): + start = sharding_rank * local_batch_size + batch_a = paddle.to_tensor( + self.data[idx][start : start + local_batch_size] + ) + batch_b = paddle.to_tensor(self.data[idx]) + + loss_a = self.train_batch(batch_a, model_a, optimizer_a) + loss_b = self.train_batch(batch_b, model_b, optimizer_b) + + # Verify param consistency + for param_a, param_b in zip( + model_a.parameters(), model_b.parameters() + ): + a_fp32 = param_a.cast('float32').numpy() + b_fp32 = param_b.cast('float32').numpy() + np.testing.assert_allclose( + a_fp32, + b_fp32, + atol=1e-3, + err_msg=f"Param {param_a.name} mismatch at step {idx}!", + ) + + @unittest.skipIf( + not paddle.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 8, + "BF16 matmul requires GPU compute capability >= 80 (Ampere+)", + ) + def test_sharding_muon(self): + """Test all 24 parameter combinations.""" + total = len(QKV_UPDATE_MODES) * len(FFN_SPLITS) * len(NS_COEFF_TYPES) + passed = 0 + failed = [] + + for qkv_mode in QKV_UPDATE_MODES: + for ffn_split in FFN_SPLITS: + for ns_coeff in NS_COEFF_TYPES: + print( + f"\n[Muon Test] qkv_mode={qkv_mode}, ffn_split={ffn_split}, ns_coeff={ns_coeff}" + ) + try: + self._run_single_test(qkv_mode, ffn_split, ns_coeff) + passed += 1 + print(f"[PASS] {qkv_mode}, {ffn_split}, {ns_coeff}") + except Exception as e: + failed.append((qkv_mode, ffn_split, ns_coeff, str(e))) + print( + f"[FAIL] {qkv_mode}, {ffn_split}, {ns_coeff}: {e}" + ) + + print(f"\n{'=' * 60}") + print(f"Muon Sharding Test Summary: {passed}/{total} passed") + if failed: + print("Failed combinations:") + for qkv, ffn, ns, err in failed: + print(f" - {qkv}, {ffn}, {ns}: {err[:100]}...") + print(f"{'=' * 60}") + + if failed: + raise AssertionError(f"{len(failed)} test combinations failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_muon.py b/test/collective/fleet/test_parallel_dygraph_muon.py new file mode 100644 index 0000000000000..c908087b465fe --- /dev/null +++ b/test/collective/fleet/test_parallel_dygraph_muon.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import ( + TestMultipleAccelerators, +) + +import paddle + + +class TestMuonParallel(TestMultipleAccelerators): + @unittest.skipIf( + not paddle.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 8, + "BF16 matmul requires GPU compute capability >= 80 (Ampere+)", + ) + def test_muon_sharding_optimizer(self): + """MuonSharding test: iterate all QKV/FFN/ns_coeff_type combinations. + + Test logic is in hybrid_parallel_sharding_muon_model.py, + iterating 24 combinations (3 qkv_modes * 2 ffn_splits * 4 ns_coeff_types). + """ + self.run_mnist_2accelerators('hybrid_parallel_sharding_muon_model.py') + + +if __name__ == "__main__": + unittest.main()