-
Notifications
You must be signed in to change notification settings - Fork 3.8k
(Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM #3792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
8b36432
71e07e7
f70e899
b6e1b5d
95c62e0
62310e1
a886174
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -283,6 +283,14 @@ def __init__( | |
| self.cudagraph_tensor_store = MoECudaGraphTensorStore() | ||
| self.fwd_execution_map = ["route", "expert_compute", "postprocess"] | ||
|
|
||
| # Delay wgrad computation for TE grouped GEMM | ||
| self._delayed_wgrad_event: Optional[torch.cuda.Event] = None | ||
| self._delayed_wgrad_stream: Optional[torch.cuda.Stream] = None | ||
| self._process_expert_grads_fn = None | ||
| if self.config.delay_wgrad_compute_for_te_grouped_gemm: | ||
| self._delayed_wgrad_event = torch.cuda.Event() | ||
| self._delayed_wgrad_stream = torch.cuda.Stream(device="cuda") | ||
|
|
||
| def _setup_inference_mode(self, pg_collection): | ||
| """Set up inference-optimized token dispatcher and state. | ||
|
|
||
|
|
@@ -373,6 +381,8 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor): | |
| tokens and their associated probabilities to the devices hosting their assigned | ||
| experts. | ||
| """ | ||
| if self.config.delay_wgrad_compute_for_te_grouped_gemm: | ||
| hidden_states = _RegisterDelayedWgradForExperts.apply(self, hidden_states) | ||
| return self.token_dispatcher.token_dispatch(hidden_states, probs) | ||
|
|
||
| @maybe_skip_or_early_return_by_cudagraph("shared_experts_compute") | ||
|
|
@@ -411,6 +421,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso | |
| for each expert. It then passes the tokens through the local experts. | ||
| The output from the experts is preprocessed for the combine step. | ||
| """ | ||
| if self.config.delay_wgrad_compute_for_te_grouped_gemm: | ||
| hidden_states = _RecordExpertDgradCompletion.apply( | ||
| self._delayed_wgrad_event, hidden_states | ||
| ) | ||
| dispatched_input, tokens_per_expert, permuted_probs = ( | ||
| self.token_dispatcher.dispatch_postprocess(hidden_states, probs) | ||
| ) | ||
|
|
@@ -584,3 +598,77 @@ def set_for_recompute_pre_mlp_layernorm(self): | |
| from megatron.core.extensions.transformer_engine import set_save_original_input | ||
|
|
||
| set_save_original_input(self.shared_experts.linear_fc1) | ||
|
|
||
| def register_process_expert_grads_fn(self, fn): | ||
| """Register a callback to process expert gradients after delayed wgrad computation. | ||
|
|
||
| This is used by FSDP to defer the reduce-scatter of expert parameter | ||
| gradients until the delayed wgrad computation has completed. | ||
|
|
||
| Args: | ||
| fn: A callable that processes expert gradients (e.g., triggers | ||
| FSDP reduce-scatter for expert parameters). | ||
| """ | ||
| self._process_expert_grads_fn = fn | ||
|
|
||
|
|
||
| class _RecordExpertDgradCompletion(torch.autograd.Function): | ||
| """Autograd function that records a CUDA event when expert data gradients finish. | ||
|
|
||
| Placed in the forward graph just before the expert computation so that during | ||
| the backward pass, when the expert dgrad completes, we record an event. The | ||
| subsequent ``_RegisterDelayedWgradForExperts`` waits on this event before | ||
| launching the delayed wgrad computation on a separate CUDA stream. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, event: torch.cuda.Event, *inputs): | ||
| """Forward pass that stores the event and passes through inputs unchanged.""" | ||
| ctx.event = event | ||
| return inputs[0] if len(inputs) == 1 else inputs | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, *grad_outputs): | ||
| """Backward pass that records the event when expert dgrad completes.""" | ||
| ctx.event.record(torch.cuda.current_stream()) | ||
| ctx.event = None | ||
| return (None,) + grad_outputs | ||
|
|
||
|
|
||
| class _RegisterDelayedWgradForExperts(torch.autograd.Function): | ||
| """Autograd function that orchestrates delayed wgrad computation for MoE experts. | ||
|
|
||
| Placed in the forward graph at the dispatch boundary. During the backward pass, | ||
| this function: | ||
| 1. Records an event on the current (backward) stream to signal the dgrad is done. | ||
| 2. Executes the delayed wgrad computation on a dedicated CUDA stream. | ||
| 3. Waits for the wgrad computation to complete. | ||
| 4. Invokes the registered gradient processing callback (e.g., FSDP reduce-scatter). | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, module: MoELayer, *inputs): | ||
| """Forward pass that stores the MoE module and passes through inputs unchanged.""" | ||
| ctx.module = module | ||
| return inputs[0] if len(inputs) == 1 else inputs | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, *grad_outputs): | ||
| """Backward pass that executes delayed wgrad computation on a separate stream.""" | ||
| module = ctx.module | ||
| event = module._delayed_wgrad_event | ||
| wgrad_stream = module._delayed_wgrad_stream | ||
|
|
||
| wgrad_stream.wait_event(event) | ||
| with torch.cuda.stream(wgrad_stream): | ||
| with torch.cuda.nvtx.range("delayed_expert_wgrad"): | ||
| module.backward_dw(routed_experts=True, shared_experts=False) | ||
| event.record(wgrad_stream) | ||
|
|
||
| torch.cuda.current_stream().wait_event(event) | ||
|
Comment on lines
+662
to
+668
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: When You'd need to either:
|
||
|
|
||
| if module._process_expert_grads_fn is not None: | ||
| module._process_expert_grads_fn() | ||
|
|
||
| ctx.module = None | ||
| return (None,) + grad_outputs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring says the wgrad overlaps "with the backward pass of earlier layers," but looking at the implementation in
_RegisterDelayedWgradForExperts.backward, the main backward stream synchronizes (current_stream().wait_event(event)) before returning — so earlier layers' backward cannot start until wgrad finishes. The actual overlap is between the wgrad computation and the A2A combine backward (dispatch backward) within the same layer.