5858g_shard_bypass_dygraph_optimizer = int (
5959 os .environ .get ("FLAGS_shard_bypass_dygraph_optimizer" , 0 )
6060)
61+ g_shard_fused_gradient = int (os .environ .get ("FLAGS_shard_fused_gradient" , 0 ))
6162
6263
6364def _is_trainable (param ):
@@ -231,6 +232,12 @@ def __init__(self, optimizer, hcg=None):
231232 for p in params :
232233 self ._param2rank_2d_by_color [color_key ][p .name ] = rank
233234
235+ # add sort 2d params
236+ for color_key , params_2d in self ._params_2d_by_color .items ():
237+ params_2d .sort (
238+ key = lambda p : self ._param2rank_2d_by_color [color_key ][p .name ]
239+ )
240+
234241 # ---- Backward compatibility: expose legacy attributes ----
235242 # These are kept for any external code that might reference them
236243 self ._params_2d = self ._params_2d_by_color .get (None , [])
@@ -244,6 +251,13 @@ def __init__(self, optimizer, hcg=None):
244251 'moe_expert' , {}
245252 )
246253
254+ self ._use_fuse_gradients = g_shard_fused_gradient
255+ # ---- Build comm buffers for 2D params (V1-style) ----
256+ if self ._use_fuse_gradients :
257+ if not hasattr (self , 'comm_buffer_2d' ):
258+ self .comm_buffer_2d = self ._build_2d_comm_buffers ()
259+ self .comm_buffer_2d .sort (key = lambda x : x ._dst )
260+
247261 # ---- Step 3: Build comm buffers for 1D params (V2-style) ----
248262 self ._slice_params = {}
249263 self ._comm_buffer_list = []
@@ -423,6 +437,50 @@ def _partition_2d_parameters(self, params, world_size, label=""):
423437
424438 return mapping
425439
440+ def _build_2d_comm_buffers (self ):
441+ """Build communication buffers for 2D (Tensor-wise) parameters using all-reduce."""
442+ group_size = (
443+ self .comm_buffer_size_MB * 1024 * 1024
444+ if self .comm_buffer_size_MB > 0
445+ else 256 * 1024 * 1024
446+ )
447+ comm_buffers = []
448+
449+ for color_key , params_2d in self ._params_2d_by_color .items ():
450+ group_info = self ._color_to_group_info .get (color_key , {})
451+ comm_group = group_info .get ('group' , None )
452+
453+ fused_parameter_group = defaultdict (list )
454+
455+ for p in params_2d :
456+ dst_rank = self ._param2rank_2d_by_color [color_key ][p .name ]
457+ fused_parameter_group [dst_rank ].append (p )
458+
459+ absolute_dst_ranks = {
460+ rank : comm_group .ranks [rank ] for rank in fused_parameter_group
461+ }
462+
463+ for dst , params in fused_parameter_group .items ():
464+ var_groups = assign_group_by_size (params , group_size )
465+ abs_dst = absolute_dst_ranks [dst ]
466+
467+ buffer = [
468+ FusedCommBuffer (
469+ group_idx ,
470+ parameters ,
471+ comm_group ,
472+ self .accumulate_steps ,
473+ act = HOOK_ACTION .REDUCE ,
474+ dst = abs_dst ,
475+ release_grads = False ,
476+ use_reduce_avg = True ,
477+ )
478+ for group_idx , parameters in var_groups .items ()
479+ ]
480+ comm_buffers .extend (buffer )
481+
482+ return comm_buffers
483+
426484 # ------------------------------------------------------------------
427485 # 1D slice creation (V2-style)
428486 # ------------------------------------------------------------------
@@ -583,21 +641,26 @@ def reduce_gradients(self, parameter_list, hcg):
583641 paddle .device .synchronize ()
584642
585643 with framework .no_grad ():
586- # --- Non-MoE 2D params: reduce to owner rank via sharding_group ---
587- sharding_group = hcg .get_sharding_parallel_group ()
588- self ._reduce_2d_grads (
589- self ._params_2d , self ._param2rank_2d , sharding_group
590- )
644+ # --- 2D params: reduce via comm buffers | per tensors ---
645+ if self ._use_fuse_gradients :
646+ for comm_buffer in self .comm_buffer_2d :
647+ comm_buffer ._comm_grads ()
648+ else :
649+ # --- Non-MoE 2D params: reduce to owner rank via sharding_group ---
650+ sharding_group = hcg .get_sharding_parallel_group ()
651+ self ._reduce_2d_grads (
652+ self ._params_2d , self ._param2rank_2d , sharding_group
653+ )
591654
592- # --- MoE expert 2D params: reduce to owner rank via moe_sharding_group ---
593- if self ._params_2d_moe and self ._moe_sharding_group is not None :
594- if self ._moe_sharding_world_size > 1 :
595- self ._reduce_2d_grads (
596- self ._params_2d_moe ,
597- self ._param2rank_2d_moe ,
598- self ._moe_sharding_group ,
599- )
600- # When moe_sharding_degree=1, no reduce needed (single rank group)
655+ # --- MoE expert 2D params: reduce to owner rank via moe_sharding_group ---
656+ if self ._params_2d_moe and self ._moe_sharding_group is not None :
657+ if self ._moe_sharding_world_size > 1 :
658+ self ._reduce_2d_grads (
659+ self ._params_2d_moe ,
660+ self ._param2rank_2d_moe ,
661+ self ._moe_sharding_group ,
662+ )
663+ # When moe_sharding_degree=1, no reduce needed (single rank group)
601664
602665 # --- 1D params: reduce-scatter via comm buffers ---
603666 for comm_buffer in self ._comm_buffer_list :
@@ -608,6 +671,12 @@ def reduce_gradients(self, parameter_list, hcg):
608671
609672 if not self .comm_overlap :
610673 comm_buffer ._comm_grads ()
674+
675+ # wait for all comm_buffer tasks to finish
676+ if self ._use_fuse_gradients :
677+ for comm_buffer in self .comm_buffer_2d :
678+ comm_buffer .scale_grads ()
679+ for comm_buffer in self ._comm_buffer_list :
611680 comm_buffer .scale_grads ()
612681
613682 def filter_parameters (self , parameter_list , hcg ):
@@ -722,6 +791,11 @@ def clear_grad_func(p):
722791 if comm_buffer .need_reduce_scale_sync ():
723792 comm_buffer ._clear_grad_storage ()
724793
794+ if self ._use_fuse_gradients :
795+ for comm_buffer in self .comm_buffer_2d :
796+ if comm_buffer .need_reduce_scale_sync ():
797+ comm_buffer ._clear_grad_storage ()
798+
725799 # ------------------------------------------------------------------
726800 # Optimizer step
727801 # ------------------------------------------------------------------
0 commit comments