Skip to content

Commit 9768756

Browse files
committed
Merge branch 'fix_lr_mult' into 'main'
fix lr_mult setting will be reset in get_param_groups inner loop See merge request ADLR/megatron-lm!1578
2 parents 1ef53c3 + f11303b commit 9768756

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

megatron/core/optimizer/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def _get_param_groups(
9090
scale_lr = False
9191

9292
if not no_wd and not scale_lr:
93-
wd_mult, lr_mult = 1.0, 1.0
93+
wd_mult, _lr_mult = 1.0, 1.0
9494
elif not no_wd and scale_lr:
95-
wd_mult, lr_mult = 1.0, lr_mult
95+
wd_mult, _lr_mult = 1.0, lr_mult
9696
elif no_wd and not scale_lr:
97-
wd_mult, lr_mult = 0.0, 1.0
97+
wd_mult, _lr_mult = 0.0, 1.0
9898
else:
99-
wd_mult, lr_mult = 0.0, lr_mult
99+
wd_mult, _lr_mult = 0.0, lr_mult
100100

101101
is_decoupled_lr = False
102102
# For input/embedding and output layer: embedding.word_embeddings.weight / output_layer.weight.
@@ -105,19 +105,19 @@ def _get_param_groups(
105105
):
106106
is_decoupled_lr = True
107107

108-
key = (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr)
108+
key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
109109
if key not in params_map:
110110
params_map[key] = []
111111
params_map[key].append(param)
112112

113113
param_groups = []
114-
for (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
114+
for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
115115
assert len(params) > 0
116116
param_groups.append(
117117
{
118118
'params': params,
119119
'wd_mult': wd_mult,
120-
'lr_mult': lr_mult,
120+
'lr_mult': _lr_mult,
121121
'is_expert_parallel': is_expert_parallel,
122122
'is_decoupled_lr': is_decoupled_lr,
123123
}

0 commit comments

Comments
 (0)