Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from megatron.core.distributed.distributed_data_parallel_config import (
DistributedDataParallelConfig,
)
from megatron.core.transformer import TransformerLayer
from megatron.core.utils import is_submodule
except ImportError:
# Megatron-LM is not installed, use Megatron-FSDP as a standalone module.
Expand All @@ -73,6 +74,46 @@ class TrainingState(Enum):
IDLE = auto()


def _maybe_setup_delayed_wgrad_for_experts(module, process_post_backward_gradients_fn):
"""Configure delayed wgrad gradient processing for MoE expert parameters.

When ``delay_wgrad_compute_for_te_grouped_gemm`` is enabled on a TransformerLayer,
this function:
1. Marks expert parameters so the normal post-accumulate-grad hook is skipped.
2. Registers a callback on the MoE layer that invokes FSDP's gradient
reduce-scatter after the delayed wgrad computation completes.

Args:
module: The module being processed in the forward pre-hook. Only
``TransformerLayer`` instances with the delayed wgrad config flag
enabled are affected; all other modules are no-ops.
process_post_backward_gradients_fn: The FSDP gradient processing function
(``_process_post_backward_gradients``) to be called after the delayed
wgrad computation finishes.
"""
if not (isinstance(module, TransformerLayer) and module.is_moe_layer):
return

if not getattr(module.config, 'delay_wgrad_compute_for_te_grouped_gemm', False):
return

if getattr(module.mlp, '_process_expert_grads_fn', None) is not None:
return

def _make_process_expert_grads(mlp_module):
def _process_expert_grads():
params = list(mlp_module.experts.parameters())
process_post_backward_gradients_fn(params)

return _process_expert_grads

expert_params = list(module.mlp.experts.parameters())
for p in expert_params:
p._fsdp_delay_grad_reduce = True

module.mlp.register_process_expert_grads_fn(_make_process_expert_grads(module.mlp))


class MegatronFSDP(torch.nn.Module):
"""Fully Sharded Data Parallel training.

Expand Down Expand Up @@ -719,6 +760,9 @@ def _pre_forward_param_unshard(
prefetch=fsdp_forward_prefetch,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
)

# Set post backward hook for TE grouped gemm if enabled comm overlap
_maybe_setup_delayed_wgrad_for_experts(module, _process_post_backward_gradients)
return args, kwargs

@torch.compiler.disable
Expand Down Expand Up @@ -1022,7 +1066,11 @@ def _register_pre_backward_param_unshard_hook(module):
for param in grad_acc_param_list:
self.grad_acc_hooks[f"grad_acc and reduce for {self.param_to_name[param]}"] = (
param.register_post_accumulate_grad_hook(
lambda p: _process_post_backward_gradients([p])
lambda p: (
None
if getattr(p, '_fsdp_delay_grad_reduce', False)
else _process_post_backward_gradients([p])
)
)
)

Expand Down
10 changes: 7 additions & 3 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,10 +1701,14 @@ def __init__(
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache

extra_kwargs = _get_extra_te_kwargs(config)
self.delay_wgrad_compute = (
self.config.delay_wgrad_compute
or self.config.delay_wgrad_compute_for_te_grouped_gemm
)

if self.config.delay_wgrad_compute:
if self.delay_wgrad_compute:
if is_te_min_version("2.3.0"):
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
extra_kwargs["delay_wgrad_compute"] = True
else:
raise RuntimeError(
"Only TE with version >=2.3.0 supports delay_wgrad_compute now."
Expand Down Expand Up @@ -2012,7 +2016,7 @@ def backward_dw(self):
Compute weight gradients during the backward pass
if delay_wgrad_compute is enabled.
"""
if self.config.delay_wgrad_compute:
if self.delay_wgrad_compute:
super().backward_dw()

class TEColumnParallelGroupedLinear(TEGroupedLinear):
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ class ModelParallelConfig:
delay_wgrad_compute: bool = False
"""Delay the weight gradient computation to improve batch-level communication overlapping"""

delay_wgrad_compute_for_te_grouped_gemm: bool = False
"""Delay the weight gradient computation for TE Grouped GEMM MoE experts.
When enabled with FSDP, the expert weight gradients are computed on a separate
CUDA stream after the data gradients finish, allowing overlap of wgrad compute
with EP A2A communication. The FSDP gradient reduce-scatter for
expert parameters is deferred until the delayed wgrad computation completes.
This requires transformer_engine with GroupedLinear support (TE >= 2.3.0).
"""

ep_overlap_early_attn_memory_release: bool = False
"""Enable early memory release of attention activations during EP overlap.
EP overlap can increase peak memory usage when the overlapped forward module allocates
Expand Down
88 changes: 88 additions & 0 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ def __init__(
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
self.fwd_execution_map = ["route", "expert_compute", "postprocess"]

# Delay wgrad computation for TE grouped GEMM
self._delayed_wgrad_event: Optional[torch.cuda.Event] = None
self._delayed_wgrad_stream: Optional[torch.cuda.Stream] = None
self._process_expert_grads_fn = None
if self.config.delay_wgrad_compute_for_te_grouped_gemm:
self._delayed_wgrad_event = torch.cuda.Event()
self._delayed_wgrad_stream = torch.cuda.Stream(device="cuda")

def _setup_inference_mode(self, pg_collection):
"""Set up inference-optimized token dispatcher and state.

Expand Down Expand Up @@ -373,6 +381,8 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
tokens and their associated probabilities to the devices hosting their assigned
experts.
"""
if self.config.delay_wgrad_compute_for_te_grouped_gemm:
hidden_states = _RegisterDelayedWgradForExperts.apply(self, hidden_states)
return self.token_dispatcher.token_dispatch(hidden_states, probs)

@maybe_skip_or_early_return_by_cudagraph("shared_experts_compute")
Expand Down Expand Up @@ -411,6 +421,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso
for each expert. It then passes the tokens through the local experts.
The output from the experts is preprocessed for the combine step.
"""
if self.config.delay_wgrad_compute_for_te_grouped_gemm:
hidden_states = _RecordExpertDgradCompletion.apply(
self._delayed_wgrad_event, hidden_states
)
dispatched_input, tokens_per_expert, permuted_probs = (
self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
)
Expand Down Expand Up @@ -584,3 +598,77 @@ def set_for_recompute_pre_mlp_layernorm(self):
from megatron.core.extensions.transformer_engine import set_save_original_input

set_save_original_input(self.shared_experts.linear_fc1)

def register_process_expert_grads_fn(self, fn):
"""Register a callback to process expert gradients after delayed wgrad computation.

This is used by FSDP to defer the reduce-scatter of expert parameter
gradients until the delayed wgrad computation has completed.

Args:
fn: A callable that processes expert gradients (e.g., triggers
FSDP reduce-scatter for expert parameters).
"""
self._process_expert_grads_fn = fn


class _RecordExpertDgradCompletion(torch.autograd.Function):
"""Autograd function that records a CUDA event when expert data gradients finish.

Placed in the forward graph just before the expert computation so that during
the backward pass, when the expert dgrad completes, we record an event. The
subsequent ``_RegisterDelayedWgradForExperts`` waits on this event before
launching the delayed wgrad computation on a separate CUDA stream.
"""

@staticmethod
def forward(ctx, event: torch.cuda.Event, *inputs):
"""Forward pass that stores the event and passes through inputs unchanged."""
ctx.event = event
return inputs[0] if len(inputs) == 1 else inputs

@staticmethod
def backward(ctx, *grad_outputs):
"""Backward pass that records the event when expert dgrad completes."""
ctx.event.record(torch.cuda.current_stream())
ctx.event = None
return (None,) + grad_outputs


class _RegisterDelayedWgradForExperts(torch.autograd.Function):
"""Autograd function that orchestrates delayed wgrad computation for MoE experts.

Placed in the forward graph at the dispatch boundary. During the backward pass,
this function:
1. Records an event on the current (backward) stream to signal the dgrad is done.
2. Executes the delayed wgrad computation on a dedicated CUDA stream.
3. Waits for the wgrad computation to complete.
4. Invokes the registered gradient processing callback (e.g., FSDP reduce-scatter).
"""

@staticmethod
def forward(ctx, module: MoELayer, *inputs):
"""Forward pass that stores the MoE module and passes through inputs unchanged."""
ctx.module = module
return inputs[0] if len(inputs) == 1 else inputs

@staticmethod
def backward(ctx, *grad_outputs):
"""Backward pass that executes delayed wgrad computation on a separate stream."""
module = ctx.module
event = module._delayed_wgrad_event
wgrad_stream = module._delayed_wgrad_stream

wgrad_stream.wait_event(event)
with torch.cuda.stream(wgrad_stream):
with torch.cuda.nvtx.range("delayed_expert_wgrad"):
module.backward_dw(routed_experts=True, shared_experts=False)
event.record(wgrad_stream)

torch.cuda.current_stream().wait_event(event)
Comment on lines +662 to +668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: When moe_latent_size is configured, backward_dw() internally launches fc2_latent_proj.backward_dw() on comm_stream (line 585-586). However, the event here is only recorded on wgrad_stream, so torch.cuda.current_stream().wait_event(event) on line 668 does not wait for the comm_stream work to finish. This could lead to a data race where the main stream proceeds (e.g., starts the next iteration) before the latent projection weight gradients are fully computed.

You'd need to either:

  1. Also synchronize comm_stream back to wgrad_stream (or main stream) before recording the event, or
  2. Add a validation assertion that delay_wgrad_compute_for_te_grouped_gemm and moe_latent_size are mutually exclusive (if that combination isn't intended to be supported yet).


if module._process_expert_grads_fn is not None:
module._process_expert_grads_fn()

ctx.module = None
return (None,) + grad_outputs
13 changes: 13 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,6 +2101,19 @@ def __post_init__(self):
'partial cuda graph'
)

if self.delay_wgrad_compute_for_te_grouped_gemm:
assert not self.overlap_moe_expert_parallel_comm, (
'overlap_moe_expert_parallel_comm must be disabled when enabling '
'delay_wgrad_compute_for_te_grouped_gemm.'
)
assert is_te_min_version(
"2.3.0"
), 'TE version >= 2.3.0 is required for delay_wgrad_compute_for_te_grouped_gemm'
assert not self.delay_wgrad_compute, (
'delay_wgrad_compute and delay_wgrad_compute_for_te_grouped_gemm '
'are mutually exclusive; use only one'
)

if self.ep_overlap_early_attn_memory_release:
assert self.overlap_moe_expert_parallel_comm, (
'overlap_moe_expert_parallel_comm must be enabled when enabling '
Expand Down
Loading