|
| 1 | +Getting Started with Fully Sharded Data Parallel(FSDP) |
| 2 | +====================================================== |
| 3 | + |
| 4 | +**Author**: `Wei Feng <https://github.com/weifengpy>`__, `Will Constable <https://github.com/wconstab>`__, `Yifan Mao <https://github.com/mori360>`__ |
| 5 | + |
| 6 | +.. note:: |
| 7 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP2_tutorial.rst>`__. |
| 8 | + |
| 9 | +How FSDP2 works |
| 10 | +-------------- |
| 11 | +In `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__ (DDP) training, each rank owns a model replica and processes a batch of data, finally it uses all-reduce to sync gradients across ranks. |
| 12 | + |
| 13 | +Comparing with DDP, FSDP reduces GPU memory footprint by sharding model parameters, gradients, and optimizer states. It makes it feasible to train models that cannot fit on a single GPU. As shown below in the picture, |
| 14 | +* Outside of forward and backward, parameters stay fully sharded. |
| 15 | +* Before forward and backward, all-gather to unshard parameters for computation. |
| 16 | +* Inside backward, reduce-scatter to get fully sharded gradients. |
| 17 | +* Optimizer updates sharded parameters according to sharded gradients, resulting in sharded optimizer states. |
| 18 | + |
| 19 | +.. figure:: /_static/img/distributed/fsdp_workflow.png |
| 20 | + :width: 100% |
| 21 | + :align: center |
| 22 | + :alt: FSDP workflow |
| 23 | + |
| 24 | + FSDP Workflow |
| 25 | + |
| 26 | + |
| 27 | +FSDP can be considered as decomposing DDP all-reduce into reduce-scatter and all-gather. |
| 28 | + |
| 29 | +.. figure:: /_static/img/distributed/fsdp_sharding.png |
| 30 | + :width: 100% |
| 31 | + :align: center |
| 32 | + :alt: FSDP allreduce |
| 33 | + |
| 34 | + FSDP Allreduce |
| 35 | + |
| 36 | +Comparing with FSDP1, FSDP2 has following advantages: |
| 37 | +* Representing sharded parameters as DTensors sharded on dim-i, allowing for easy manipulation of individual parameters, communication-free sharded state dicts, and a simpler meta-device initialization flow. |
| 38 | +* Improving memory management system that achieves lower and deterministic GPU memory by avoiding recordStream and does so without any CPU synchronization. |
| 39 | +* Offers an extension point to customize the all-gather, e.g. for fp8 all-gather for fp8 linears. |
| 40 | +* Mixing frozen and non-frozen parameters can in the the same communication group without using extra memory. |
| 41 | + |
| 42 | +How to use FSDP2 |
| 43 | +--------------- |
| 44 | +Model Initialization: nested wrapping, dim-0 sharding, AC |
| 45 | + |
| 46 | +Loading State Dict |
| 47 | + |
| 48 | +Forward and Backward |
| 49 | + |
| 50 | +Gradient Clipping and Scaler, and Optimizer with DTensor |
| 51 | + |
| 52 | +Saving State Dict |
| 53 | + |
| 54 | +FSDP1-to-FSDP2 Migration Guide |
| 55 | +--------------- |
0 commit comments