@@ -251,17 +251,23 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
251251 """
252252 self .collect_mcts_temperature = temperature
253253 ready_env_id = list (obs .keys ())
254- init_state = {env_id : obs [env_id ]['board' ] for env_id in ready_env_id }
254+ if self ._cfg .simulation_env_id == 'chess' : # obs[env_id]['board'] is FEN str
255+ init_state = {env_id : obs [env_id ]['board' ].encode () for env_id in ready_env_id } # str → bytes
256+ else :
257+ init_state = {env_id : obs [env_id ]['board' ] for env_id in ready_env_id }
258+
255259 # If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
256260 # If it's not present (which will raise a KeyError), None is used instead.
257261 # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
258262 katago_game_state = {env_id : obs [env_id ].get ('katago_game_state' , None ) for env_id in ready_env_id }
259263 start_player_index = {env_id : obs [env_id ]['current_player_index' ] for env_id in ready_env_id }
260264 output = {}
261265 self ._policy_model = self ._collect_model
266+
262267 for env_id in ready_env_id :
263268 state_config_for_simulation_env_reset = EasyDict (dict (start_player_index = start_player_index [env_id ],
264- init_state = init_state [env_id ],
269+ # init_state=init_state[env_id], # orig
270+ init_state = np .frombuffer (init_state [env_id ], dtype = np .int8 ) if self ._cfg .simulation_env_id == 'chess' else init_state [env_id ],
265271 katago_policy_init = False ,
266272 katago_game_state = katago_game_state [env_id ]))
267273 action , mcts_probs , root = self ._collect_mcts .get_next_action (state_config_for_simulation_env_reset , self ._policy_value_fn , self .collect_mcts_temperature , True )
@@ -314,7 +320,11 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
314320 the corresponding policy output in this timestep, including action, probs and so on.
315321 """
316322 ready_env_id = list (obs .keys ())
317- init_state = {env_id : obs [env_id ]['board' ] for env_id in ready_env_id }
323+ if self ._cfg .simulation_env_id == 'chess' : # obs[env_id]['board'] is FEN str
324+ init_state = {env_id : obs [env_id ]['board' ].encode () for env_id in ready_env_id } # str → bytes
325+ else :
326+ init_state = {env_id : obs [env_id ]['board' ] for env_id in ready_env_id }
327+
318328 # If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
319329 # If it's not present (which will raise a KeyError), None is used instead.
320330 # This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
@@ -324,7 +334,7 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
324334 self ._policy_model = self ._eval_model
325335 for env_id in ready_env_id :
326336 state_config_for_simulation_env_reset = EasyDict (dict (start_player_index = start_player_index [env_id ],
327- init_state = init_state [env_id ],
337+ init_state = np . frombuffer ( init_state [ env_id ], dtype = np . int8 ) if self . _cfg . simulation_env_id == 'chess' else init_state [env_id ],
328338 katago_policy_init = False ,
329339 katago_game_state = katago_game_state [env_id ]))
330340 action , mcts_probs , root = self ._eval_mcts .get_next_action (
0 commit comments