@@ -243,8 +243,17 @@ class UniZeroMTPolicy(UniZeroPolicy):
243243 perceptual_loss_weight = 0. ,
244244 # (float) The weight of the policy entropy.
245245 policy_entropy_weight = 1e-4 ,
246- # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse'].
247- predict_latent_loss_type = 'group_kl' ,
246+ # (str) The normalization type for the final layer in both the head and the encoder.
247+ # This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'.
248+ # Valid options are 'LayerNorm' and 'SimNorm'.
249+ # When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'.
250+ # When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'.
251+ final_norm_option_in_head = "LayerNorm" ,
252+ final_norm_option_in_encoder = "LayerNorm" ,
253+ # (str) The type of loss function for predicting latent variables.
254+ # Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence).
255+ # This choice is dependent on the normalization method selected above.
256+ predict_latent_loss_type = 'mse' ,
248257 # (str) The type of observation. Options are ['image', 'vector'].
249258 obs_type = 'image' ,
250259 # (float) The discount factor for future rewards.
0 commit comments