Skip to content

Commit 9d17862

Browse files
xxyuxclaude
andcommitted
feat: add Muon optimizer with distributed sharding support
Add Muon optimizer implementation with Newton-Schulz orthogonalization for distributed training: - Muon optimizer (python/paddle/optimizer/muon.py): - Newton-Schulz iteration for orthogonal gradient updates - QKV split modes: per_head, qkv_sep, full - FFN gate_up split support - Multiple NS coefficient types: simple, quintic, polar_express, aol - MuonShardingOptimizer: - Whole-tensor assignment for 2D parameters (Muon) - Element-wise sharding for non-2D parameters (AdamW) - Hybrid memory balancing across ranks - Test coverage: - All 24 parameter combinations tested - 2-GPU sharding validation against single-GPU reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0b2356c commit 9d17862

File tree

10 files changed

+2418
-7
lines changed

10 files changed

+2418
-7
lines changed

ci/h-test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ concurrency_list="^test_fp8_deep_gemm$|\
165165
^test_scaled_dot_product_attention$|\
166166
^test_compat_scaled_dot_product_attention$|\
167167
^test_flash_attention$|\
168-
^test_batched_gemm$"
168+
^test_batched_gemm$|\
169+
^test_parallel_dygraph_muon$"
169170

170171
cd ${work_dir}/build
171172
tmp_dir=$(mktemp -d)

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def __init__(self) -> None:
337337
]
338338
self.sync_param_name: list[str] = ["embedding", "layer_norm", ".b_"]
339339

340+
self.use_muon_sharding: bool = False
341+
340342
self.__lock_attr = True
341343
logger.info("distributed strategy initialized")
342344

python/paddle/distributed/fleet/meta_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AdaptiveLocalSGDOptimizer,
3232
LocalSGDOptimizer,
3333
)
34+
from .muon_sharding_optimizer import MuonShardingOptimizer # noqa: F401
3435
from .pipeline_optimizer import PipelineOptimizer # noqa: F401
3536
from .ps_optimizer import ParameterServerOptimizer # noqa: F401
3637
from .qat_optimizer import QATOptimizer # noqa: F401

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
DygraphShardingOptimizer,
2424
DygraphShardingOptimizerV2,
2525
)
26+
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
27+
MuonShardingOptimizer,
28+
)
2629
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
2730
obtain_optimizer_parameters_list,
2831
)
@@ -284,11 +287,13 @@ def __init__(self, optimizer, hcg, strategy):
284287
split_param = strategy.hybrid_configs[
285288
'sharding_configs'
286289
].split_param
287-
ShardingOptimizer = (
288-
DygraphShardingOptimizerV2
289-
if split_param
290-
else DygraphShardingOptimizer
291-
)
290+
use_muon_sharding = getattr(strategy, "use_muon_sharding", False)
291+
if use_muon_sharding:
292+
ShardingOptimizer = MuonShardingOptimizer
293+
elif split_param:
294+
ShardingOptimizer = DygraphShardingOptimizerV2
295+
else:
296+
ShardingOptimizer = DygraphShardingOptimizer
292297
optimizer = ShardingOptimizer(optimizer, hcg)
293298

294299
self._enable_timer = strategy.hybrid_configs["enable_optimizer_timer"]
@@ -335,6 +340,7 @@ def __init__(self, optimizer, hcg, strategy):
335340
MixPrecisionOptimizer,
336341
DygraphShardingOptimizer,
337342
DygraphShardingOptimizerV2,
343+
MuonShardingOptimizer,
338344
),
339345
)
340346

@@ -628,7 +634,11 @@ def _hybrid_sync_grad(self, parameter_list):
628634
if self._sharding_enable:
629635
assert isinstance(
630636
self._inner_opt,
631-
(DygraphShardingOptimizer, DygraphShardingOptimizerV2),
637+
(
638+
DygraphShardingOptimizer,
639+
DygraphShardingOptimizerV2,
640+
MuonShardingOptimizer,
641+
),
632642
)
633643
self._inner_opt.reduce_gradients(parameter_list, self._hcg)
634644
dp_parameter_list = self._inner_opt.filter_parameters(

0 commit comments

Comments
 (0)