@@ -90,13 +90,13 @@ def _get_param_groups(
9090 scale_lr = False
9191
9292 if not no_wd and not scale_lr :
93- wd_mult , lr_mult = 1.0 , 1.0
93+ wd_mult , _lr_mult = 1.0 , 1.0
9494 elif not no_wd and scale_lr :
95- wd_mult , lr_mult = 1.0 , lr_mult
95+ wd_mult , _lr_mult = 1.0 , lr_mult
9696 elif no_wd and not scale_lr :
97- wd_mult , lr_mult = 0.0 , 1.0
97+ wd_mult , _lr_mult = 0.0 , 1.0
9898 else :
99- wd_mult , lr_mult = 0.0 , lr_mult
99+ wd_mult , _lr_mult = 0.0 , lr_mult
100100
101101 is_decoupled_lr = False
102102 # For input/embedding and output layer: embedding.word_embeddings.weight / output_layer.weight.
@@ -105,19 +105,19 @@ def _get_param_groups(
105105 ):
106106 is_decoupled_lr = True
107107
108- key = (wd_mult , lr_mult , is_expert_parallel , is_decoupled_lr )
108+ key = (wd_mult , _lr_mult , is_expert_parallel , is_decoupled_lr )
109109 if key not in params_map :
110110 params_map [key ] = []
111111 params_map [key ].append (param )
112112
113113 param_groups = []
114- for (wd_mult , lr_mult , is_expert_parallel , is_decoupled_lr ), params in params_map .items ():
114+ for (wd_mult , _lr_mult , is_expert_parallel , is_decoupled_lr ), params in params_map .items ():
115115 assert len (params ) > 0
116116 param_groups .append (
117117 {
118118 'params' : params ,
119119 'wd_mult' : wd_mult ,
120- 'lr_mult' : lr_mult ,
120+ 'lr_mult' : _lr_mult ,
121121 'is_expert_parallel' : is_expert_parallel ,
122122 'is_decoupled_lr' : is_decoupled_lr ,
123123 }
0 commit comments