Skip to content

Commit 52849c2

Browse files
committed
Revert "[Megatron-FSDP] Add dtype customization to Megatron-FSDP. (#3067)"
This reverts commit b969f76. Signed-off-by: oliver könig <okoenig@nvidia.com>
1 parent 905c0e3 commit 52849c2

File tree

13 files changed

+306
-870
lines changed

13 files changed

+306
-870
lines changed

megatron/core/distributed/distributed_data_parallel_config.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ class DistributedDataParallelConfig:
3333
"""
3434

3535
check_for_nan_in_grad: bool = False
36-
"""
37-
If true, check for NaNs and Infs in gradients _before_ communication collective.
38-
Invoked by `start_grad_sync` such as in the Megatron-LM DDP training API.
39-
"""
36+
"""If true, check for NaNs and Infs in gradients _before_ communication collective."""
4037

4138
check_for_large_grads: bool = False
4239
"""If true, check for unexpectedly large gradients _before_ communication collective."""
@@ -81,7 +78,7 @@ class DistributedDataParallelConfig:
8178

8279
data_parallel_sharding_strategy: str = 'no_shard'
8380
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
84-
'optim_grads', 'optim_grads_params'."""
81+
'optim_grads', 'optim_grads_params'."""
8582

8683
gradient_reduce_div_fusion: bool = True
8784
"""If true, perform gradient reduce and division fusion."""
@@ -94,6 +91,9 @@ class DistributedDataParallelConfig:
9491
disables prefetching and may degrade performance. Adjust this value
9592
based on your system's memory and performance requirements."""
9693

94+
preserve_fp32_weights: bool = True
95+
"""If true, preserve fp32 weights in the Megatron FSDP ParamAndGradBuffer."""
96+
9797
keep_fp8_transpose_cache: bool = False
9898
"""If true, keep the fp8 transpose cache when using Megatron FSDP."""
9999

@@ -128,7 +128,7 @@ class DistributedDataParallelConfig:
128128
allocated buffer for the bucket that does not fit, it will enable NCCL
129129
user buffer with the cost of more memory usage. If false, FSDP will use
130130
Dynamic memory allocator, NCCL user buffer won't not enabled, which
131-
usually leads to low performance.
131+
usually leads to low performance.
132132
"""
133133

134134
fsdp_all_gather_in_start_param_sync: bool = True
@@ -142,7 +142,8 @@ class DistributedDataParallelConfig:
142142
outer_dp_sharding_strategy: str = 'no_shard'
143143
"""
144144
Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP) mode.
145-
Valid values are 'no_shard', 'optim'. This option is only effective when Hybrid FSDP is enabled.
145+
Valid values are 'no_shard', 'optim', 'optim_grads', 'optim_grads_params'.
146+
This option is only effective when Hybrid FSDP is enabled.
146147
"""
147148

148149
disable_symmetric_registration: bool = False

megatron/core/distributed/fsdp/mcore_fsdp_adapter.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@
4545
from megatron.core.utils import is_te_min_version, log_single_rank
4646

4747
try:
48-
from megatron.core.distributed.fsdp.src.megatron_fsdp import (
49-
FSDPDistributedIndex,
50-
MegatronFSDP,
51-
MixedPrecisionPolicy,
52-
)
48+
from megatron.core.distributed.fsdp.src.megatron_fsdp import FSDPDistributedIndex, MegatronFSDP
5349

5450
HAVE_MEGATRON_FSDP = True
5551
except ImportError as import_megatron_fsdp_error:
@@ -70,9 +66,6 @@ def __init__(
7066
ddp_config: DistributedDataParallelConfig,
7167
module: torch.nn.Module,
7268
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
73-
main_params_dtype: Optional[torch.dtype] = torch.float32,
74-
main_grads_dtype: Optional[torch.dtype] = torch.float32,
75-
grad_comm_dtype: Optional[torch.dtype] = torch.float32,
7669
disable_bucketing: bool = False,
7770
device: Optional[torch.device] = None,
7871
pg_collection: Optional[ProcessGroupCollection] = None,
@@ -89,17 +82,6 @@ def __init__(
8982
logging.INFO,
9083
f'Setting up DistributedDataParallel with config {self.ddp_config}',
9184
)
92-
self.mp_policy = MixedPrecisionPolicy(
93-
main_params_dtype=main_params_dtype,
94-
# Grandfathered Argument: grad_reduce_in_fp32
95-
main_grads_dtype=torch.float32 if ddp_config.grad_reduce_in_fp32 else main_grads_dtype,
96-
grad_comm_dtype=grad_comm_dtype,
97-
)
98-
log_single_rank(
99-
logger,
100-
logging.INFO,
101-
f'Setting up Megatron-FSDP MixedPrecisionPolicy with config {self.mp_policy}',
102-
)
10385

10486
self.megatron_fsdp_dist_index = self._init_dist_index(pg_collection)
10587

@@ -128,7 +110,6 @@ def __init__(
128110
config=config,
129111
module=MegatronFSDP(
130112
ddp_config=ddp_config,
131-
mixed_precision_policy=self.mp_policy,
132113
module=module,
133114
fsdp_unit_modules=self.fsdp_unit_modules,
134115
disable_bucketing=disable_bucketing,

megatron/core/distributed/fsdp/src/README.md

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ import torch
129129
from megatron_fsdp import (
130130
fully_shard_model,
131131
fully_shard_optimizer,
132-
MixedPrecisionPolicy,
133132
)
134133
```
135134

@@ -197,8 +196,10 @@ model = fully_shard_model(
197196
outer_dp_sharding_strategy=1,
198197
# Initialize the model on devices in shards to avoid OOM. Requires device("meta")-init for model.
199198
init_model_with_meta_device=True,
200-
# Mixed-Precision Policy for controlling compute and communication precision in Megatron-FSDP.
201-
mixed_precision_policy=MixedPrecisionPolicy(),
199+
# Reduce gradients in FP32.
200+
grad_reduce_in_fp32=False,
201+
# Store distributed optimization state in FP32.
202+
preserve_fp32_weights=True,
202203
# Sync parameters and gradients each step. Allows for gradient transformations after backward pass,
203204
# and synchronizes parameters and gradients across HSDP groups, but deactivates compute-communication
204205
# overlap going into the subsequent training step.
@@ -284,7 +285,7 @@ Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fin
284285

285286
- `fsdp_unit_modules` is a list of sub-module classes or `str` import-paths associated with modules that you want `MegatronFSDP` to fully-shard.
286287
- Required if `1`, `2`, or `3` are specified as the sharding strategy. Defaults to `None`, in which case Megatron-FSDP will replicate the parameters similar to DDP.
287-
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. Additionally, `outer_dp_sharding_strategy` supports `no_shard` ([Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277)) and `optim` (`HFSDP` = Fully-Sharded Optimizer State + _HSDP_, requires `zero_dp_strategy='optim_grads_params'`), after specifying the "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`).
288+
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
288289
- Default: `optim_grads_params` or `3` for `zero_dp_strategy` and `no_shard` or `0` for `outer_dp_sharding_strategy`
289290
- `0` or `no_shard` implies that your model is not sharded. Similar memory usage to `DDP`.
290291
- `1` or `optim` implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in `ZeRO-DP`.
@@ -304,25 +305,16 @@ Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fin
304305
- `init_model_with_meta_device` has `MegatronFSDP` initialize your `meta`-device model in shards on every CUDA device to avoid OOM when initializing extremely large models that cannot fit on a single device. Users can initialize their model on a [`meta`-device](https://docs.pytorch.org/docs/stable/meta.html) (`with torch.device('meta'): ...`), and ``MegatronFSDP`` will further shard and initialize the model parameters layer-by-layer adhering to the customizable `module.reset_parameters` method, which prevents the entire model from being allocated in memory at any point during runtime.
305306
- Defaults to `False`.
306307
- Note that the `device` argument which installs your model on a specific device or rank will be deactivated when `init_model_with_meta_device=True`.
307-
- `mixed_precision_policy` takes a `megatron_fsdp.MixedPrecisionPolicy` that configures mixed-precision compute and communication for Megatron-FSDP. Configuration options include:
308-
- `main_params_dtype` controls the data-type for parameters used in distributed optimization or quantization.
309-
- Defaults to `torch.float32`.
310-
- If set to `None`, the native model compute parameter data-type will be utilized.
311-
- Requires specification (cannot be `None`) when using `FP8` parameters with Megatron-FSDP.
312-
- `main_grads_dtype` controls the data-type for gradients used in distributed optimization.
313-
- Defaults to `torch.float32`, which is highly-recommended for accuracy at scale.
314-
- If set to `None`, the model native gradient data-type will be utilized.
315-
- `grad_comm_dtype` controls the data-type for gradient communications (RS / AR) when reducing gradients. Lower precision `grad_comm_dtype` improves (communication) performance, but may increase memory utilization or sacrifice gradient precision in certain cases.
316-
- Defaults to `torch.float32`.
317-
- If set to `None`, the `main_grads_dtype` data-type will be utilized.
318-
- If using `no_shard`, `optim`, or a `FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient communication buffers (per FSDP group) adds memory overhead.
319-
- If using NCCL UBR v2.27+ (`nccl_ub=True`), gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to BF16 can support FP32 reduction even though we have BF16 input and output communication buffers. Otherwise, gradients will be reduced in `grad_comm_dtype` (and accumulated in `main_grads_dtype`) as usual.
308+
- `grad_reduce_in_fp32` will reduce gradients in `FP32` precision (in contrast to the lower `BF16` or `FP8` model training precision).
309+
- Defaults to `False`.
310+
- `torch.distributed.fsdp.MixedPrecisionPolicy` will be supported in the near future.
311+
- `preserve_fp32_weights` will preserve a `FP32` precision version of model parameters utilized for optimization.
312+
- Defaults to `True`.
313+
- `torch.distributed.fsdp.MixedPrecisionPolicy` will be supported in the near future.
320314
- `overlap_grad_reduce` and `overlap_param_gather` will overlap gradient [`reduce-scatter`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) and parameter [`all-gather`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) group communications with backward and forward compute with asynchronous calls and pre-fetching. (In the case of `no_shard`, parameters are not gathered but gradient [`all-reduce`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) is overlapped.)
321315
- Both default to `True`.
322-
- `sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary within and between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) when the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`.
316+
- `sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) where the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`.
323317
- Defaults to `True` for `fully_shard`, but defaults to `False` when using the `MegatronFSDP` class directly.
324-
- Can also be controlled with the `MegatronFSDP.sync()` context manager, or through invoking `MegatronFSDP.set_model_auto_sync(bool)`.
325-
- WARNING: When this synchronization feature is activated in conjunction with `no_shard` / `0` or `optim` / `1` sharding strategies, the user is responsible for calling `MegatronFSDP.zero_grad_buffer()` or `optimizer.zero_grad()` after the subsequent forward-backward pass. This is because un-sharded gradients are all-reduced directly into the gradient accumulation buffer, and this buffer should not be all-reduced more than once per optimization cycle! Analogous to the justification for the [`no_sync()` API for PyTorch DistributedDataParallel](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync).
326318
- `enable_fine_grained_param_gather` modifies FSDP to all-gather parameters with per-Module granularity instead of collectively unsharding all sub-modules of a unit module in Megatron-FSDP.
327319
- Defaults to `False`.
328320
- `keep_fp8_transpose_cache` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture.
@@ -370,10 +362,8 @@ mfsdp_model = fully_shard_model(
370362
fsdp_unit_modules=[te.pytorch.TransformerLayer],
371363
# Only FSDP / ZeRO-3 supports FP8 parameters.
372364
zero_dp_strategy=3,
373-
# FP32 main weights needed for FP8 parameters.
374-
mixed_precision_policy=MixedPrecisionPolicy(
375-
main_params_dtype=torch.float32
376-
),
365+
# Needed for FP8 parameters. (Default is already True.)
366+
preserve_fp32_weights=True,
377367
# Needed for select FP8 recipes.
378368
keep_fp8_transpose_cache=True,
379369
)

megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .distributed_data_parallel_config import DistributedDataParallelConfig
1616
from .fully_shard import fully_shard, fully_shard_model, fully_shard_optimizer
1717
from .megatron_fsdp import MegatronFSDP
18-
from .mixed_precision import MixedPrecisionPolicy
1918
from .package_info import (
2019
__contact_emails__,
2120
__contact_names__,
@@ -35,7 +34,6 @@
3534
"DistributedDataParallelConfig",
3635
"MegatronFSDP",
3736
"FSDPDistributedIndex",
38-
"MixedPrecisionPolicy",
3937
"fully_shard",
4038
"fully_shard_model",
4139
"fully_shard_optimizer",

0 commit comments

Comments
 (0)