-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathpretrain_gpt.py
157 lines (139 loc) · 5.59 KB
/
pretrain_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# This file is adapted from pretrain_gpt.py in Megatron-LM
import time
import torch
from domino.utils import print_rank_0
from domino.initialize import initialize_domino, set_jit_fusion_options
from domino.arguments import get_args, core_transformer_config_from_args
from domino.data.gpt_dataset import build_train_valid_test_datasets
from domino.training import pretrain
from domino.modules.module import DominoModule
from domino.modules.enums import AttnMaskType
from domino.language_model import parallel_lm_logits
from domino.language_model import get_language_model
from domino.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
from domino.tensor_parallel.cross_entropy import LigerFusedLinearCrossEntropyFunction
_TRAIN_START_TIME = time.time()
def post_language_model_processing(lm_output, labels, logit_weights, parallel_output):
output = parallel_lm_logits(lm_output, logit_weights, parallel_output)
labels = labels.transpose(0, 1).contiguous()
loss = vocab_parallel_cross_entropy(output.float(), labels)
loss = loss.transpose(0, 1).contiguous()
return loss
def post_language_model_processing_with_liger(lm_output, labels, logit_weights, parallel_output):
b, s = labels.shape
lm_output = lm_output.flatten(0, 1)
labels = labels.transpose(0, 1).flatten(0, 1)
loss, _ = LigerFusedLinearCrossEntropyFunction.apply(lm_output, logit_weights, labels)
loss = loss.view(s, b).transpose(0, 1).contiguous()
return loss
class GPTModel(DominoModule):
def __init__(
self,
config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
):
super().__init__(config=config)
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.language_model = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
encoder_attn_mask_type=AttnMaskType.causal,
pre_process=self.pre_process,
post_process=self.post_process,
)
self.initialize_word_embeddings()
self.config = config
def set_input_tensor(self, input_tensor):
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
input_ids,
position_ids,
attention_mask,
labels=None,
inference_params=None,
):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
inference_params=inference_params,
)
if self.post_process:
if self.config.fused_linear_loss:
return post_language_model_processing_with_liger(
lm_output,
labels,
self.shared_embedding_or_output_weight(),
self.parallel_output,
)
else:
return post_language_model_processing(
lm_output,
labels,
self.shared_embedding_or_output_weight(),
self.parallel_output,
)
else:
return lm_output
def main():
initialize_domino()
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_rank_0('Building GPT model ...')
config = core_transformer_config_from_args(get_args())
model = GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True
)
args = get_args()
print_rank_0('Load GPT dataset ...')
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_cache_path=args.data_cache_path)
pretrain(model, train_ds, valid_ds, test_ds)
if __name__ == "__main__":
main()