Skip to content

Commit 8553751

Browse files
authored
[bugfix] Fix megatron lr_mult (#9524)
1 parent 749f6d4 commit 8553751

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

swift/megatron/trainers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,8 @@ def _get_param_groups(
446446
param_group['max_lr'] = lr
447447
param_group['min_lr'] = min_lr
448448
lr_mult = param_group.pop('lr_mult')
449+
# Instead of using lr_mult to control the learning rate, we directly use max_lr/min_lr.
450+
param_group['lr_mult'] = 1.
449451
param_group['max_lr'] *= lr_mult
450452
param_group['min_lr'] *= lr_mult
451453
return param_groups

0 commit comments

Comments
 (0)