Skip to content

[training] feat: Support Megatron-FSDP MixedPrecisionPolicy dtype args (main_params_dtype, main_grads_dtype, grad_comm_dtype) #2763

@yaoyu-33

Description

@yaoyu-33

Summary

Bridge does not currently expose the new Megatron-FSDP MixedPrecisionPolicy dtype customization arguments added in Megatron-LM #3067 (merged). Users (e.g., MLPerf) who need fine-grained control over gradient and parameter dtypes in Megatron-FSDP cannot configure them through Bridge.

Background

MLM PR #3067 added three new arguments to Megatron-FSDP's MixedPrecisionPolicy:

MLM Argument Purpose
--megatron-fsdp-main-params-dtype Controls the dtype of the main parameter copy (generalizes the deprecated preserve_fp32_weights)
--megatron-fsdp-main-grads-dtype Controls the dtype of the main gradient buffer (generalizes grad_reduce_in_fp32)
--megatron-fsdp-grad-comm-dtype Controls the dtype used for gradient communication (all-reduce & reduce-scatter)

A follow-up MLM PR #3810 adds auto (None) support for these arguments, which falls back to per-parameter gradient dtypes (needed for --gradient-accumulation-fusion compatibility with FP8 parameters).

Motivation

  • MLPerf use case: MLPerf benchmarks use Bridge for DSv3 training but need --megatron-fsdp-grad-comm-dtype bf16 for correctness with FP8 + gradient accumulation fusion. Currently this must be hard-coded.
  • General gap: Bridge's DistributedDataParallelConfig inherits from MCore but the new FSDP MixedPrecisionPolicy fields are not surfaced or validated in Bridge's config system.

Current State

  • Bridge already has grad_reduce_in_fp32 in DistributedDataParallelConfig (inherited from MCore), but the new FSDP-specific dtype args from #3067 are not exposed in Bridge's config or CLI.
  • The DDP wrapper in src/megatron/bridge/models/common/unimodal.py (_ddp_wrap()) passes ddp_config to MCore's FullyShardedDataParallel, so the fields will flow through once they are set on the config — the main gap is config surface area.

Proposed Changes

  1. Verify MCore DDP config fields: Once MCore is bumped to include #3067, confirm that DistributedDataParallelConfig (or the FSDP-specific config) already contains main_params_dtype, main_grads_dtype, and grad_comm_dtype. If they live in a separate FSDP config, surface them in Bridge's config hierarchy.

  2. Add validation: In ConfigContainer.validate(), add appropriate validation for the new fields when use_megatron_fsdp=True (e.g., warn if grad_comm_dtype conflicts with gradient_accumulation_fusion + FP8).

  3. Support auto / None: Once MLM #3810 merges, ensure Bridge correctly handles the None / auto option for these fields.

  4. Update recipes: Review existing Megatron-FSDP recipes (e.g., qwen_vl, gpt_oss, nemotron_3_nano) to use the new dtype args instead of the legacy grad_reduce_in_fp32 where appropriate.

Upstream PRs

References

  • Bridge DDP wrapper: src/megatron/bridge/models/common/unimodal.py L192-252
  • Bridge DDP config: src/megatron/bridge/training/config.py L56-78
  • Bridge MixedPrecisionConfig: src/megatron/bridge/training/mixed_precision.py L27-68
  • Bridge FSDP validation: src/megatron/bridge/training/config.py L1536-1564

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions