Skip to content

Commit 5234065

Browse files
committed
FSDP2 tutorial outline
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 13103fa Pull Request resolved: #3354
1 parent 35c68ea commit 5234065

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
Getting Started with Fully Sharded Data Parallel(FSDP)
2+
======================================================
3+
4+
**Author**: `Wei Feng <https://github.com/weifengpy>`__, `Will Constable <https://github.com/wconstab>`__, `Yifan Mao <https://github.com/mori360>`__
5+
6+
.. note::
7+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP2_tutorial.rst>`__.
8+
9+
How FSDP2 works
10+
--------------
11+
In `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__ (DDP) training, each rank owns a model replica and processes a batch of data, finally it uses all-reduce to sync gradients across ranks.
12+
13+
Comparing with DDP, FSDP reduces GPU memory footprint by sharding model parameters, gradients, and optimizer states. It makes it feasible to train models that cannot fit on a single GPU. As shown below in the picture,
14+
* Outside of forward and backward, parameters stay fully sharded.
15+
* Before forward and backward, all-gather to unshard parameters for computation.
16+
* Inside backward, reduce-scatter to get fully sharded gradients.
17+
* Optimizer updates sharded parameters according to sharded gradients, resulting in sharded optimizer states.
18+
19+
.. figure:: /_static/img/distributed/fsdp_workflow.png
20+
:width: 100%
21+
:align: center
22+
:alt: FSDP workflow
23+
24+
FSDP Workflow
25+
26+
27+
FSDP can be considered as decomposing DDP all-reduce into reduce-scatter and all-gather.
28+
29+
.. figure:: /_static/img/distributed/fsdp_sharding.png
30+
:width: 100%
31+
:align: center
32+
:alt: FSDP allreduce
33+
34+
FSDP Allreduce
35+
36+
Comparing with FSDP1, FSDP2 has following advantages:
37+
* Representing sharded parameters as DTensors sharded on dim-i, allowing for easy manipulation of individual parameters, communication-free sharded state dicts, and a simpler meta-device initialization flow.
38+
* Improving memory management system that achieves lower and deterministic GPU memory by avoiding recordStream and does so without any CPU synchronization.
39+
* Offers an extension point to customize the all-gather, e.g. for fp8 all-gather for fp8 linears.
40+
* Mixing frozen and non-frozen parameters can in the the same communication group without using extra memory.
41+
42+
How to use FSDP2
43+
---------------
44+
Model Initialization: nested wrapping, dim-0 sharding, AC
45+
46+
Loading State Dict
47+
48+
Forward and Backward
49+
50+
Gradient Clipping and Scaler, and Optimizer with DTensor
51+
52+
Saving State Dict
53+
54+
FSDP1-to-FSDP2 Migration Guide
55+
---------------

0 commit comments

Comments
 (0)