Skip to content

Commit ad2226a

Browse files
committed
fix(pu): fix final_norm_option and predict_latent_loss_type default config bug
1 parent 7841fdf commit ad2226a

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

lzero/policy/unizero_multitask.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,17 @@ class UniZeroMTPolicy(UniZeroPolicy):
243243
perceptual_loss_weight=0.,
244244
# (float) The weight of the policy entropy.
245245
policy_entropy_weight=1e-4,
246-
# (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
247-
predict_latent_loss_type='group_kl',
246+
# (str) The normalization type for the final layer in both the head and the encoder.
247+
# This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
248+
# Valid options are 'LayerNorm' and 'SimNorm'.
249+
# When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
250+
# When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
251+
final_norm_option_in_head="LayerNorm",
252+
final_norm_option_in_encoder="LayerNorm",
253+
# (str) The type of loss function for predicting latent variables.
254+
# Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
255+
# This choice is dependent on the normalization method selected above.
256+
predict_latent_loss_type='mse',
248257
# (str) The type of observation. Options are ['image', 'vector'].
249258
obs_type='image',
250259
# (float) The discount factor for future rewards.

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def create_config(
129129
max_blocks=num_unroll_steps,
130130
max_tokens=2 * num_unroll_steps,
131131
context_length=2 * infer_context_length,
132+
final_norm_option_in_obs_head='LayerNorm',
133+
final_norm_option_in_encoder='LayerNorm',
134+
predict_latent_loss_type='mse',
132135
encoder_type='vit',
133136
device='cuda',
134137
game_segment_length=20,

0 commit comments

Comments
 (0)