@@ -360,8 +360,6 @@ def _init_learn(self) -> None:
360360 )
361361 self .value_support = DiscreteSupport (* self ._cfg .model .value_support_range , self ._cfg .device )
362362 self .reward_support = DiscreteSupport (* self ._cfg .model .reward_support_range , self ._cfg .device )
363- assert self .value_support .size == self ._learn_model .value_support_size # if these assertions fails, somebody introduced...
364- assert self .reward_support .size == self ._learn_model .reward_support_size # ...incoherence between policy and model
365363 self .value_inverse_scalar_transform_handle = InverseScalarTransform (self .value_support , self ._cfg .model .categorical_distribution )
366364 self .reward_inverse_scalar_transform_handle = InverseScalarTransform (self .reward_support , self ._cfg .model .categorical_distribution )
367365
@@ -370,6 +368,8 @@ def _init_learn(self) -> None:
370368 self .l2_norm_after = 0.
371369 self .grad_norm_before = 0.
372370 self .grad_norm_after = 0.
371+ self .pad_token_id = 0 # for compatibility
372+
373373
374374 # @profile
375375 def _forward_learn (self , data : Tuple [torch .Tensor ]) -> Dict [str , Union [float , int ]]:
@@ -850,11 +850,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [
850850 network_output = self ._eval_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , timestep )
851851 latent_state_roots , reward_roots , pred_values , policy_logits = mz_network_output_unpack (network_output )
852852
853- if not self ._eval_model .training :
854- # if not in training, obtain the scalars of the value/reward
855- pred_values = self .value_inverse_scalar_transform_handle (pred_values ).detach ().cpu ().numpy () # shape(B, 1)
856- latent_state_roots = latent_state_roots .detach ().cpu ().numpy ()
857- policy_logits = policy_logits .detach ().cpu ().numpy ().tolist () # list shape(B, A)
853+ # if not in training, obtain the scalars of the value/reward
854+ pred_values = self .value_inverse_scalar_transform_handle (pred_values ).detach ().cpu ().numpy () # shape(B, 1)
855+ latent_state_roots = latent_state_roots .detach ().cpu ().numpy ()
856+ policy_logits = policy_logits .detach ().cpu ().numpy ().tolist () # list shape(B, A)
858857
859858 if self ._cfg .model .continuous_action_space is True :
860859 # when the action space of the environment is continuous, action_mask[:] is None.
@@ -885,8 +884,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [
885884 # ==============================================================
886885 # sampled related core code
887886 # ==============================================================
888- roots_sampled_actions = roots .get_sampled_actions (
889- ) # shape: ``{list: batch_size} ->{list: action_space_size}``
887+ roots_sampled_actions = roots .get_sampled_actions () # shape: ``{list: batch_size} ->{list: action_space_size}``
890888 batch_action = []
891889
892890 for i , env_id in enumerate (ready_env_id ):
0 commit comments