Skip to content

Commit 05da638

Browse files
committed
fix(pu): fix configure_optimizer_unizero in unizero_mt
1 parent 84e6094 commit 05da638

File tree

4 files changed

+99
-38
lines changed

4 files changed

+99
-38
lines changed

lzero/model/unizero_world_models/world_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,9 @@ def _initialize_patterns(self) -> None:
334334
def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head:
335335
"""Create head modules for the transformer."""
336336
modules = [
337+
nn.LayerNorm(self.config.embed_dim), # <-- 核心优化! # TODO
337338
nn.Linear(self.config.embed_dim, self.config.embed_dim),
339+
nn.LayerNorm(self.config.embed_dim), # 2. <-- 新增!稳定内部激活
338340
nn.GELU(approximate='tanh'),
339341
nn.Linear(self.config.embed_dim, output_dim)
340342
]

lzero/policy/unizero_multitask.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,50 +194,103 @@ def zero_grad(self, set_to_none: bool = False) -> None:
194194
self.act_embedding_table.zero_grad(set_to_none=set_to_none)
195195

196196

197+
# def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas):
198+
# """
199+
# 为UniZero模型配置带有差异化学习率的优化器。
200+
# """
201+
# # 1. 定义需要特殊处理的参数
202+
# param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
203+
204+
# # 2. 将参数分为三组:Transformer主干、Tokenizer、Heads
205+
# transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn}
206+
# tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn}
207+
208+
# # Heads的参数是那些既不属于transformer也不属于tokenizer的
209+
# head_params = {
210+
# pn: p for pn, p in param_dict.items()
211+
# if 'transformer' not in pn and 'tokenizer' not in pn
212+
# }
213+
214+
# # 3. 为每组设置不同的优化器参数(特别是学习率)
215+
# # 这里我们仍然使用AdamW,但学习率设置更合理
216+
# optim_groups = [
217+
# {
218+
# 'params': list(transformer_params.values()),
219+
# 'lr': learning_rate, # 1e-4
220+
# # 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5
221+
# 'weight_decay': weight_decay
222+
# # 'weight_decay': weight_decay * 5.0
223+
# },
224+
# {
225+
# 'params': list(tokenizer_params.values()),
226+
# 'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
227+
# # 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
228+
# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
229+
230+
# },
231+
# {
232+
# 'params': list(head_params.values()),
233+
# 'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
234+
# 'weight_decay': 0.0 # 通常Heads的权重不做衰减
235+
# # 'weight_decay': weight_decay
236+
237+
# }
238+
# ]
239+
240+
# print("--- Optimizer Groups ---")
241+
# print(f"Transformer LR: {learning_rate}")
242+
# print(f"Tokenizer/Heads LR: {learning_rate}")
243+
244+
# optimizer = torch.optim.AdamW(optim_groups, betas=betas)
245+
# return optimizer
246+
197247
def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas):
198248
"""
199249
为UniZero模型配置带有差异化学习率的优化器。
250+
(修正版,确保参数组互斥)
200251
"""
201-
# 1. 定义需要特殊处理的参数
202-
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
203-
204-
# 2. 将参数分为三组:Transformer主干、Tokenizer、Heads
205-
transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn}
206-
tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn}
207-
208-
# Heads的参数是那些既不属于transformer也不属于tokenizer的
209-
head_params = {
210-
pn: p for pn, p in param_dict.items()
211-
if 'transformer' not in pn and 'tokenizer' not in pn
212-
}
213-
214-
# 3. 为每组设置不同的优化器参数(特别是学习率)
252+
# 1. 创建空的参数列表用于分组
253+
transformer_params = []
254+
tokenizer_params = []
255+
head_params = []
256+
257+
# 2. 遍历所有可训练参数,并使用 if/elif/else 结构确保每个参数只被分配到一个组
258+
for name, param in model.named_parameters():
259+
if not param.requires_grad:
260+
continue
261+
262+
if 'transformer' in name:
263+
transformer_params.append(param)
264+
elif 'tokenizer' in name:
265+
tokenizer_params.append(param)
266+
else:
267+
head_params.append(param)
268+
269+
# 3. 为每组设置不同的优化器参数
215270
# 这里我们仍然使用AdamW,但学习率设置更合理
216271
optim_groups = [
217272
{
218-
'params': list(transformer_params.values()),
273+
'params': transformer_params,
219274
'lr': learning_rate, # 1e-4
220-
# 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5
221275
'weight_decay': weight_decay
222-
# 'weight_decay': weight_decay * 5.0
223276
},
224277
{
225-
'params': list(tokenizer_params.values()),
278+
'params': tokenizer_params,
226279
'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
227-
# 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
228280
'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
229-
230281
},
231282
{
232-
'params': list(head_params.values()),
283+
'params': head_params,
233284
'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
234285
'weight_decay': 0.0 # 通常Heads的权重不做衰减
235-
# 'weight_decay': weight_decay
236-
237286
}
238287
]
239288

240289
print("--- Optimizer Groups ---")
290+
# 打印每个组的参数数量以供调试
291+
print(f"Transformer params: {len(transformer_params)}")
292+
print(f"Tokenizer params: {len(tokenizer_params)}")
293+
print(f"Head params: {len(head_params)}")
241294
print(f"Transformer LR: {learning_rate}")
242295
print(f"Tokenizer/Heads LR: {learning_rate}")
243296

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,20 @@ def create_config(
155155
task_num=len(env_id_list),
156156
# game_segment_length=game_segment_length,
157157
game_segment_length=20, # TODO
158-
# use_priority=True,
159-
use_priority=False, # TODO=====
158+
use_priority=True,
159+
# use_priority=False, # TODO=====
160160
priority_prob_alpha=1,
161161
priority_prob_beta=1,
162-
# encoder_type='vit',
163-
encoder_type='resnet',
162+
encoder_type='vit',
163+
# encoder_type='resnet',
164164
use_normal_head=True,
165165
use_softmoe_head=False,
166166
use_moe_head=False,
167167
num_experts_in_moe_head=4,
168168
moe_in_transformer=False,
169169

170-
# multiplication_moe_in_transformer=True,
171-
multiplication_moe_in_transformer=False, # TODO=====
170+
multiplication_moe_in_transformer=True,
171+
# multiplication_moe_in_transformer=False, # TODO=====
172172

173173
n_shared_experts=1,
174174
num_experts_per_tok=1,
@@ -188,8 +188,8 @@ def create_config(
188188
learning_rate=0.0001,
189189

190190
# (bool) 是否启用自适应策略熵权重 (alpha)
191-
# use_adaptive_entropy_weight=True,
192-
use_adaptive_entropy_weight=False,
191+
use_adaptive_entropy_weight=True,
192+
# use_adaptive_entropy_weight=False,
193193

194194
# (float) 自适应alpha优化器的学习率
195195
adaptive_entropy_alpha_lr=1e-4,
@@ -216,8 +216,8 @@ def create_config(
216216
total_batch_size=total_batch_size,
217217
allocated_batch_sizes=False,
218218
train_start_after_envsteps=int(0),
219-
use_priority=False, # TODO=====
220-
# use_priority=True,
219+
# use_priority=False, # TODO=====
220+
use_priority=True,
221221
priority_prob_alpha=1,
222222
priority_prob_beta=1,
223223
print_task_priority_logs=False,
@@ -271,7 +271,10 @@ def generate_configs(
271271
# Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name.
272272
benchmark_tag = "data_unizero_mt_refactor0929" # e.g., unizero_atari_mt_20250612
273273
# model_tag = f"vit-small_moe8_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head"
274-
model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
274+
# model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
275+
276+
model_tag = f"vit_prior_alpha-100k-098-07_encoder-100k-30-10_moe8_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
277+
275278
exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/'
276279

277280
for task_id, env_id in enumerate(env_id_list):

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ def main(env_id, seed):
1414
evaluator_env_num = 3
1515
num_simulations = 50
1616
# max_env_step = int(4e5)
17-
max_env_step = int(10e6) # TODO
17+
max_env_step = int(5e6) # TODO
1818

19-
batch_size = 64
19+
# batch_size = 64
20+
batch_size = 256
2021
num_layers = 2
21-
replay_ratio = 0.25
22+
replay_ratio = 0.1
23+
# replay_ratio = 0.25
2224
num_unroll_steps = 10
2325
infer_context_length = 4
2426

@@ -131,6 +133,7 @@ def main(env_id, seed):
131133
use_adaptive_entropy_weight=True,
132134
# (float) 自适应alpha优化器的学习率
133135
adaptive_entropy_alpha_lr=1e-4,
136+
# adaptive_entropy_alpha_lr=1e-3,
134137
target_entropy_start_ratio =0.98,
135138
# target_entropy_end_ratio =0.9,
136139
target_entropy_end_ratio =0.7,
@@ -200,7 +203,7 @@ def main(env_id, seed):
200203

201204
# ============ use muzero_segment_collector instead of muzero_collector =============
202205
from lzero.entry import train_unizero_segment
203-
main_config.exp_name = f'data_unizero_st_refactor0929/{env_id[:-14]}/{env_id[:-14]}_uz_resnet-encoder_priority_adamw-wd1e-2_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
206+
main_config.exp_name = f'data_unizero_st_refactor0929/{env_id[:-14]}/{env_id[:-14]}_uz_resnet-encoder_priority_adamw-wd1e-2_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
204207
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
205208

206209

0 commit comments

Comments
 (0)