Skip to content

Commit 7b9679a

Browse files
Merge branch 'main' into dmoe_integration
2 parents 542103f + f7a5a6f commit 7b9679a

22 files changed

+1891
-181
lines changed

Diff for: README.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
1818
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).
1919

2020
## News
21+
**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration
22+
2123
**[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling
2224

2325
**[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform
@@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
6062
* [Environment and Dependencies](#environment-and-dependencies)
6163
+ [Host Setup](#host-setup)
6264
+ [Flash Attention](#flash-attention)
65+
+ [Transformer Engine](#transformer-engine)
6366
+ [Multi-Node Launching](#multi-node-launching)
6467
+ [Containerized Setup](#containerized-setup)
6568
* [Usage](#usage)
@@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD,
130133

131134
### Flash Attention
132135

133-
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
136+
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
137+
138+
### Transformer Engine
139+
140+
To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details.
141+
142+
143+
TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100:
144+
145+
146+
147+
and H100:
148+
149+
134150

135151

136152
### Multi-Node Launching

Diff for: configs/1-3B-transformer-engine.yml

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# GPT-2 pretraining setup
2+
{
3+
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
4+
# across the node boundaries )
5+
"pipe_parallel_size": 1,
6+
"model_parallel_size": 1,
7+
8+
# model settings
9+
"num_layers": 24,
10+
"hidden_size": 2048,
11+
"num_attention_heads": 16,
12+
"seq_length": 2048,
13+
"max_position_embeddings": 2048,
14+
"norm": "layernorm",
15+
"pos_emb": "rotary",
16+
"no_weight_tying": true,
17+
"gpt_j_residual": false,
18+
"output_layer_parallelism": "column",
19+
20+
# Transformer Engine settings
21+
"te_columnparallel": false,
22+
"te_rowparallel": false,
23+
"te_layernorm_mlp": true,
24+
"te_mha": true,
25+
"te_fp8_format": "hybrid",
26+
"te_fp8_wgrad": true,
27+
"te_fp8_amax_history_len": 1,
28+
"te_fp8_amax_compute_algo": "most_recent",
29+
"te_fp8_margin": 0,
30+
"te_fp8_mha": false,
31+
32+
# these should provide some speedup but takes a while to build, set to true if desired
33+
"scaled_upper_triang_masked_softmax_fusion": false,
34+
"bias_gelu_fusion": false,
35+
"rope_fusion": false,
36+
"layernorm_fusion": false,
37+
38+
# init methods
39+
"init_method": "small_init",
40+
"output_layer_init_method": "wang_init",
41+
42+
# optimizer settings
43+
"optimizer": {
44+
"type": "Adam",
45+
"params": {
46+
"lr": 0.0002,
47+
"betas": [0.9, 0.95],
48+
"eps": 1.0e-8,
49+
}
50+
},
51+
"min_lr": 0.00002,
52+
53+
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
54+
"zero_optimization": {
55+
"stage": 1,
56+
"allgather_partitions": True,
57+
"allgather_bucket_size": 500000000,
58+
"overlap_comm": True,
59+
"reduce_scatter": True,
60+
"reduce_bucket_size": 500000000,
61+
"contiguous_gradients": True,
62+
},
63+
64+
# batch / data settings
65+
"train_micro_batch_size_per_gpu": 4,
66+
"data_impl": "mmap",
67+
68+
# activation checkpointing
69+
"checkpoint_activations": true,
70+
"checkpoint_num_layers": 1,
71+
"partition_activations": true,
72+
"synchronize_each_layer": true,
73+
74+
# regularization
75+
"gradient_clipping": 1.0,
76+
"weight_decay": 0.1,
77+
"hidden_dropout": 0,
78+
"attention_dropout": 0,
79+
80+
# precision settings
81+
"fp16": {
82+
"fp16": true,
83+
"enabled": true,
84+
"loss_scale": 0,
85+
"loss_scale_window": 1000,
86+
"hysteresis": 2,
87+
"min_loss_scale": 1
88+
},
89+
90+
# misc. training settings
91+
"train_iters": 320000,
92+
"lr_decay_iters": 320000,
93+
"distributed_backend": "nccl",
94+
"lr_decay_style": "cosine",
95+
"warmup": 0.01,
96+
"checkpoint_factor": 10000,
97+
"eval_interval": 1000,
98+
"eval_iters": 10,
99+
100+
# logging
101+
"log_interval": 100,
102+
"steps_per_print": 10,
103+
"keep_last_n_checkpoints": 4,
104+
"wall_clock_breakdown": true,
105+
}

Diff for: configs/eleutherai_cluster.yml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"tensorboard_dir": "/mnt/ssd-1/tensorboard",
2525
"log_dir": "/mnt/ssd-1/logs",
2626
"wandb_team": "eleutherai",
27+
#"wandb_run_name": "experiment"
2728
"wandb_project": "neox",
2829
"wandb_group": "example"
2930
}

Diff for: megatron/data/data_utils.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from megatron.data.blendable_dataset import BlendableDataset
2525
from megatron.data.gpt2_dataset import GPT2Dataset
2626
from megatron.data.pairwise_dataset import PairwiseDataset
27+
from megatron.data.online_dataset import OnlineDataset
2728
from megatron.data.samplers import DistributedBatchSampler
2829

2930

@@ -532,7 +533,56 @@ def build_train_valid_test_data_loaders(neox_args):
532533
pipe_load = True
533534

534535
# Data loader only on rank 0 of each model parallel group.
535-
if mpu.get_model_parallel_rank() == 0 and pipe_load:
536+
if (
537+
pipe_load
538+
and (neox_args.dataset_impl == "online")
539+
and (mpu.get_model_parallel_rank() == 0)
540+
):
541+
# Can skip most of the work...
542+
train_iters = neox_args.train_iters
543+
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
544+
test_iters = neox_args.eval_iters
545+
# Build datasets...
546+
print(
547+
f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}"
548+
)
549+
train_datasets = OnlineDataset(
550+
leave_one_out=neox_args.reinforce_leave_one_out,
551+
data_split="train",
552+
num_samples=train_iters * neox_args.train_batch_size,
553+
seq_length=neox_args.seq_length,
554+
dataserver_ips=neox_args.online_dataserver_ips,
555+
dataserver_ports=neox_args.online_dataserver_ports,
556+
)
557+
valid_datasets = OnlineDataset(
558+
leave_one_out=neox_args.reinforce_leave_one_out,
559+
data_split="valid",
560+
num_samples=eval_iters * neox_args.train_batch_size,
561+
seq_length=neox_args.seq_length,
562+
dataserver_ips=neox_args.online_dataserver_ips,
563+
dataserver_ports=neox_args.online_dataserver_ports,
564+
)
565+
test_datasets = OnlineDataset(
566+
leave_one_out=neox_args.reinforce_leave_one_out,
567+
data_split="test",
568+
num_samples=test_iters * neox_args.train_batch_size,
569+
seq_length=neox_args.seq_length,
570+
dataserver_ips=neox_args.online_dataserver_ips,
571+
dataserver_ports=neox_args.online_dataserver_ports,
572+
)
573+
# print length of datasets
574+
# Build dataloders.
575+
train_dataloader = make_data_loader(train_datasets, neox_args=neox_args)
576+
valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args)
577+
test_dataloader = make_data_loader(test_datasets, neox_args=neox_args)
578+
579+
# Flags to know if we need to do training/validation/testing.
580+
do_train = train_dataloader is not None and neox_args.train_iters > 0
581+
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
582+
do_test = test_dataloader is not None and neox_args.eval_iters > 0
583+
# Need to broadcast num_tokens and num_type_tokens.
584+
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
585+
elif mpu.get_model_parallel_rank() == 0 and pipe_load:
536586
# Number of train/valid/test samples.
537587
if neox_args.train_iters is not None:
538588
train_iters = neox_args.train_iters

Diff for: megatron/data/online_dataset.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2024, EleutherAI
2+
# This file is based on code by the authors denoted below and has been modified from its original version.
3+
#
4+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
"""Online dataset."""
19+
from typing import Union, List
20+
21+
import numpy as np
22+
import torch
23+
import torch.utils.data
24+
import socket
25+
import pickle
26+
from megatron.mpu.initialize import get_data_parallel_rank
27+
28+
29+
class OnlineDataset(torch.utils.data.Dataset):
30+
def __init__(
31+
self,
32+
num_samples,
33+
seq_length,
34+
leave_one_out=False,
35+
data_split="train",
36+
dataserver_ips: Union[str, List[str]] = "localhost",
37+
dataserver_ports: Union[int, List[int]] = 10000,
38+
):
39+
self.num_samples = num_samples
40+
self.global_rank = get_data_parallel_rank()
41+
self.leave_one_out = leave_one_out
42+
self.reward_buffer = []
43+
self.online_batching_data = []
44+
self.data_split = data_split
45+
self.seq_length = seq_length
46+
self.dataserver_ips = dataserver_ips
47+
self.dataserver_ports = dataserver_ports
48+
49+
def __len__(self):
50+
# dummy value since it's decided by the Online Trainer
51+
return self.num_samples
52+
53+
def update_online_batches(self):
54+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
55+
if isinstance(self.dataserver_ips, str):
56+
ipaddr = self.dataserver_ips
57+
else:
58+
ipaddr = self.dataserver_ips[self.global_rank]
59+
if isinstance(self.dataserver_ports, int):
60+
# simply add over the global rank
61+
port = self.dataserver_ports
62+
else:
63+
# in case we want to use different ports for different ranks, e.g. per machine sampling
64+
port = self.dataserver_ports[self.global_rank]
65+
print(f"Connecting to {ipaddr}:{port}")
66+
s.connect((ipaddr, port))
67+
s.send(self.data_split.encode())
68+
data = b""
69+
while True:
70+
chunk = s.recv(4096)
71+
if not chunk:
72+
break
73+
data += chunk
74+
batch_data = pickle.loads(data)
75+
s.close()
76+
print(f"Received {len(batch_data)} samples from the server.")
77+
for data in batch_data:
78+
if self.leave_one_out:
79+
rewards = list()
80+
for i in range(len(data["rewards"])):
81+
rewards.append(
82+
data["rewards"][i]
83+
- np.mean(
84+
[
85+
data["rewards"][j]
86+
for j in range(len(data["rewards"]))
87+
if j != i
88+
]
89+
)
90+
)
91+
data["raw_rewards"] = data["rewards"]
92+
data["rewards"] = rewards
93+
else:
94+
moving_average = 0
95+
if len(self.reward_buffer) > 0:
96+
moving_average = np.mean(self.reward_buffer)
97+
self.reward_buffer.append(np.mean(data["rewards"]))
98+
if len(self.reward_buffer) > 100:
99+
self.reward_buffer.pop(0)
100+
# For metrics...
101+
data["raw_rewards"] = data["rewards"]
102+
data["rewards"] = [r - moving_average for r in data["rewards"]]
103+
for i in range(len(data["completions"])):
104+
self.online_batching_data.append(
105+
[
106+
data["prefix"],
107+
data["completions"][i],
108+
data["rewards"][i],
109+
data["raw_rewards"][i],
110+
]
111+
)
112+
113+
def __getitem__(self, idx):
114+
if len(self.online_batching_data) == 0:
115+
self.update_online_batches()
116+
batch = self.online_batching_data.pop(0)
117+
text = batch[0] + batch[1]
118+
label = [-100 for _ in batch[0]] + batch[1]
119+
# +1 because of causal masking
120+
if len(text) <= self.seq_length:
121+
text = text + [0] * ((self.seq_length + 1) - len(text))
122+
label = label + [-100] * ((self.seq_length + 1) - len(label))
123+
return {
124+
"text": np.array(text, dtype=np.int64),
125+
"label": np.array(label, dtype=np.int64),
126+
"reward": np.array([batch[2]], dtype=np.float32),
127+
"raw_reward": np.array([batch[3]], dtype=np.float32),
128+
}

0 commit comments

Comments
 (0)