Skip to content

Support 'auto' argument which defaults to pre-MixedPrecisionPolicy be…#3810

Open
cspades wants to merge 1 commit intoNVIDIA:mainfrom
cspades:cye/mfsdp-dtype-gradaccumfuse-bugfix
Open

Support 'auto' argument which defaults to pre-MixedPrecisionPolicy be…#3810
cspades wants to merge 1 commit intoNVIDIA:mainfrom
cspades:cye/mfsdp-dtype-gradaccumfuse-bugfix

Conversation

@cspades
Copy link
Member

@cspades cspades commented Mar 11, 2026

…havior for supporting per-parameter grad dtypes.

What does this PR do ?

  • Support the auto (i.e. None) argument for Megatron-FSDP MixedPrecisionPolicy. Addresses the possible case where model gradients are mixed precision and static gradient buffers are used (as in the case of --gradient-accumulation-fusion), in which case we fall-back to the original logic of using the parameter data-type for the gradient buffer data-type, and always use BF16 for quantized parameters.
    • Mixed-precision gradients per-parameter or per-module isn't supported yet, if we hypothetically wanted to pre-allocate gradient buffers for fused gradient accumulation.

Details / Backstory

  • --gradient-accumulation-fusion is not compatible with the default torch.float32 value for gradient communication dtype when using FP8 parameters, because get_main_grad() is called during Megatron-LM's backward implementation and produces an FP32 gradient buffer for BF16 gradients if --megatron-fsdp-grad-comm-dtype bf16 is not set. Megatron-Bridge currently does not support setting this argument, so we have to hard-code it to unblock MLPerf. This PR generalizes this compatibility in case models actually have mixed precision gradients, we need the None option in Megatron-LM.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@cspades cspades self-assigned this Mar 11, 2026
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Comment on lines -861 to 870
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8, 'auto': None,
}
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]

args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
args.mamba_inference_conv_states_dtype = map_dtype(args.mamba_inference_conv_states_dtype)
args.mamba_inference_ssm_states_dtype = map_dtype(args.mamba_inference_ssm_states_dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I guard the args that don't support None? The argparser should already take care of it...

…havior for supporting per-parameter grad dtypes.

Signed-off-by: Cory Ye <cye@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants