You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: megatron/core/distributed/fsdp/src/README.md
+14-24Lines changed: 14 additions & 24 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -129,7 +129,6 @@ import torch
129
129
from megatron_fsdp import (
130
130
fully_shard_model,
131
131
fully_shard_optimizer,
132
-
MixedPrecisionPolicy,
133
132
)
134
133
```
135
134
@@ -197,8 +196,10 @@ model = fully_shard_model(
197
196
outer_dp_sharding_strategy=1,
198
197
# Initialize the model on devices in shards to avoid OOM. Requires device("meta")-init for model.
199
198
init_model_with_meta_device=True,
200
-
# Mixed-Precision Policy for controlling compute and communication precision in Megatron-FSDP.
201
-
mixed_precision_policy=MixedPrecisionPolicy(),
199
+
# Reduce gradients in FP32.
200
+
grad_reduce_in_fp32=False,
201
+
# Store distributed optimization state in FP32.
202
+
preserve_fp32_weights=True,
202
203
# Sync parameters and gradients each step. Allows for gradient transformations after backward pass,
203
204
# and synchronizes parameters and gradients across HSDP groups, but deactivates compute-communication
204
205
# overlap going into the subsequent training step.
@@ -284,7 +285,7 @@ Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fin
284
285
285
286
-`fsdp_unit_modules` is a list of sub-module classes or `str` import-paths associated with modules that you want `MegatronFSDP` to fully-shard.
286
287
- Required if `1`, `2`, or `3` are specified as the sharding strategy. Defaults to `None`, in which case Megatron-FSDP will replicate the parameters similar to DDP.
287
-
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. Additionally, `outer_dp_sharding_strategy` supports `no_shard` ([Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277)) and `optim` (`HFSDP` = Fully-Sharded Optimizer State + _HSDP_, requires `zero_dp_strategy='optim_grads_params'`), after specifying the "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`).
288
+
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
288
289
- Default: `optim_grads_params` or `3` for `zero_dp_strategy` and `no_shard` or `0` for `outer_dp_sharding_strategy`
289
290
-`0` or `no_shard` implies that your model is not sharded. Similar memory usage to `DDP`.
290
291
-`1` or `optim` implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in `ZeRO-DP`.
@@ -304,25 +305,16 @@ Megatron-FSDP's `fully_shard_*` API has a comprehensive set of arguments for fin
304
305
-`init_model_with_meta_device` has `MegatronFSDP` initialize your `meta`-device model in shards on every CUDA device to avoid OOM when initializing extremely large models that cannot fit on a single device. Users can initialize their model on a [`meta`-device](https://docs.pytorch.org/docs/stable/meta.html) (`with torch.device('meta'): ...`), and ``MegatronFSDP`` will further shard and initialize the model parameters layer-by-layer adhering to the customizable `module.reset_parameters` method, which prevents the entire model from being allocated in memory at any point during runtime.
305
306
- Defaults to `False`.
306
307
- Note that the `device` argument which installs your model on a specific device or rank will be deactivated when `init_model_with_meta_device=True`.
307
-
-`mixed_precision_policy` takes a `megatron_fsdp.MixedPrecisionPolicy` that configures mixed-precision compute and communication for Megatron-FSDP. Configuration options include:
308
-
-`main_params_dtype` controls the data-type for parameters used in distributed optimization or quantization.
309
-
- Defaults to `torch.float32`.
310
-
- If set to `None`, the native model compute parameter data-type will be utilized.
311
-
- Requires specification (cannot be `None`) when using `FP8` parameters with Megatron-FSDP.
312
-
-`main_grads_dtype` controls the data-type for gradients used in distributed optimization.
313
-
- Defaults to `torch.float32`, which is highly-recommended for accuracy at scale.
314
-
- If set to `None`, the model native gradient data-type will be utilized.
315
-
-`grad_comm_dtype` controls the data-type for gradient communications (RS / AR) when reducing gradients. Lower precision `grad_comm_dtype` improves (communication) performance, but may increase memory utilization or sacrifice gradient precision in certain cases.
316
-
- Defaults to `torch.float32`.
317
-
- If set to `None`, the `main_grads_dtype` data-type will be utilized.
318
-
- If using `no_shard`, `optim`, or a `FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient communication buffers (per FSDP group) adds memory overhead.
319
-
- If using NCCL UBR v2.27+ (`nccl_ub=True`), gradient reduction may be performed in high-precision depending on the network domain (NVLink or IB), and can enable mixed-precision communication and accumulation, e.g. setting grad_comm_dtype to BF16 can support FP32 reduction even though we have BF16 input and output communication buffers. Otherwise, gradients will be reduced in `grad_comm_dtype` (and accumulated in `main_grads_dtype`) as usual.
308
+
-`grad_reduce_in_fp32` will reduce gradients in `FP32` precision (in contrast to the lower `BF16` or `FP8` model training precision).
309
+
- Defaults to `False`.
310
+
-`torch.distributed.fsdp.MixedPrecisionPolicy` will be supported in the near future.
311
+
-`preserve_fp32_weights` will preserve a `FP32` precision version of model parameters utilized for optimization.
312
+
- Defaults to `True`.
313
+
-`torch.distributed.fsdp.MixedPrecisionPolicy` will be supported in the near future.
320
314
-`overlap_grad_reduce` and `overlap_param_gather` will overlap gradient [`reduce-scatter`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter) and parameter [`all-gather`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather) group communications with backward and forward compute with asynchronous calls and pre-fetching. (In the case of `no_shard`, parameters are not gathered but gradient [`all-reduce`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce) is overlapped.)
321
315
- Both default to `True`.
322
-
-`sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary within and between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) when the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`.
316
+
-`sync_model_each_microbatch` will trigger a `wait` (`MegatronFSDP.finish_grad_sync()`) on gradient reduction, parameter de-allocation, and optimizer parameter / gradient installation (in preparation for `optimizer.step()`) after every forward-backward pass. When using HSDP, parameters and gradients will be all-gathered and reduced respectively on the "outer" DP group each training step instead of each optimization cycle. This behavior is desirable for a transparent and user-friendly sharded training loop where post-backward transformations on the gradient and a clean compute / memory state are necessary between training iterations, but damages performance in situations where optimization is delayed (e.g. gradient accumulation) where the communications of the previous training iteration can be overlapped with the compute of the next training iteration. Will also override `is_last_microbatch` / `microbatch_count` logic in `MegatronFSDP`.
323
317
- Defaults to `True` for `fully_shard`, but defaults to `False` when using the `MegatronFSDP` class directly.
324
-
- Can also be controlled with the `MegatronFSDP.sync()` context manager, or through invoking `MegatronFSDP.set_model_auto_sync(bool)`.
325
-
- WARNING: When this synchronization feature is activated in conjunction with `no_shard` / `0` or `optim` / `1` sharding strategies, the user is responsible for calling `MegatronFSDP.zero_grad_buffer()` or `optimizer.zero_grad()` after the subsequent forward-backward pass. This is because un-sharded gradients are all-reduced directly into the gradient accumulation buffer, and this buffer should not be all-reduced more than once per optimization cycle! Analogous to the justification for the [`no_sync()` API for PyTorch DistributedDataParallel](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync).
326
318
-`enable_fine_grained_param_gather` modifies FSDP to all-gather parameters with per-Module granularity instead of collectively unsharding all sub-modules of a unit module in Megatron-FSDP.
327
319
- Defaults to `False`.
328
320
-`keep_fp8_transpose_cache` will keep the fp8 transpose cache when using `MegatronFSDP`. This option will cause (number of parameter $\times$ 1 Byte) of memory overhead, but can skip the weight transpose operation in the backward propagation. This feature will not give any benefit from the Blackwell architecture.
0 commit comments