@@ -194,50 +194,103 @@ def zero_grad(self, set_to_none: bool = False) -> None:
194194 self .act_embedding_table .zero_grad (set_to_none = set_to_none )
195195
196196
197+ # def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas):
198+ # """
199+ # 为UniZero模型配置带有差异化学习率的优化器。
200+ # """
201+ # # 1. 定义需要特殊处理的参数
202+ # param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
203+
204+ # # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads
205+ # transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn}
206+ # tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn}
207+
208+ # # Heads的参数是那些既不属于transformer也不属于tokenizer的
209+ # head_params = {
210+ # pn: p for pn, p in param_dict.items()
211+ # if 'transformer' not in pn and 'tokenizer' not in pn
212+ # }
213+
214+ # # 3. 为每组设置不同的优化器参数(特别是学习率)
215+ # # 这里我们仍然使用AdamW,但学习率设置更合理
216+ # optim_groups = [
217+ # {
218+ # 'params': list(transformer_params.values()),
219+ # 'lr': learning_rate, # 1e-4
220+ # # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5
221+ # 'weight_decay': weight_decay
222+ # # 'weight_decay': weight_decay * 5.0
223+ # },
224+ # {
225+ # 'params': list(tokenizer_params.values()),
226+ # 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
227+ # # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
228+ # 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
229+
230+ # },
231+ # {
232+ # 'params': list(head_params.values()),
233+ # 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
234+ # 'weight_decay': 0.0 # 通常Heads的权重不做衰减
235+ # # 'weight_decay': weight_decay
236+
237+ # }
238+ # ]
239+
240+ # print("--- Optimizer Groups ---")
241+ # print(f"Transformer LR: {learning_rate}")
242+ # print(f"Tokenizer/Heads LR: {learning_rate}")
243+
244+ # optimizer = torch.optim.AdamW(optim_groups, betas=betas)
245+ # return optimizer
246+
197247def configure_optimizer_unizero (model , learning_rate , weight_decay , device_type , betas ):
198248 """
199249 为UniZero模型配置带有差异化学习率的优化器。
250+ (修正版,确保参数组互斥)
200251 """
201- # 1. 定义需要特殊处理的参数
202- param_dict = {pn : p for pn , p in model .named_parameters () if p .requires_grad }
203-
204- # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads
205- transformer_params = {pn : p for pn , p in param_dict .items () if 'transformer' in pn }
206- tokenizer_params = {pn : p for pn , p in param_dict .items () if 'tokenizer' in pn }
207-
208- # Heads的参数是那些既不属于transformer也不属于tokenizer的
209- head_params = {
210- pn : p for pn , p in param_dict .items ()
211- if 'transformer' not in pn and 'tokenizer' not in pn
212- }
213-
214- # 3. 为每组设置不同的优化器参数(特别是学习率)
252+ # 1. 创建空的参数列表用于分组
253+ transformer_params = []
254+ tokenizer_params = []
255+ head_params = []
256+
257+ # 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组
258+ for name , param in model .named_parameters ():
259+ if not param .requires_grad :
260+ continue
261+
262+ if 'transformer' in name :
263+ transformer_params .append (param )
264+ elif 'tokenizer' in name :
265+ tokenizer_params .append (param )
266+ else :
267+ head_params .append (param )
268+
269+ # 3. 为每组设置不同的优化器参数
215270 # 这里我们仍然使用AdamW,但学习率设置更合理
216271 optim_groups = [
217272 {
218- 'params' : list ( transformer_params . values ()) ,
273+ 'params' : transformer_params ,
219274 'lr' : learning_rate , # 1e-4
220- # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5
221275 'weight_decay' : weight_decay
222- # 'weight_decay': weight_decay * 5.0
223276 },
224277 {
225- 'params' : list ( tokenizer_params . values ()) ,
278+ 'params' : tokenizer_params ,
226279 'lr' : learning_rate , # Tokenizer使用基础学习率,例如 1e-4
227- # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
228280 'weight_decay' : weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
229-
230281 },
231282 {
232- 'params' : list ( head_params . values ()) ,
283+ 'params' : head_params ,
233284 'lr' : learning_rate , # Heads也使用基础学习率率,例如 1e-4
234285 'weight_decay' : 0.0 # 通常Heads的权重不做衰减
235- # 'weight_decay': weight_decay
236-
237286 }
238287 ]
239288
240289 print ("--- Optimizer Groups ---" )
290+ # 打印每个组的参数数量以供调试
291+ print (f"Transformer params: { len (transformer_params )} " )
292+ print (f"Tokenizer params: { len (tokenizer_params )} " )
293+ print (f"Head params: { len (head_params )} " )
241294 print (f"Transformer LR: { learning_rate } " )
242295 print (f"Tokenizer/Heads LR: { learning_rate } " )
243296
0 commit comments