Skip to content

Commit a9a20bd

Browse files
committed
optimization for reduce-gradient by applying comm_buffer to 2d params
1 parent 1714b7d commit a9a20bd

File tree

1 file changed

+88
-14
lines changed

1 file changed

+88
-14
lines changed

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
g_shard_bypass_dygraph_optimizer = int(
5959
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
6060
)
61+
g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0))
6162

6263

6364
def _is_trainable(param):
@@ -231,6 +232,12 @@ def __init__(self, optimizer, hcg=None):
231232
for p in params:
232233
self._param2rank_2d_by_color[color_key][p.name] = rank
233234

235+
# add sort 2d params
236+
for color_key, params_2d in self._params_2d_by_color.items():
237+
params_2d.sort(
238+
key=lambda p: self._param2rank_2d_by_color[color_key][p.name]
239+
)
240+
234241
# ---- Backward compatibility: expose legacy attributes ----
235242
# These are kept for any external code that might reference them
236243
self._params_2d = self._params_2d_by_color.get(None, [])
@@ -244,6 +251,13 @@ def __init__(self, optimizer, hcg=None):
244251
'moe_expert', {}
245252
)
246253

254+
self._use_fuse_gradients = g_shard_fused_gradient
255+
# ---- Build comm buffers for 2D params (V1-style) ----
256+
if self._use_fuse_gradients:
257+
if not hasattr(self, 'comm_buffer_2d'):
258+
self.comm_buffer_2d = self._build_2d_comm_buffers()
259+
self.comm_buffer_2d.sort(key=lambda x: x._dst)
260+
247261
# ---- Step 3: Build comm buffers for 1D params (V2-style) ----
248262
self._slice_params = {}
249263
self._comm_buffer_list = []
@@ -423,6 +437,50 @@ def _partition_2d_parameters(self, params, world_size, label=""):
423437

424438
return mapping
425439

440+
def _build_2d_comm_buffers(self):
441+
"""Build communication buffers for 2D (Tensor-wise) parameters using all-reduce."""
442+
group_size = (
443+
self.comm_buffer_size_MB * 1024 * 1024
444+
if self.comm_buffer_size_MB > 0
445+
else 256 * 1024 * 1024
446+
)
447+
comm_buffers = []
448+
449+
for color_key, params_2d in self._params_2d_by_color.items():
450+
group_info = self._color_to_group_info.get(color_key, {})
451+
comm_group = group_info.get('group', None)
452+
453+
fused_parameter_group = defaultdict(list)
454+
455+
for p in params_2d:
456+
dst_rank = self._param2rank_2d_by_color[color_key][p.name]
457+
fused_parameter_group[dst_rank].append(p)
458+
459+
absolute_dst_ranks = {
460+
rank: comm_group.ranks[rank] for rank in fused_parameter_group
461+
}
462+
463+
for dst, params in fused_parameter_group.items():
464+
var_groups = assign_group_by_size(params, group_size)
465+
abs_dst = absolute_dst_ranks[dst]
466+
467+
buffer = [
468+
FusedCommBuffer(
469+
group_idx,
470+
parameters,
471+
comm_group,
472+
self.accumulate_steps,
473+
act=HOOK_ACTION.REDUCE,
474+
dst=abs_dst,
475+
release_grads=False,
476+
use_reduce_avg=True,
477+
)
478+
for group_idx, parameters in var_groups.items()
479+
]
480+
comm_buffers.extend(buffer)
481+
482+
return comm_buffers
483+
426484
# ------------------------------------------------------------------
427485
# 1D slice creation (V2-style)
428486
# ------------------------------------------------------------------
@@ -583,21 +641,26 @@ def reduce_gradients(self, parameter_list, hcg):
583641
paddle.device.synchronize()
584642

585643
with framework.no_grad():
586-
# --- Non-MoE 2D params: reduce to owner rank via sharding_group ---
587-
sharding_group = hcg.get_sharding_parallel_group()
588-
self._reduce_2d_grads(
589-
self._params_2d, self._param2rank_2d, sharding_group
590-
)
644+
# --- 2D params: reduce via comm buffers | per tensors ---
645+
if self._use_fuse_gradients:
646+
for comm_buffer in self.comm_buffer_2d:
647+
comm_buffer._comm_grads()
648+
else:
649+
# --- Non-MoE 2D params: reduce to owner rank via sharding_group ---
650+
sharding_group = hcg.get_sharding_parallel_group()
651+
self._reduce_2d_grads(
652+
self._params_2d, self._param2rank_2d, sharding_group
653+
)
591654

592-
# --- MoE expert 2D params: reduce to owner rank via moe_sharding_group ---
593-
if self._params_2d_moe and self._moe_sharding_group is not None:
594-
if self._moe_sharding_world_size > 1:
595-
self._reduce_2d_grads(
596-
self._params_2d_moe,
597-
self._param2rank_2d_moe,
598-
self._moe_sharding_group,
599-
)
600-
# When moe_sharding_degree=1, no reduce needed (single rank group)
655+
# --- MoE expert 2D params: reduce to owner rank via moe_sharding_group ---
656+
if self._params_2d_moe and self._moe_sharding_group is not None:
657+
if self._moe_sharding_world_size > 1:
658+
self._reduce_2d_grads(
659+
self._params_2d_moe,
660+
self._param2rank_2d_moe,
661+
self._moe_sharding_group,
662+
)
663+
# When moe_sharding_degree=1, no reduce needed (single rank group)
601664

602665
# --- 1D params: reduce-scatter via comm buffers ---
603666
for comm_buffer in self._comm_buffer_list:
@@ -608,6 +671,12 @@ def reduce_gradients(self, parameter_list, hcg):
608671

609672
if not self.comm_overlap:
610673
comm_buffer._comm_grads()
674+
675+
# wait for all comm_buffer tasks to finish
676+
if self._use_fuse_gradients:
677+
for comm_buffer in self.comm_buffer_2d:
678+
comm_buffer.scale_grads()
679+
for comm_buffer in self._comm_buffer_list:
611680
comm_buffer.scale_grads()
612681

613682
def filter_parameters(self, parameter_list, hcg):
@@ -722,6 +791,11 @@ def clear_grad_func(p):
722791
if comm_buffer.need_reduce_scale_sync():
723792
comm_buffer._clear_grad_storage()
724793

794+
if self._use_fuse_gradients:
795+
for comm_buffer in self.comm_buffer_2d:
796+
if comm_buffer.need_reduce_scale_sync():
797+
comm_buffer._clear_grad_storage()
798+
725799
# ------------------------------------------------------------------
726800
# Optimizer step
727801
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)