Skip to content

Commit b1efa60

Browse files
committed
tmp
1 parent bf3cd12 commit b1efa60

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

lzero/policy/unizero.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,25 @@ def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type,
5757
# 3. 为每组设置不同的优化器参数(特别是学习率)
5858
# 这里我们仍然使用AdamW,但学习率设置更合理
5959
optim_groups = [
60+
{
61+
'params': list(tokenizer_params.values()),
62+
'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
63+
# 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
64+
'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
65+
# 'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
66+
},
6067
{
6168
'params': list(transformer_params.values()),
6269
'lr': learning_rate, # 1e-4
6370
# 'lr': learning_rate * 0.2, # 为Transformer主干设置一个较小的学习率,例如 1e-5
6471
'weight_decay': weight_decay
6572
# 'weight_decay': weight_decay * 5.0
6673
},
67-
{
68-
'params': list(tokenizer_params.values()),
69-
'lr': learning_rate, # Tokenizer使用基础学习率,例如 1e-4
70-
# 'lr': learning_rate * 0.1, # 为encoder设置一个较小的学习率,例如 1e-5
71-
# 'weight_decay': weight_decay * 5.0 # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
72-
'weight_decay': weight_decay # <-- 为Encoder设置5倍的权重衰减!这是一个强力正则化
73-
},
7474
{
7575
'params': list(head_params.values()),
7676
'lr': learning_rate, # Heads也使用基础学习率率,例如 1e-4
77-
# 'weight_decay': 0.0 # 通常Heads的权重不做衰减
78-
'weight_decay': weight_decay
77+
'weight_decay': 0.0 # 通常Heads的权重不做衰减
78+
# 'weight_decay': weight_decay
7979

8080
}
8181
]

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def main(env_id, seed):
1111
collector_env_num = 8
1212
num_segments = 8
1313

14-
# game_segment_length = 20
15-
game_segment_length = 400 # TODO
14+
game_segment_length = 20
15+
# game_segment_length = 400 # TODO
1616

1717
evaluator_env_num = 3
1818
num_simulations = 50
@@ -76,10 +76,10 @@ def main(env_id, seed):
7676
reward_support_range=(-300., 301., 1.),
7777
value_support_range=(-300., 301., 1.),
7878
norm_type=norm_type,
79-
# num_res_blocks=1,
80-
# num_channels=64,
81-
num_res_blocks=2,
82-
num_channels=128,
79+
num_res_blocks=1,
80+
num_channels=64,
81+
# num_res_blocks=2,
82+
# num_channels=128,
8383
world_model_cfg=dict(
8484
norm_type=norm_type,
8585
final_norm_option_in_obs_head='LayerNorm',
@@ -161,8 +161,8 @@ def main(env_id, seed):
161161
# (float) 退火的结束 clip 值 (训练后期,较严格)。
162162
encoder_clip_end_value=10.0,
163163
# (int) 完成从起始值到结束值的退火所需的训练迭代步数。
164-
encoder_clip_anneal_steps=400000, # 例如,在400k次迭代后达到最终值
165-
# encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值
164+
# encoder_clip_anneal_steps=400000, # 例如,在400k次迭代后达到最终值
165+
encoder_clip_anneal_steps=100000, # 例如,在100k次迭代后达到最终值
166166

167167
# ==================== START: label smooth ====================
168168
policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman
@@ -225,7 +225,9 @@ def main(env_id, seed):
225225

226226
# ============ use muzero_segment_collector instead of muzero_collector =============
227227
from lzero.entry import train_unizero_segment
228-
main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch128-res2_targetentropy-alpha-100k-098-07-encoder-clip30-10-400k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
228+
main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder5-trans1-head0-true_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
229+
230+
# main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch128-res2_targetentropy-alpha-100k-098-07-encoder-clip30-10-400k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
229231
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
230232

231233

@@ -262,7 +264,7 @@ def main(env_id, seed):
262264
tmux new -s uz-st-refactor-boxing
263265
264266
conda activate /mnt/nfs/zhangjinouwen/puyuan/conda_envs/lz
265-
export CUDA_VISIBLE_DEVICES=4
267+
export CUDA_VISIBLE_DEVICES=6
266268
cd /mnt/nfs/zhangjinouwen/puyuan/LightZero
267269
python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251010_uz_st_ch128-res2_fix-encoder-clip_qbert.log
268270
"""

0 commit comments

Comments
 (0)