-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathinitialize.py
224 lines (197 loc) · 7.08 KB
/
initialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# This file is adapted from initialize.py in Megatron-LM
import random
import os
import time
import numpy as np
import torch
from domino.arguments import parse_args, get_args, set_args, build_tokenizer_g, set_timers
import domino.parallel_state as mpu
from domino.tensor_parallel.random import model_parallel_cuda_manual_seed
from domino.modules.fused_func import bias_dropout_add_fused_train
from domino.modules.fused_bias_gelu import bias_gelu
from megatron import fused_kernels
import deepspeed
def initialize_domino():
"""Set global variables, initialize distributed, and
random seeds.
"""
args = parse_args()
set_args(args)
build_tokenizer_g(args)
set_timers()
args = get_args()
device_count = torch.cuda.device_count()
device = args.rank % device_count
args.local_rank = device
torch.cuda.set_device(device)
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank
)
deepspeed.init_distributed()
mpu.initialize_model_parallel(args.tensor_model_parallel_size)
seed_ = args.seed
data_parallel_random_init = False
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
# seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
seed = seed_
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
model_parallel_cuda_manual_seed(seed)
# Compile dependencies.
_compile_dependencies()
# No continuation function
return None
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling dataset index builder ...")
from megatron.data.dataset_utils import compile_helper
compile_helper()
print(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = (
args.num_attention_heads / args.tensor_model_parallel_size
) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = (
seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
)
# Print a warning.
if not (
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0:
print(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
_warmup_jit_function()
def _warmup_jit_function():
"""Compilie JIT functions before the main training steps"""
args = get_args()
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
else:
dtype = torch.float32
# Warmup fused bias+gelu
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype,
device="cuda",
)
input = torch.rand(
(
args.seq_length,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, input.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu(bias, input)
del bias, input, output
# Warmup fused bias+dropout+add
seq_length = args.seq_length
input = torch.rand(
(seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
residual = torch.rand(
(seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
residual
)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip(
[False, True], [True, True], [True, True]
):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train(input, bias, residual, dropout_rate)
del bias, input, residual, output
torch.cuda.empty_cache()