Skip to content

Commit

Permalink
Update on "[Not for land] Added changes for GPT-2 perf"
Browse files Browse the repository at this point in the history
Credit: felipemello1 for most of the work here (especially around chunked cross entropy)

Running on 4xH100s:
Without these changes (`torch.compile`), the max local batch size is 5:
```
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
```
Without these changes (no `torch.compile`), local batch size 5:
```
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%
```

With these changes, we can use local batch size 16:
```
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%
```

22.7% MFU -> 34.1% MFU


[ghstack-poisoned]
  • Loading branch information
awgu committed Sep 7, 2024
2 parents b4a24d2 + 6dc57df commit b6a84e4
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 327 deletions.
1 change: 0 additions & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
66 changes: 60 additions & 6 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Contributing to torchtitan
We want to make contributing to this project as easy and transparent as
possible.
possible. Contributions should follow the [Contributing Guidelines](#contributing-guidelines) below.

## Setup
### Setup
```
pip install -r dev-requirements.txt
```

## Pull Requests
### Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `main`.
Expand All @@ -17,20 +17,74 @@ We actively welcome your pull requests.
5. Make sure your code lints (`pre-commit run --all-files`).
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
### Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
### Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
### License
By contributing to `torchtitan`, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

---

## Contributing Guidelines

### Principles of contribution

- Apply PyTorch-native training techniques.
- The technique should be of general interests for distributed training.
- A technique with moderate to large complexity should be sitting in the proper repo (e.g. pytorch/pytorch for a new parallelism, or pytorch/data for a new data loader) instead of `torchtitan`.
- The main branch of `torchtitan` should have minimal dependency on non-PyTorch libraries. Interesting models/techniques that depend on external libraries can be demonstrated in forks of `torchtitan`.
- Aim for minimum (if not zero) code change to the model. For the Llama model in `torchtitan`, if one has to make (justifiable) model change:
- After the model change, it should still load the original checkpoint correctly.
- Document the reasons for the code change, similar to [composability.md](docs/composability.md).
- Keep code modularized, especially for [train.py](train.py), so that it remains easy to copy-paste into a minimal code example. If necessary:
- Introduce new config options/category in [config_manager.py](torchtitan/config_manager.py).
- Create separate functions/files.

### Proof of Value

It is the contributor’s responsibility to justify the change. The requirements include, but are not limited to

#### Loss

- If a change does not impact computation results, one should see identical loss before vs. after, with fixed random seeds. An example is activation checkpointing.
```
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```
- If a change is expected to impact computation results, loss convergence should be verified via end-to-end training on representable datasets (e.g. Llama3 models on the C4 dataset). One can refer to the example jobs reported in [performance.md](docs/performance.md) on what configs and how many steps to run.
- The resulted loss curve should be compared with a verified baseline.
- 1D FSDP – Preferred, whose effectiveness can be proven by comparing with 1D DDP and single-GPU training.
- 2D FSDP + TP can be used as the baseline when 1D FSDP does not suffice to make comparisons due to limited scalability. For example, this should be the baseline when experimenting with 3D parallelisms on the Llama 3.1 405B model.

#### Performance
- Memory and WPS / MFU, which are available from logging, should meet expectations.
- It is worth noting that performance expectations vary from case to case. For example, there are cases when a technique targeting memory reduction may cause throughput regression but still be acceptable (e.g. activation checkpointing). Again, it is the contributor's job to justify the feature, whether by achieving hypothetical performance, or by comparing with existing well-known implementations, etc.
- If necessary, verify the numbers on jobs spanning multiple nodes (e.g. on 64 GPUs). Please reach out to the `torchtitan` team for help if you are resource-constrained.
- When appropriate, one should show profile traces and/or memory snapshots to prove the effectiveness.

### Best practices

When appropriate, one should consider

- Adding CPU/GPU unit/integration tests.
- To add a unit test, put it in the [test](test/) folder and follow the existing test files.
- To add a GPU integration test, create a new `OverrideDefinitions` in [test_runner.py](test_runner.py). It will override the default config to run on the [debug model](train_configs/debug_model.toml).
- Updating [README](README.md) and writing a new note in the [docs](docs/) folder on installation and usage, similar to [float8.md](docs/float8.md).
- Updating [performance.md](docs/performance.md) with new performance results.
- Creating GitHub issues for things that cannot be addressed at the moment.
- Writing a post on [PyTorch Dev Discussions](https://dev-discuss.pytorch.org/c/distributed/6) forum and linking to it.
37 changes: 18 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Our guiding principles when building `torchtitan`:
### Dive into the code

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)
* [train.py](train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
* [torchtitan/parallelisms/pipeline_llama.py](torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
Expand All @@ -35,26 +35,25 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
### Key features available

1. [FSDP2 with per param sharding](docs/fsdp.md)
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html)
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including async TP)
3. Selective layer and operator activation checkpointing
4. Distributed checkpointing
5. 2 datasets pre-configured (45K - 144M)
6. GPU usage, MFU, tokens per second and more displayed via TensorBoard
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
7. All options easily configured via [toml files](train_configs/)
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
9. [Float8 support](docs/float8.md)
4. Distributed checkpointing (including async checkpointing)
5. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via TensorBoard
7. Learning rate scheduler, meta-init, optional Fused RMSNorm into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
8. [Float8 support](docs/float8.md)
9. `torch.compile` support
10. All options easily configured via [toml files](train_configs/)
11. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. Context Parallel
3. 3D Pipeline Parallel
4. `torch.compile` support
5. Scalable data loading solution
1. Context Parallel
2. Pipeline Parallel (and 3D parallellism)
3. HSDP


## Installation
Expand Down
Binary file added assets/images/llama3_1_405B_loss_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 8 additions & 9 deletions docs/composability.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# Building a Clean, Readable Distributed LLM
One of the main goals for TorchTitan was to provide a version of distributed LLM that was not only high performance, but utilized native pytorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, FP8, Compile, DCP, ...) just to name a few, and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally 'get along'. But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.
One of the main goals for torchtitan was to provide a version of distributed LLM that was not only high performance, but utilized native PyTorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, Float8, Compile, DCP, ..., just to name a few), and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally "get along". But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.



# Making the model "pipeline friendly"
## Making the model "pipeline friendly"
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.

### Simplifying the top-level model forward
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.

example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
we used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.
Example ([PR #321](https://github.com/pytorch/torchtitan/pull/321)):
We used to slice the `freqs_cis` buffer by `seq_len` in the top level forward, pass that into child modules, and expect that inside the child modules the `seq_len` would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the `freqs_cis` slicing inside the child submodule, using the runtime-accurate local `seq_len`, and this sidesteps the issue at PP slicing time.

example ([PR #322])https://github.com/pytorch/torchtitan/pull/322)): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.
Example ([PR #322](https://github.com/pytorch/torchtitan/pull/322)):
We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.

# Using a seed checkpoint for init
## Using a seed checkpoint for init
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's `init_weights` function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.

For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to `torch.pipelining`.

One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in `train.py` after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
One issue with seed checkpoints is that we rely on initializing _every_ model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in [train.py](../train.py) after pipeline splitting. `freqs_cis` was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
2 changes: 1 addition & 1 deletion docs/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fully_shard(
- `fully_shard(module)` is similar to `FullyShardedDataParallel(module)`, constructing one communication bucket from `module.parameters()` except those already assigned to a nested `fully_shard`/`FullyShardedDataParallel` call.
- `fully_shard(module)` adds an `FSDPState` object on `module`, accessible via `fully_shard.state(module)`, instead of being an `nn.Module` wrapper. This is done via the `@contract` decorator.
- Calling `model.named_parameters()` for a `model` with FSDP2 applied returns unchanged parameter names and `DTensor` sharded parameters. This means that the optimizer and gradient norm clipping see `DTensor`s.
- `fully_shard(module)` performs a dynamic class swap on `module`. E.g., if `type(module) is Transformer`, then FSDP2 constructs a new class `FSDPTransformer` that inherits from a class `FSDP` and `Transformer` and sets `module.__class__` to be `FSDPTransformer`. This allows us to add new methods and override methods via `FSDP` without constructing an `nn.Module` wrapper.
- `fully_shard(module)` performs a dynamic class swap on `module`. E.g., if `type(module) is Transformer`, then FSDP2 constructs a new class `FSDPTransformer` that inherits from a class `FSDPModule` and `Transformer` and sets `module.__class__` to be `FSDPTransformer`. This allows us to add new methods and override methods via `FSDPModule` without constructing an `nn.Module` wrapper.
- FSDP1's `sharding_strategy` and `process_group`/`device_mesh` maps to FSDP2's `mesh` and `reshard_after_forward`.
- `mesh` should be 1D for FSDP and 2D for HSDP. For HSDP, we assume replication on the 0th mesh dim and sharding on the 1st mesh dim. If `mesh is None`, then FSDP2 initializes a 1D global mesh over the default process group.
- `reshard_after_forward=True` or `False` determines whether parameters are resharded (freed) after forward. If `True`, then they are re-all-gathered in backward. This trades off saving memory at the cost of extra communication.
Expand Down
Loading

0 comments on commit b6a84e4

Please sign in to comment.