Skip to content

Commit b18f892

Browse files
committed
fix(pu): fix encoder-clip bug and num_channel/res bug
1 parent bf91ca2 commit b18f892

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

lzero/model/unizero_world_models/world_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,6 +1885,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
18851885
discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum()
18861886
discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum()
18871887

1888+
# 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典
1889+
# 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算
1890+
detached_obs_embeddings = obs_embeddings.detach()
1891+
18881892
if self.continuous_action_space:
18891893
return LossWithIntermediateLosses(
18901894
latent_recon_loss_weight=self.latent_recon_loss_weight,
@@ -1913,8 +1917,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19131917
policy_mu=mu,
19141918
policy_sigma=sigma,
19151919
target_sampled_actions=target_sampled_actions,
1920+
19161921
value_priority=value_priority,
19171922
intermediate_tensor_x=intermediate_tensor_x,
1923+
obs_embeddings=detached_obs_embeddings, # <-- 新增
19181924
)
19191925
else:
19201926
return LossWithIntermediateLosses(
@@ -1941,8 +1947,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19411947
e_rank_last_linear = e_rank_last_linear,
19421948
e_rank_sim_norm = e_rank_sim_norm,
19431949
latent_state_l2_norms=latent_state_l2_norms,
1950+
19441951
value_priority=value_priority,
19451952
intermediate_tensor_x=intermediate_tensor_x,
1953+
obs_embeddings=detached_obs_embeddings, # <-- 新增
19461954
)
19471955

19481956

lzero/model/unizero_world_models/world_model_multitask.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,6 +1899,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
18991899
discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum()
19001900
discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum()
19011901

1902+
# 为了让外部的训练循环能够获取encoder的输出,我们将其加入返回字典
1903+
# 使用 .detach() 是因为这个张量仅用于后续的clip操作,不应影响梯度计算
1904+
detached_obs_embeddings = obs_embeddings.detach()
1905+
19021906
if self.continuous_action_space:
19031907
return LossWithIntermediateLosses(
19041908
latent_recon_loss_weight=self.latent_recon_loss_weight,
@@ -1927,7 +1931,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19271931
policy_mu=mu,
19281932
policy_sigma=sigma,
19291933
target_sampled_actions=target_sampled_actions,
1934+
19301935
value_priority=value_priority,
1936+
obs_embeddings=detached_obs_embeddings, # <-- 新增
19311937

19321938
)
19331939
else:
@@ -1955,7 +1961,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19551961
e_rank_last_linear = e_rank_last_linear,
19561962
e_rank_sim_norm = e_rank_sim_norm,
19571963
latent_state_l2_norms=latent_state_l2_norms,
1964+
19581965
value_priority=value_priority,
1966+
obs_embeddings=detached_obs_embeddings, # <-- 新增
1967+
19591968

19601969
)
19611970

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def main(env_id, seed):
7373
reward_support_range=(-300., 301., 1.),
7474
value_support_range=(-300., 301., 1.),
7575
norm_type=norm_type,
76+
# num_res_blocks=1,
77+
# num_channels=64,
78+
num_res_blocks=2,
79+
num_channels=128,
7680
world_model_cfg=dict(
7781
norm_type=norm_type,
7882
final_norm_option_in_obs_head='LayerNorm',
@@ -138,10 +142,10 @@ def main(env_id, seed):
138142
# adaptive_entropy_alpha_lr=1e-3,
139143
target_entropy_start_ratio =0.98,
140144
# target_entropy_end_ratio =0.9,
141-
target_entropy_end_ratio =0.7,
142-
target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整
143-
# target_entropy_end_ratio =0.5, # TODO=====
144-
# target_entropy_decay_steps = 400000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整
145+
# target_entropy_end_ratio =0.7,
146+
# target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整
147+
target_entropy_end_ratio =0.5, # TODO=====
148+
target_entropy_decay_steps = 400000, # 例如,在100k次迭代后达到最终值 需要与replay ratio协同调整
145149

146150

147151
# ==================== START: Encoder-Clip Annealing Config ====================
@@ -217,7 +221,7 @@ def main(env_id, seed):
217221

218222
# ============ use muzero_segment_collector instead of muzero_collector =============
219223
from lzero.entry import train_unizero_segment
220-
main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
224+
main_config.exp_name = f'data_unizero_st_refactor1010/{env_id[:-14]}/{env_id[:-14]}_uz_ch128-res2_targetentropy-alpha-400k-098-05-encoder-clip30-10-100k-true_label-smooth_resnet-encoder_priority_adamw-wd1e-2-encoder1-trans1-head1_ln-inner-ln_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
221225
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
222226

223227

@@ -240,21 +244,21 @@ def main(env_id, seed):
240244
# args.env = 'AlienNoFrameskip-v4'
241245

242246
# 下面是atari8以外的2个代表环境
243-
# args.env = 'QbertNoFrameskip-v4' # 记忆规划型环境 稀疏奖励
247+
args.env = 'QbertNoFrameskip-v4' # 记忆规划型环境 稀疏奖励
244248
# args.env = 'SpaceInvadersNoFrameskip-v4' # 记忆规划型环境 稀疏奖励
245249

246250
# 下面是已经表现不错的
247251
# args.env = 'BoxingNoFrameskip-v4' # 反应型环境 密集奖励
248252
# args.env = 'ChopperCommandNoFrameskip-v4'
249-
args.env = 'RoadRunnerNoFrameskip-v4'
253+
# args.env = 'RoadRunnerNoFrameskip-v4'
250254

251255
main(args.env, args.seed)
252256

253257
"""
254258
tmux new -s uz-st-refactor-boxing
255259
256260
conda activate /mnt/nfs/zhangjinouwen/puyuan/conda_envs/lz
257-
export CUDA_VISIBLE_DEVICES=5
261+
export CUDA_VISIBLE_DEVICES=1
258262
cd /mnt/nfs/zhangjinouwen/puyuan/LightZero
259-
python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251010_fix_uz_st_road.log
263+
python /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_segment_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251010_uz_st_ch128-res2_fix-encoder-clip_qbert.log
260264
"""

0 commit comments

Comments
 (0)