Skip to content

Commit 20933c1

Browse files
fix(pu): fix compatibility in sampled unizero policy (#430)
Co-authored-by: jasper <1157507000@qq.com>
1 parent 7d7c4da commit 20933c1

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

lzero/policy/sampled_unizero.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,6 @@ def _init_learn(self) -> None:
360360
)
361361
self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device)
362362
self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device)
363-
assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced...
364-
assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model
365363
self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution)
366364
self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution)
367365

@@ -370,6 +368,8 @@ def _init_learn(self) -> None:
370368
self.l2_norm_after = 0.
371369
self.grad_norm_before = 0.
372370
self.grad_norm_after = 0.
371+
self.pad_token_id = 0 # for compatibility
372+
373373

374374
# @profile
375375
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
@@ -850,11 +850,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [
850850
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep)
851851
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
852852

853-
if not self._eval_model.training:
854-
# if not in training, obtain the scalars of the value/reward
855-
pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
856-
latent_state_roots = latent_state_roots.detach().cpu().numpy()
857-
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
853+
# if not in training, obtain the scalars of the value/reward
854+
pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
855+
latent_state_roots = latent_state_roots.detach().cpu().numpy()
856+
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
858857

859858
if self._cfg.model.continuous_action_space is True:
860859
# when the action space of the environment is continuous, action_mask[:] is None.
@@ -885,8 +884,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [
885884
# ==============================================================
886885
# sampled related core code
887886
# ==============================================================
888-
roots_sampled_actions = roots.get_sampled_actions(
889-
) # shape: ``{list: batch_size} ->{list: action_space_size}``
887+
roots_sampled_actions = roots.get_sampled_actions() # shape: ``{list: batch_size} ->{list: action_space_size}``
890888
batch_action = []
891889

892890
for i, env_id in enumerate(ready_env_id):

0 commit comments

Comments
 (0)