forked from mlcommons/training
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path__init__.py
More file actions
executable file
·204 lines (175 loc) · 7.87 KB
/
__init__.py
File metadata and controls
executable file
·204 lines (175 loc) · 7.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from .distrib_optimizer import DistributedOptimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def get_param_groups(modules,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
do_layernorm_bias_weight_decay):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
wd_no_scale_lr = []
wd_scale_lr = []
no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules:
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
elif do_layernorm_bias_weight_decay:
no_wd = False
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def named_state_dict(optimizer, model):
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
# Save order indices instead of Tensors
param_mappings = {}
start_index = 0
model_id_map = {id(params): name for name, params in model.named_parameters()}
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed
param_groups = [pack_group(g) for g in optimizer.param_groups]
#print(packed_state)
name_state = {param_mappings[k] : (model_id_map[k] if k in model_id_map else "noname") for k, v in param_mappings.items()}
return name_state
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args()
# Base optimizer.
param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult,
args.do_layernorm_bias_weight_decay)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
# TODO: verify for distributed optimizer
def init_state_fn(opt):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
elif args.optimizer == 'sgd':
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
# Determine whether the params have main-grad field.
args.optimizer_model_map = named_state_dict(optimizer, model[0])
params_have_main_grad = False
if args.DDP_impl == 'local':
params_have_main_grad = True
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if args.fp16 or args.bf16 or args.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
opt_ty = DistributedOptimizer \
if args.use_distributed_optimizer else \
Float16OptimizerWithFloat16Params
return opt_ty(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16,
grad_scaler,
model,
init_state_fn)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
model, init_state_fn)