Skip to content

Commit 50549b3

Browse files
xxyuxclaude
andcommitted
feat: add Muon optimizer with distributed sharding support
Add Muon optimizer implementation with Newton-Schulz orthogonalization for distributed training: - Muon optimizer (python/paddle/optimizer/muon.py): - Newton-Schulz iteration for orthogonal gradient updates - QKV split modes: per_head, qkv_sep, full - FFN gate_up split support - Multiple NS coefficient types: simple, quintic, polar_express, aol - MuonShardingOptimizer: - Whole-tensor assignment for 2D parameters (Muon) - Element-wise sharding for non-2D parameters (AdamW) - Hybrid memory balancing across ranks - Test coverage: - All 24 parameter combinations tested - 2-GPU sharding validation against single-GPU reference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 87bf071 commit 50549b3

File tree

11 files changed

+767
-1107
lines changed

11 files changed

+767
-1107
lines changed

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def __init__(self) -> None:
337337
]
338338
self.sync_param_name: list[str] = ["embedding", "layer_norm", ".b_"]
339339

340+
self.use_muon_sharding: bool = False
341+
340342
self.__lock_attr = True
341343
logger.info("distributed strategy initialized")
342344

python/paddle/distributed/fleet/meta_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
from .recompute_optimizer import RecomputeOptimizer # noqa: F401
3939
from .sharding_optimizer import ShardingOptimizer # noqa: F401
4040
from .tensor_parallel_optimizer import TensorParallelOptimizer # noqa: F401
41+
from .muon_sharding_optimizer import MuonShardingOptimizer # noqa: F401

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,17 +1252,8 @@ def step(self):
12521252
self._collect_comm_buffers()
12531253
self._assign_slice_grad()
12541254

1255-
# Detect Muon by walking the wrapper chain; use name comparison to avoid
1256-
# a hard circular import.
1257-
core_opt = self._inner_opt
1258-
while hasattr(core_opt, '_inner_opt'):
1259-
core_opt = core_opt._inner_opt
1260-
is_muon = type(core_opt).__name__ == 'Muon'
1261-
12621255
if not isinstance(self._parameter_list[0], dict):
12631256
params_grads = []
1264-
# Build name→original-param map so Muon can recover full 2-D shape.
1265-
global_param_map = {p.name: p for p in self._parameter_list}
12661257
for param in self._parameter_list:
12671258
if (
12681259
hasattr(param, "regularizer")
@@ -1280,25 +1271,8 @@ def step(self):
12801271
if hasattr(param, "main_grad") and param.main_grad is not None:
12811272
grad_var = param.main_grad
12821273
if grad_var is not None:
1283-
if is_muon:
1284-
from .muon_sharding_annotations import (
1285-
annotate_muon_params,
1286-
)
1287-
1288-
original_p = global_param_map[param.name]
1289-
if not annotate_muon_params(
1290-
param, original_p, self._hcg, self.param2bucket
1291-
):
1292-
continue
1293-
12941274
params_grads.append((param, grad_var))
12951275

1296-
if is_muon and params_grads:
1297-
from .muon_sharding_annotations import (
1298-
sort_muon_params_grads,
1299-
)
1300-
1301-
sort_muon_params_grads(params_grads)
13021276
if self._enable_timer:
13031277
self.timers("apply-optimize").start()
13041278

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
DygraphShardingOptimizer,
2424
DygraphShardingOptimizerV2,
2525
)
26-
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer_v3 import (
27-
DygraphShardingOptimizerV3,
26+
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
27+
MuonShardingOptimizer,
2828
)
2929
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
3030
obtain_optimizer_parameters_list,
@@ -287,9 +287,9 @@ def __init__(self, optimizer, hcg, strategy):
287287
split_param = strategy.hybrid_configs[
288288
'sharding_configs'
289289
].split_param
290-
use_sharding_v3 = os.environ.get("FLAGS_sharding_v3", "0") == "1"
291-
if use_sharding_v3 and split_param:
292-
ShardingOptimizer = DygraphShardingOptimizerV3
290+
use_muon_sharding = getattr(strategy, "use_muon_sharding", False)
291+
if use_muon_sharding:
292+
ShardingOptimizer = MuonShardingOptimizer
293293
elif split_param:
294294
ShardingOptimizer = DygraphShardingOptimizerV2
295295
else:
@@ -340,7 +340,7 @@ def __init__(self, optimizer, hcg, strategy):
340340
MixPrecisionOptimizer,
341341
DygraphShardingOptimizer,
342342
DygraphShardingOptimizerV2,
343-
DygraphShardingOptimizerV3,
343+
MuonShardingOptimizer,
344344
),
345345
)
346346

@@ -637,7 +637,7 @@ def _hybrid_sync_grad(self, parameter_list):
637637
(
638638
DygraphShardingOptimizer,
639639
DygraphShardingOptimizerV2,
640-
DygraphShardingOptimizerV3,
640+
MuonShardingOptimizer,
641641
),
642642
)
643643
self._inner_opt.reduce_gradients(parameter_list, self._hcg)

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/muon_sharding_annotations.py

Lines changed: 0 additions & 102 deletions
This file was deleted.

0 commit comments

Comments
 (0)