Skip to content

Commit 8900d05

Browse files
TransformerEngine Integration (#1282)
* Implemented ColumnParallelLinear with Transformer-Engine * Implemented RowParallelLinear with Transformer-Engine * Implemented LayerNormMLP with Transformer-Engine * Implemented MultiheadAttention with Transformer-Engine * Cleaned up transformer.py * Cleaned up neox_args * Cleaned up neox_args * - Fixed TE_MHA and added rope support - Implemented delayed scaling * Fixed mixed files. * Implemented ColumnParallelLinear with Transformer-Engine * Implemented RowParallelLinear with Transformer-Engine * Implemented LayerNormMLP with Transformer-Engine * Implemented MultiheadAttention with Transformer-Engine * Cleaned up transformer.py * Cleaned up neox_args * Cleaned up neox_args * - Fixed TE_MHA and added rope support - Implemented delayed scaling * Fixed mixed files. * Changed get_linear name * Added rng tracker to lnmlp and placed rope in te_mha init instead of forward * Updated fp8 arguments to te_fp8 * Added EAI copyright * precommit * add sample TE config * add te to readme * remove pip install prefix from reqs file * Force TE pytorch in requirements file --------- Co-authored-by: Quentin Anthony <[email protected]>
1 parent 29080c3 commit 8900d05

File tree

8 files changed

+968
-171
lines changed

8 files changed

+968
-171
lines changed

Diff for: README.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
1818
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).
1919

2020
## News
21+
**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration
22+
2123
**[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling
2224

2325
**[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform
@@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
6062
* [Environment and Dependencies](#environment-and-dependencies)
6163
+ [Host Setup](#host-setup)
6264
+ [Flash Attention](#flash-attention)
65+
+ [Transformer Engine](#transformer-engine)
6366
+ [Multi-Node Launching](#multi-node-launching)
6467
+ [Containerized Setup](#containerized-setup)
6568
* [Usage](#usage)
@@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD,
130133

131134
### Flash Attention
132135

133-
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
136+
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
137+
138+
### Transformer Engine
139+
140+
To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details.
141+
142+
143+
TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100:
144+
145+
146+
147+
and H100:
148+
149+
134150

135151

136152
### Multi-Node Launching

Diff for: configs/1-3B-transformer-engine.yml

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# GPT-2 pretraining setup
2+
{
3+
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
4+
# across the node boundaries )
5+
"pipe_parallel_size": 1,
6+
"model_parallel_size": 1,
7+
8+
# model settings
9+
"num_layers": 24,
10+
"hidden_size": 2048,
11+
"num_attention_heads": 16,
12+
"seq_length": 2048,
13+
"max_position_embeddings": 2048,
14+
"norm": "layernorm",
15+
"pos_emb": "rotary",
16+
"no_weight_tying": true,
17+
"gpt_j_residual": false,
18+
"output_layer_parallelism": "column",
19+
20+
# Transformer Engine settings
21+
"te_columnparallel": false,
22+
"te_rowparallel": false,
23+
"te_layernorm_mlp": true,
24+
"te_mha": true,
25+
"te_fp8_format": "hybrid",
26+
"te_fp8_wgrad": true,
27+
"te_fp8_amax_history_len": 1,
28+
"te_fp8_amax_compute_algo": "most_recent",
29+
"te_fp8_margin": 0,
30+
"te_fp8_mha": false,
31+
32+
# these should provide some speedup but takes a while to build, set to true if desired
33+
"scaled_upper_triang_masked_softmax_fusion": false,
34+
"bias_gelu_fusion": false,
35+
"rope_fusion": false,
36+
"layernorm_fusion": false,
37+
38+
# init methods
39+
"init_method": "small_init",
40+
"output_layer_init_method": "wang_init",
41+
42+
# optimizer settings
43+
"optimizer": {
44+
"type": "Adam",
45+
"params": {
46+
"lr": 0.0002,
47+
"betas": [0.9, 0.95],
48+
"eps": 1.0e-8,
49+
}
50+
},
51+
"min_lr": 0.00002,
52+
53+
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
54+
"zero_optimization": {
55+
"stage": 1,
56+
"allgather_partitions": True,
57+
"allgather_bucket_size": 500000000,
58+
"overlap_comm": True,
59+
"reduce_scatter": True,
60+
"reduce_bucket_size": 500000000,
61+
"contiguous_gradients": True,
62+
},
63+
64+
# batch / data settings
65+
"train_micro_batch_size_per_gpu": 4,
66+
"data_impl": "mmap",
67+
68+
# activation checkpointing
69+
"checkpoint_activations": true,
70+
"checkpoint_num_layers": 1,
71+
"partition_activations": true,
72+
"synchronize_each_layer": true,
73+
74+
# regularization
75+
"gradient_clipping": 1.0,
76+
"weight_decay": 0.1,
77+
"hidden_dropout": 0,
78+
"attention_dropout": 0,
79+
80+
# precision settings
81+
"fp16": {
82+
"fp16": true,
83+
"enabled": true,
84+
"loss_scale": 0,
85+
"loss_scale_window": 1000,
86+
"hysteresis": 2,
87+
"min_loss_scale": 1
88+
},
89+
90+
# misc. training settings
91+
"train_iters": 320000,
92+
"lr_decay_iters": 320000,
93+
"distributed_backend": "nccl",
94+
"lr_decay_style": "cosine",
95+
"warmup": 0.01,
96+
"checkpoint_factor": 10000,
97+
"eval_interval": 1000,
98+
"eval_iters": 10,
99+
100+
# logging
101+
"log_interval": 100,
102+
"steps_per_print": 10,
103+
"keep_last_n_checkpoints": 4,
104+
"wall_clock_breakdown": true,
105+
}

Diff for: megatron/model/positional_embeddings.py

+5
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def _prepare_cache(self, seq_len, precision, base):
6767
freqs = torch.einsum("i,j->ij", t, inv_freq)
6868
emb = torch.cat((freqs, freqs), dim=-1)
6969

70+
self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))
71+
7072
cos_cached = emb.cos()[:, None, None, :]
7173
sin_cached = emb.sin()[:, None, None, :]
7274

@@ -76,6 +78,9 @@ def _prepare_cache(self, seq_len, precision, base):
7678
inv_freq.to(precision),
7779
)
7880

81+
def get_emb(self):
82+
return self.emb.to(self.precision).cuda()
83+
7984
def forward(self, x, seq_dim=0, seq_len=None):
8085
if seq_len is None:
8186
seq_len = x.shape[seq_dim]

0 commit comments

Comments
 (0)