Skip to content

Commit e17493e

Browse files
authored
fix(pu): fix incompatibility between final_norm_option_in_encoder and predict_latent_loss_type in sampled unizero
1 parent 20933c1 commit e17493e

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

lzero/model/sampled_unizero_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
activation=self.activation,
7878
norm_type=norm_type,
7979
group_size=world_model_cfg.group_size,
80+
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
8081
)
8182
# TODO: only for MemoryEnv now
8283
self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25, norm_type=norm_type)
@@ -98,6 +99,7 @@ def __init__(
9899
norm_type=norm_type,
99100
embedding_dim=world_model_cfg.embed_dim,
100101
group_size=world_model_cfg.group_size,
102+
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
101103
)
102104
# TODO: we should change the output_shape to the real observation shape
103105
self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))
@@ -127,6 +129,7 @@ def __init__(
127129
strides=[1, 1, 1],
128130
activation=self.activation,
129131
group_size=world_model_cfg.group_size,
132+
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
130133
)
131134
self.decoder_network = LatentDecoderForMemoryEnv(
132135
image_shape=(3, 5, 5),

lzero/policy/sampled_unizero.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,17 @@ class SampledUniZeroPolicy(UniZeroPolicy):
121121
perceptual_loss_weight=0.,
122122
# (float) The weight of the policy entropy loss.
123123
policy_entropy_weight=5e-3,
124-
# (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
125-
predict_latent_loss_type='group_kl',
124+
# (str) The normalization type for the final layer in both the head and the encoder.
125+
# This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
126+
# Valid options are 'LayerNorm' and 'SimNorm'.
127+
# When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
128+
# When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
129+
final_norm_option_in_head="LayerNorm",
130+
final_norm_option_in_encoder="LayerNorm",
131+
# (str) The type of loss function for predicting latent variables.
132+
# Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
133+
# This choice is dependent on the normalization method selected above.
134+
predict_latent_loss_type='mse',
126135
# (str) The type of observation. Options are ['image', 'vector'].
127136
obs_type='image',
128137
# (float) The discount factor for future rewards.

zoo/classic_control/pendulum/config/pendulum_cont_sampled_unizero_config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
world_model_cfg=dict(
5050
obs_type='vector',
5151
num_unroll_steps=num_unroll_steps,
52-
policy_entropy_weight=1e-4,
52+
policy_entropy_weight=5e-2,
5353
continuous_action_space=continuous_action_space,
5454
num_of_sampled_actions=K,
5555
sigma_type='conditioned',
@@ -80,11 +80,17 @@
8080
batch_size=batch_size,
8181
optim_type='AdamW',
8282
piecewise_decay_lr_scheduler=False,
83-
learning_rate=0.0001,
83+
discount_factor=0.99,
84+
td_steps=5,
85+
learning_rate=1e-4,
86+
grad_clip_value=5,
87+
manual_temperature_decay=True,
88+
threshold_training_steps_for_final_temperature=int(2.5e4),
89+
cos_lr_scheduler=True,
8490
num_simulations=num_simulations,
8591
reanalyze_ratio=reanalyze_ratio,
8692
n_episode=n_episode,
87-
eval_freq=int(1e3),
93+
eval_freq=int(2e3),
8894
replay_buffer_size=int(1e6),
8995
collector_env_num=collector_env_num,
9096
evaluator_env_num=evaluator_env_num,

0 commit comments

Comments
 (0)