Skip to content

Commit 247b621

Browse files
committed
2 parents a47d200 + 1ee19d7 commit 247b621

File tree

3 files changed

+107
-20
lines changed

3 files changed

+107
-20
lines changed

lzero/mcts/ctree/ctree_alphazero/make.sh

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,55 @@
77
# navigating into it, running cmake to generate build files suitable for the arm64 architecture,
88
# and running make to compile the project.
99

10+
# Function to find the ctree_alphazero directory
11+
find_ctree_alphazero_dir() {
12+
local script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
13+
14+
# Check if we're already in the ctree_alphazero directory
15+
if [[ "$script_dir" == */lzero/mcts/ctree/ctree_alphazero ]]; then
16+
echo "$script_dir"
17+
return 0
18+
fi
19+
20+
# Try to find the directory by searching upwards from script location
21+
local current_dir="$script_dir"
22+
while [[ "$current_dir" != "/" ]]; do
23+
if [[ -d "$current_dir/lzero/mcts/ctree/ctree_alphazero" ]]; then
24+
echo "$current_dir/lzero/mcts/ctree/ctree_alphazero"
25+
return 0
26+
fi
27+
current_dir="$(dirname "$current_dir")"
28+
done
29+
30+
# Try to find from current working directory
31+
if [[ -d "./lzero/mcts/ctree/ctree_alphazero" ]]; then
32+
echo "$(pwd)/lzero/mcts/ctree/ctree_alphazero"
33+
return 0
34+
fi
35+
36+
# Check if CMakeLists.txt exists in current directory (maybe we're already there)
37+
if [[ -f "./CMakeLists.txt" ]] && [[ -f "./alphazero_mcts_cpp.cpp" ]]; then
38+
echo "$(pwd)"
39+
return 0
40+
fi
41+
42+
return 1
43+
}
44+
1045
# Navigate to the project directory.
11-
# ========= NOTE: PLEASE MODIFY THE FOLLOWING DIRECTORY TO YOUR OWN. =========
12-
cd /YOUR_LightZero_DIR/LightZero/lzero/mcts/ctree/ctree_alphazero/ || exit
46+
CTREE_DIR=$(find_ctree_alphazero_dir)
47+
48+
if [[ -z "$CTREE_DIR" ]]; then
49+
echo "Error: Could not find the ctree_alphazero directory."
50+
echo "Please ensure you are running this script from within the LightZero project,"
51+
echo "or manually specify the correct path in the script."
52+
echo ""
53+
echo "Expected directory structure: LightZero/lzero/mcts/ctree/ctree_alphazero/"
54+
exit 1
55+
fi
56+
57+
echo "Found ctree_alphazero directory: $CTREE_DIR"
58+
cd "$CTREE_DIR" || exit
1359

1460
# Create a new directory named "build." The build directory is where the compiled files will be stored.
1561
mkdir -p build

lzero/policy/alphazero.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,15 @@ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch
264264
init_state=init_state[env_id],
265265
katago_policy_init=False,
266266
katago_game_state=katago_game_state[env_id]))
267-
action, mcts_probs, root = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
268-
267+
# Compatible with both ctree (returns 3 values) and ptree (returns 2 values) implementations
268+
result = self._collect_mcts.get_next_action(state_config_for_simulation_env_reset, self._policy_value_fn, self.collect_mcts_temperature, True)
269+
if len(result) == 3:
270+
# ctree implementation returns: action, mcts_probs, root
271+
action, mcts_probs, root = result
272+
else:
273+
# ptree implementation returns: action, mcts_probs
274+
action, mcts_probs = result
275+
269276
output[env_id] = {
270277
'action': action,
271278
'probs': mcts_probs,
@@ -327,9 +334,16 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
327334
init_state=init_state[env_id],
328335
katago_policy_init=False,
329336
katago_game_state=katago_game_state[env_id]))
330-
action, mcts_probs, root = self._eval_mcts.get_next_action(
337+
result = self._eval_mcts.get_next_action(
331338
state_config_for_simulation_env_reset, self._policy_value_fn, 1.0, False
332339
)
340+
if len(result) == 3:
341+
# ctree implementation returns: action, mcts_probs, root
342+
action, mcts_probs, root = result
343+
else:
344+
# ptree implementation returns: action, mcts_probs
345+
action, mcts_probs = result
346+
333347
output[env_id] = {
334348
'action': action,
335349
'probs': mcts_probs,

lzero/worker/alphazero_evaluator.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def eval(
175175
"""
176176
Overview:
177177
Execute the evaluation of the policy and determine if the stopping condition has been met.
178+
In a distributed setting, this method will block all processes except rank 0,
179+
which performs the evaluation. The results are then broadcasted to all other processes.
178180
Arguments:
179181
- save_ckpt_fn (:obj:`Optional[Callable]`): Callback function to save a checkpoint.
180182
- train_iter (:obj:`int`): Current number of training iterations completed.
@@ -183,11 +185,18 @@ def eval(
183185
- force_render (:obj:`bool`): Force rendering of the environment, if applicable.
184186
Returns:
185187
- stop_flag (:obj:`bool`): Whether the training process should stop based on evaluation results.
186-
- return_info (:obj:`dict`): Information about the evaluation results.
188+
- eval_info (:obj:`dict`): Information about the evaluation results.
187189
"""
188-
# the evaluator only works on rank0
189-
stop_flag, return_info = False, []
190+
# ==============================================================
191+
# FIX: Restructure the entire method for correct distributed handling.
192+
# ==============================================================
193+
194+
# Initialize placeholders for results on all ranks.
195+
stop_flag = False
196+
eval_info = {}
197+
190198
if get_rank() == 0:
199+
# --- Rank 0 performs the evaluation ---
191200
if n_episode is None:
192201
n_episode = self._default_n_episode
193202
assert n_episode is not None, "please indicate eval n_episode"
@@ -199,17 +208,19 @@ def eval(
199208
with self._timer:
200209
while not eval_monitor.is_finished():
201210
obs = self._env.ready_obs
202-
211+
203212
# ==============================================================
204213
# policy forward
205214
# ==============================================================
206215
policy_output = self._policy.forward(obs)
207216
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
217+
208218
# ==============================================================
209219
# Interact with env.
210220
# ==============================================================
211221
timesteps = self._env.step(actions)
212222
timesteps = to_tensor(timesteps, dtype=torch.float32)
223+
213224
for env_id, t in timesteps.items():
214225
if t.info.get('abnormal', False):
215226
# If there is an abnormal timestep, reset all the related variables(including this env).
@@ -224,15 +235,17 @@ def eval(
224235
saved_info.update(t.info['episode_info'])
225236
eval_monitor.update_info(env_id, saved_info)
226237
eval_monitor.update_reward(env_id, reward)
227-
return_info.append(t.info)
228238
self._logger.info(
229239
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
230240
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
231241
)
232242
)
233243
envstep_count += 1
244+
234245
duration = self._timer.value
235246
episode_return = eval_monitor.get_episode_return()
247+
248+
# Prepare the results dictionary
236249
info = {
237250
'train_iter': train_iter,
238251
'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
@@ -248,11 +261,13 @@ def eval(
248261
'reward_min': np.min(episode_return),
249262
# 'each_reward': episode_return,
250263
}
251-
episode_info = eval_monitor.get_episode_info()
252-
if episode_info is not None:
253-
info.update(episode_info)
264+
episode_info_from_monitor = eval_monitor.get_episode_info()
265+
if episode_info_from_monitor is not None:
266+
info.update(episode_info_from_monitor)
267+
254268
self._logger.info(self._logger.get_tabulate_vars_hor(info))
255-
# self._logger.info(self._logger.get_tabulate_vars(info))
269+
270+
# Log to TensorBoard
256271
for k, v in info.items():
257272
if k in ['train_iter', 'ckpt_name', 'each_reward']:
258273
continue
@@ -266,6 +281,8 @@ def eval(
266281
if save_ckpt_fn:
267282
save_ckpt_fn('ckpt_best.pth.tar')
268283
self._max_eval_reward = eval_reward
284+
285+
# Set the final results for rank 0
269286
stop_flag = eval_reward >= self._stop_value and train_iter > 0
270287
if stop_flag:
271288
self._logger.info(
@@ -274,11 +291,21 @@ def eval(
274291
", so your AlphaZero agent is converged, you can refer to " +
275292
"'log/evaluator/evaluator_logger.txt' for details."
276293
)
294+
295+
# The final information to be returned and broadcasted
296+
eval_info = to_item(info)
277297

278-
if get_world_size() > 1:
279-
objects = [stop_flag, episode_info]
280-
broadcast_object_list(objects, src=0)
281-
stop_flag, episode_info = objects
298+
# --- Synchronization for all ranks ---
299+
if get_world_size() > 1:
300+
# All processes must participate in the broadcast.
301+
# `src=0` means rank 0 sends, and all other ranks receive.
302+
# The `objects` list on rank 0 contains the data to be sent.
303+
# On other ranks, it contains placeholders that will be overwritten.
304+
objects = [stop_flag, eval_info]
305+
broadcast_object_list(objects, src=0)
306+
# After broadcast, all processes' `objects` list is updated.
307+
stop_flag, eval_info = objects
282308

283-
episode_info = to_item(episode_info)
284-
return stop_flag, episode_info
309+
# All ranks now have the same `stop_flag` and `eval_info`.
310+
# All ranks return a valid tuple.
311+
return stop_flag, eval_info

0 commit comments

Comments
 (0)