Skip to content

Commit 40448eb

Browse files
committed
feat: add Muon optimizer support and related tests
1 parent a2ec403 commit 40448eb

File tree

8 files changed

+1277
-4
lines changed

8 files changed

+1277
-4
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,8 +1249,17 @@ def step(self):
12491249
self._collect_comm_buffers()
12501250
self._assign_slice_grad()
12511251

1252+
# Detect Muon by walking the wrapper chain; use name comparison to avoid
1253+
# a hard circular import.
1254+
core_opt = self._inner_opt
1255+
while hasattr(core_opt, '_inner_opt'):
1256+
core_opt = core_opt._inner_opt
1257+
is_muon = type(core_opt).__name__ == 'Muon'
1258+
12521259
if not isinstance(self._parameter_list[0], dict):
12531260
params_grads = []
1261+
# Build name→original-param map so Muon can recover full 2-D shape.
1262+
global_param_map = {p.name: p for p in self._parameter_list}
12541263
for param in self._parameter_list:
12551264
if (
12561265
hasattr(param, "regularizer")
@@ -1268,8 +1277,59 @@ def step(self):
12681277
if hasattr(param, "main_grad") and param.main_grad is not None:
12691278
grad_var = param.main_grad
12701279
if grad_var is not None:
1280+
if is_muon:
1281+
# Lazy import to avoid circular dependency.
1282+
from ...utils.muon_comm_utils import get_sharding_info, should_use_muon
1283+
original_p = global_param_map[param.name]
1284+
if should_use_muon(original_p.name, original_p.shape):
1285+
# Skip uninitialised slices and shape-[1] sentinels.
1286+
if not param._is_initialized():
1287+
continue
1288+
if list(param.shape) == [1] and list(original_p.shape) != [1]:
1289+
continue
1290+
1291+
# Annotate whether this rank holds a partial shard or the full weight.
1292+
param.is_sharded_gather = int(param.numel()) < int(original_p.numel())
1293+
param.original_shape = original_p.shape
1294+
param.split_axis = getattr(original_p, "split_axis", None)
1295+
param.needs_qkv_split = getattr(original_p, "needs_qkv_split", False)
1296+
param.head_num = getattr(original_p, "head_num", 0)
1297+
param.kv_head_num = getattr(original_p, "kv_head_num", 0)
1298+
param.is_muon = True
1299+
1300+
# MoE experts use a dedicated expert-parallel sharding group.
1301+
if getattr(original_p, "no_sync", False):
1302+
sharding_group = self._hcg.get_moe_sharding_parallel_group()
1303+
else:
1304+
sharding_group = self._hcg.get_sharding_parallel_group()
1305+
1306+
sharding_rank = sharding_group.rank
1307+
if sharding_rank == -1:
1308+
sharding_rank = 0
1309+
sharding_world_size = sharding_group.nranks
1310+
1311+
if param.is_sharded_gather:
1312+
# Compute per-rank element counts for the variable-length gather.
1313+
target_buffer = self.param2bucket[param.name][0]
1314+
indices, my_offset = get_sharding_info(
1315+
target_buffer, param.name,
1316+
sharding_world_size, sharding_rank,
1317+
)
1318+
param.sharding_indices = indices
1319+
param.sharding_my_offset = my_offset
1320+
12711321
params_grads.append((param, grad_var))
12721322

1323+
if is_muon and params_grads:
1324+
import numpy as np
1325+
# Sort: largest fully-owned params first for better allocator locality.
1326+
params_grads.sort(
1327+
key=lambda x: (
1328+
getattr(x[0], "is_sharded_gather", False),
1329+
np.prod(getattr(x[0], "original_shape", [])) if getattr(x[0], "original_shape", None) else 0
1330+
),
1331+
reverse=True,
1332+
)
12731333
if self._enable_timer:
12741334
self.timers("apply-optimize").start()
12751335

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.distributed as dist
17+
from paddle.distributed.communication.batch_isend_irecv import (
18+
_coalescing_manager as batch_isend_irecv_coalescing_manager,
19+
)
20+
21+
22+
def gather_varlen(input, dst, group, all_shape_and_dtype):
23+
"""Gather variable-length tensors from all ranks to *dst*.
24+
25+
The destination rank pre-allocates a single contiguous buffer for all
26+
incoming data to avoid memory fragmentation from intermediate concat.
27+
Non-destination ranks send their local slice and return None.
28+
29+
Args:
30+
input: Local tensor slice (may be None if this rank contributes nothing).
31+
dst: Global rank of the destination.
32+
group: The process group.
33+
all_shape_and_dtype: List of (shape, dtype) tuples, one per rank.
34+
shape is None (or shape[0] == 0) when a rank has no data.
35+
36+
Returns:
37+
Concatenated 1-D tensor on the destination rank; None elsewhere.
38+
"""
39+
tasks = []
40+
41+
if group.ranks[group.rank] == dst:
42+
# Destination: allocate one contiguous buffer and receive all slices.
43+
total_len = sum([s[0] for s, _ in all_shape_and_dtype if s is not None])
44+
dtype = all_shape_and_dtype[0][1]
45+
output_tensor = paddle.empty([total_len], dtype=dtype)
46+
47+
task_info_list = []
48+
current_offset = 0
49+
50+
with batch_isend_irecv_coalescing_manager(group, tasks):
51+
for src in range(group.nranks):
52+
shape = all_shape_and_dtype[src][0]
53+
if shape is None or shape[0] == 0:
54+
continue
55+
length = shape[0]
56+
if src != group.rank:
57+
recv_tensor = paddle.empty(shape, dtype=all_shape_and_dtype[src][1])
58+
task = dist.irecv(recv_tensor, group.ranks[src], group=group)
59+
tasks.append(task)
60+
task_info_list.append((task, recv_tensor, current_offset, length))
61+
else:
62+
output_tensor[current_offset : current_offset + length] = input
63+
current_offset += length
64+
65+
for task, recv_tensor, offset, length in task_info_list:
66+
task.wait()
67+
output_tensor[offset : offset + length] = recv_tensor
68+
del recv_tensor
69+
70+
return output_tensor
71+
72+
else:
73+
# Sender: push local slice to dst and return None.
74+
with batch_isend_irecv_coalescing_manager(group, tasks):
75+
if input is not None and input.shape[0] != 0:
76+
task = dist.isend(input, dst, group=group)
77+
tasks.append(task)
78+
79+
for task in tasks:
80+
task.wait()
81+
82+
return None
83+
84+
85+
def get_sharding_info(buffer, param_name, world_size, rank):
86+
"""Compute per-rank element counts and local offset for a sharded parameter.
87+
88+
ShardingV2 splits the flat param storage evenly across ranks. This
89+
function intersects each rank's slice of that storage with the parameter's
90+
global range to produce the element count each rank owns.
91+
92+
Args:
93+
buffer: The FusedCommBuffer that contains the parameter.
94+
param_name: Name of the parameter.
95+
world_size: Number of ranks in the sharding group.
96+
rank: Local rank in the sharding group.
97+
98+
Returns:
99+
indices: List of element counts per rank (length == world_size).
100+
my_slice_offset: Offset of this rank's slice within the full flat param.
101+
"""
102+
grad_view = buffer._sharding_param_grad_view[param_name]
103+
104+
param_global_start = grad_view._index
105+
param_global_end = grad_view._index + grad_view._padded_size
106+
107+
# ShardingV2 splits the storage buffer evenly across ranks.
108+
shard_size = buffer.param_storage.shape[0] // world_size
109+
110+
indices = []
111+
my_slice_offset = 0
112+
current_relative_offset = 0
113+
114+
for r in range(world_size):
115+
r_start = r * shard_size
116+
r_end = (r + 1) * shard_size
117+
start = max(param_global_start, r_start)
118+
end = min(param_global_end, r_end)
119+
length = max(0, end - start)
120+
indices.append(length)
121+
if r == rank:
122+
my_slice_offset = current_relative_offset
123+
current_relative_offset += length
124+
125+
return indices, my_slice_offset
126+
127+
128+
def should_use_muon(name, shape):
129+
"""Return True if a parameter should receive Muon (orthogonal) updates.
130+
131+
Muon applies only to 2-D weight matrices. Embeddings, biases, and
132+
LM-head weights fall back to AdamW.
133+
"""
134+
if len(shape) != 2:
135+
return False
136+
name = name.lower()
137+
if "embed" in name or "bias" in name or "lm_head" in name:
138+
return False
139+
return True

0 commit comments

Comments
 (0)