(Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3795
Draft
Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Draft
(Draft)[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3795Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
7 tasks
Contributor
Author
|
/claude review |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
PR for dev: #3766
Problem
In MoE models, the expert weight gradient (wgrad) computation during backward is serialized on the main CUDA stream. This blocks the data gradient (dgrad) from flowing to earlier layers until the expert wgrad finishes, even though there is no data dependency between them. The result is wasted GPU cycles — earlier layers' backward pass sits idle waiting for expert wgrad to complete.
With FSDP, this is further compounded because the gradient reduce-scatter for expert parameters is also blocked on the same critical path.
Solution
This PR introduces a new flag
--delay-wgrad-compute-for-te-grouped-gemmthat separates the expert wgrad computation from the main backward stream:Two autograd functions are inserted into the MoE layer's forward graph:
_RecordExpertDgradCompletion— placed before the expert computation; during backward, it records a CUDA event once the expert dgrad is done._RegisterDelayedWgradForExperts— placed at the dispatch boundary; during backward, it waits on the dgrad event, then launchesbackward_dw()on a dedicated CUDA stream, and synchronizes back to the main stream before proceeding.FSDP integration — When used with MegatronFSDP, expert parameters are marked with
_fsdp_delay_grad_reduce = Trueso the normal post-accumulate-grad hook skips them. A callback is registered viaregister_process_expert_grads_fn()that triggers the FSDP reduce-scatter for expert parameters only after the delayed wgrad computation completes.TE GroupedLinear is configured with
delay_wgrad_compute=True, which tells Transformer Engine to skip wgrad during the normal autograd backward and instead wait for an explicitbackward_dw()call.How to enable
Requirements:
moe_grouped_gemmenabled (not legacy grouped gemm)--delay-wgrad-compute(the existing A2A-overlap-based delay)--overlap-moe-expert-parallel-commWorks with both FSDP and 3-D parallelism (TP/EP/PP).
What is achieved
The expert wgrad computation runs on a separate CUDA stream, overlapping with the EP communication within the same transformer layer. This reduces the wall-clock time of the backward pass without changing numerical results — the feature is bit-exact with the non-delayed baseline (verified by unit tests comparing per-step losses and final weights over multiple optimizer steps).
Changes
megatron/core/model_parallel_config.pydelay_wgrad_compute_for_te_grouped_gemmmegatron/core/transformer/transformer_config.pymegatron/core/transformer/moe/moe_layer.pyregister_process_expert_grads_fncallbackmegatron/core/extensions/transformer_engine.pydelay_wgrad_compute=Trueto TE GroupedLinear when the new flag is setmegatron/core/distributed/fsdp/.../megatron_fsdp.pytests/unit_tests/a2a_overlap/test_delay_wgrad_compute.pyTest plan
test_delay_wgrad_compute_for_te_grouped_gemm— full-model training loop (forward → backward → optimizer) comparing delayed vs. non-delayed acrossnum_layers × shared_experts × dispatcher_type × fp8_flagtest_delay_wgrad_compute_for_te_grouped_gemm_with_fsdp— same comparison with MegatronFSDP wrapping (fully_shard_model+fully_shard_optimizer), verifying the deferred reduce-scatter pathContribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.