-
Notifications
You must be signed in to change notification settings - Fork 206
Description
Summary
Bridge does not currently expose the new Megatron-FSDP MixedPrecisionPolicy dtype customization arguments added in Megatron-LM #3067 (merged). Users (e.g., MLPerf) who need fine-grained control over gradient and parameter dtypes in Megatron-FSDP cannot configure them through Bridge.
Background
MLM PR #3067 added three new arguments to Megatron-FSDP's MixedPrecisionPolicy:
| MLM Argument | Purpose |
|---|---|
--megatron-fsdp-main-params-dtype |
Controls the dtype of the main parameter copy (generalizes the deprecated preserve_fp32_weights) |
--megatron-fsdp-main-grads-dtype |
Controls the dtype of the main gradient buffer (generalizes grad_reduce_in_fp32) |
--megatron-fsdp-grad-comm-dtype |
Controls the dtype used for gradient communication (all-reduce & reduce-scatter) |
A follow-up MLM PR #3810 adds auto (None) support for these arguments, which falls back to per-parameter gradient dtypes (needed for --gradient-accumulation-fusion compatibility with FP8 parameters).
Motivation
- MLPerf use case: MLPerf benchmarks use Bridge for DSv3 training but need
--megatron-fsdp-grad-comm-dtype bf16for correctness with FP8 + gradient accumulation fusion. Currently this must be hard-coded. - General gap: Bridge's
DistributedDataParallelConfiginherits from MCore but the new FSDP MixedPrecisionPolicy fields are not surfaced or validated in Bridge's config system.
Current State
- Bridge already has
grad_reduce_in_fp32inDistributedDataParallelConfig(inherited from MCore), but the new FSDP-specific dtype args from #3067 are not exposed in Bridge's config or CLI. - The DDP wrapper in
src/megatron/bridge/models/common/unimodal.py(_ddp_wrap()) passesddp_configto MCore'sFullyShardedDataParallel, so the fields will flow through once they are set on the config — the main gap is config surface area.
Proposed Changes
-
Verify MCore DDP config fields: Once MCore is bumped to include #3067, confirm that
DistributedDataParallelConfig(or the FSDP-specific config) already containsmain_params_dtype,main_grads_dtype, andgrad_comm_dtype. If they live in a separate FSDP config, surface them in Bridge's config hierarchy. -
Add validation: In
ConfigContainer.validate(), add appropriate validation for the new fields whenuse_megatron_fsdp=True(e.g., warn ifgrad_comm_dtypeconflicts withgradient_accumulation_fusion+ FP8). -
Support
auto/None: Once MLM #3810 merges, ensure Bridge correctly handles theNone/autooption for these fields. -
Update recipes: Review existing Megatron-FSDP recipes (e.g.,
qwen_vl,gpt_oss,nemotron_3_nano) to use the new dtype args instead of the legacygrad_reduce_in_fp32where appropriate.
Upstream PRs
- Megatron-LM #3067 — Add dtype customization to Megatron-FSDP (merged)
- Megatron-LM #3810 — Support
autoargument for per-parameter grad dtypes (open)
References
- Bridge DDP wrapper:
src/megatron/bridge/models/common/unimodal.pyL192-252 - Bridge DDP config:
src/megatron/bridge/training/config.pyL56-78 - Bridge MixedPrecisionConfig:
src/megatron/bridge/training/mixed_precision.pyL27-68 - Bridge FSDP validation:
src/megatron/bridge/training/config.pyL1536-1564