Skip to content

Commit 44c0287

Browse files
committed
fix(pu): fix chess reset bug when use alphazero ctree
1 parent 2e98102 commit 44c0287

File tree

7 files changed

+38
-19
lines changed

7 files changed

+38
-19
lines changed

lzero/entry/train_alphazero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def train_alphazero(
106106
)
107107

108108
# Evaluate policy performance
109-
if evaluator.should_eval(learner.train_iter) and learner.train_iter > 0:
109+
if evaluator.should_eval(learner.train_iter) or learner.train_iter == 0:
110110
stop, reward = evaluator.eval(
111111
learner.save_checkpoint,
112112
learner.train_iter,

lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class MCTS {
166166
if (!init_state.is_none()) {
167167
init_state = py::bytes(init_state.attr("tobytes")());
168168
}
169+
169170
py::object katago_game_state = state_config_for_env_reset["katago_game_state"];
170171
if (!katago_game_state.is_none()) {
171172
katago_game_state = py::module::import("pickle").attr("dumps")(katago_game_state);
Submodule pybind11 updated 286 files

lzero/mcts/ptree/ptree_az.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def get_next_action(
261261
action = actions[np.argmax(action_probs)]
262262

263263
# Return the selected action and the output probability of each action.
264-
return action, action_probs
264+
return action, action_probs, None
265265

266266
def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None:
267267
"""

lzero/policy/alphazero.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,23 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
251251
"""
252252
self.collect_mcts_temperature = temperature
253253
ready_env_id = list(obs.keys())
254-
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
254+
if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str
255+
init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes
256+
else:
257+
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
258+
255259
# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
256260
# If it's not present (which will raise a KeyError), None is used instead.
257261
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
258262
katago_game_state = {env_id: obs[env_id].get('katago_game_state', None) for env_id in ready_env_id}
259263
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
260264
output = {}
261265
self._policy_model = self._collect_model
266+
262267
for env_id in ready_env_id:
263268
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
264-
init_state=init_state[env_id],
269+
# init_state=init_state[env_id], # orig
270+
init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id],
265271
katago_policy_init=False,
266272
katago_game_state=katago_game_state[env_id]))
267273
action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
@@ -314,7 +320,11 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
314320
the corresponding policy output in this timestep, including action, probs and so on.
315321
"""
316322
ready_env_id = list(obs.keys())
317-
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
323+
if self._cfg.simulation_env_id == 'chess': # obs[env_id]['board'] is FEN str
324+
init_state = {env_id: obs[env_id]['board'].encode() for env_id in ready_env_id} # str → bytes
325+
else:
326+
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
327+
318328
# If 'katago_game_state' is in the observation of the given environment ID, it's value is used.
319329
# If it's not present (which will raise a KeyError), None is used instead.
320330
# This approach is taken to maintain compatibility with the handling of 'katago' related parts of 'alphazero_mcts_ctree' in Go.
@@ -324,7 +334,7 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
324334
self._policy_model = self._eval_model
325335
for env_id in ready_env_id:
326336
state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
327-
init_state=init_state[env_id],
337+
init_state=np.frombuffer(init_state[env_id], dtype=np.int8) if self._cfg.simulation_env_id == 'chess' else init_state[env_id],
328338
katago_policy_init=False,
329339
katago_game_state=katago_game_state[env_id]))
330340
action, mcts_probs, root = self._eval_mcts.get_next_action(

zoo/board_games/chess/config/chess_alphazero_sp_mode_config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
batch_size = 512
1212
max_env_step = int(1e6)
1313
mcts_ctree = True
14+
# mcts_ctree = False
15+
1416

1517
# TODO: for debug
16-
# collector_env_num = 2
17-
# n_episode = 2
18-
# evaluator_env_num = 2
19-
# num_simulations = 4
20-
# update_per_collect = 2
21-
# batch_size = 2
22-
# max_env_step = int(1e4)
18+
collector_env_num = 2
19+
n_episode = 2
20+
evaluator_env_num = 2
21+
num_simulations = 4
22+
update_per_collect = 2
23+
batch_size = 2
24+
max_env_step = int(1e4)
2325
# mcts_ctree = False
2426
# ==============================================================
2527
# end of the most frequently changed config specified by the user

zoo/board_games/chess/envs/chess_lightzero_env.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def observe(self, agent_index):
5757
observation = chess_utils.get_observation(self.board, agent_index).astype(float) # TODO
5858
except Exception as e:
5959
print('debug')
60+
print(f"self.board:{self.board}")
61+
6062

6163
# TODO:
6264
# observation = np.dstack((observation[:, :, :7], self.board_history))
@@ -109,10 +111,6 @@ def get_done_winner(self):
109111
return done, winner
110112

111113
def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None):
112-
if self.alphazero_mcts_ctree and init_state is not None:
113-
# Convert byte string to np.ndarray
114-
init_state = np.frombuffer(init_state, dtype=np.int32)
115-
116114
if self.scale:
117115
self._observation_space = spaces.Dict(
118116
{
@@ -131,8 +129,16 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False,
131129
self._reward_space = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
132130
self.start_player_index = start_player_index
133131
self._current_player = self.players[self.start_player_index]
132+
134133
if init_state is not None:
135-
self.board = chess.Board(init_state)
134+
if isinstance(init_state, np.ndarray):
135+
# ndarray → bytes → str
136+
fen = init_state.tobytes().decode()
137+
elif isinstance(init_state, (bytes, bytearray)):
138+
fen = init_state.decode()
139+
else: # init_state is str
140+
fen = init_state
141+
self.board = chess.Board(fen)
136142
else:
137143
self.board = chess.Board()
138144

0 commit comments

Comments
 (0)