@@ -175,6 +175,8 @@ def eval(
175175 """
176176 Overview:
177177 Execute the evaluation of the policy and determine if the stopping condition has been met.
178+ In a distributed setting, this method will block all processes except rank 0,
179+ which performs the evaluation. The results are then broadcasted to all other processes.
178180 Arguments:
179181 - save_ckpt_fn (:obj:`Optional[Callable]`): Callback function to save a checkpoint.
180182 - train_iter (:obj:`int`): Current number of training iterations completed.
@@ -183,11 +185,18 @@ def eval(
183185 - force_render (:obj:`bool`): Force rendering of the environment, if applicable.
184186 Returns:
185187 - stop_flag (:obj:`bool`): Whether the training process should stop based on evaluation results.
186- - return_info (:obj:`dict`): Information about the evaluation results.
188+ - eval_info (:obj:`dict`): Information about the evaluation results.
187189 """
188- # the evaluator only works on rank0
189- stop_flag , return_info = False , []
190+ # ==============================================================
191+ # FIX: Restructure the entire method for correct distributed handling.
192+ # ==============================================================
193+
194+ # Initialize placeholders for results on all ranks.
195+ stop_flag = False
196+ eval_info = {}
197+
190198 if get_rank () == 0 :
199+ # --- Rank 0 performs the evaluation ---
191200 if n_episode is None :
192201 n_episode = self ._default_n_episode
193202 assert n_episode is not None , "please indicate eval n_episode"
@@ -199,17 +208,19 @@ def eval(
199208 with self ._timer :
200209 while not eval_monitor .is_finished ():
201210 obs = self ._env .ready_obs
202-
211+
203212 # ==============================================================
204213 # policy forward
205214 # ==============================================================
206215 policy_output = self ._policy .forward (obs )
207216 actions = {env_id : output ['action' ] for env_id , output in policy_output .items ()}
217+
208218 # ==============================================================
209219 # Interact with env.
210220 # ==============================================================
211221 timesteps = self ._env .step (actions )
212222 timesteps = to_tensor (timesteps , dtype = torch .float32 )
223+
213224 for env_id , t in timesteps .items ():
214225 if t .info .get ('abnormal' , False ):
215226 # If there is an abnormal timestep, reset all the related variables(including this env).
@@ -224,15 +235,17 @@ def eval(
224235 saved_info .update (t .info ['episode_info' ])
225236 eval_monitor .update_info (env_id , saved_info )
226237 eval_monitor .update_reward (env_id , reward )
227- return_info .append (t .info )
228238 self ._logger .info (
229239 "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}" .format (
230240 env_id , eval_monitor .get_latest_reward (env_id ), eval_monitor .get_current_episode ()
231241 )
232242 )
233243 envstep_count += 1
244+
234245 duration = self ._timer .value
235246 episode_return = eval_monitor .get_episode_return ()
247+
248+ # Prepare the results dictionary
236249 info = {
237250 'train_iter' : train_iter ,
238251 'ckpt_name' : 'iteration_{}.pth.tar' .format (train_iter ),
@@ -248,11 +261,13 @@ def eval(
248261 'reward_min' : np .min (episode_return ),
249262 # 'each_reward': episode_return,
250263 }
251- episode_info = eval_monitor .get_episode_info ()
252- if episode_info is not None :
253- info .update (episode_info )
264+ episode_info_from_monitor = eval_monitor .get_episode_info ()
265+ if episode_info_from_monitor is not None :
266+ info .update (episode_info_from_monitor )
267+
254268 self ._logger .info (self ._logger .get_tabulate_vars_hor (info ))
255- # self._logger.info(self._logger.get_tabulate_vars(info))
269+
270+ # Log to TensorBoard
256271 for k , v in info .items ():
257272 if k in ['train_iter' , 'ckpt_name' , 'each_reward' ]:
258273 continue
@@ -266,6 +281,8 @@ def eval(
266281 if save_ckpt_fn :
267282 save_ckpt_fn ('ckpt_best.pth.tar' )
268283 self ._max_eval_reward = eval_reward
284+
285+ # Set the final results for rank 0
269286 stop_flag = eval_reward >= self ._stop_value and train_iter > 0
270287 if stop_flag :
271288 self ._logger .info (
@@ -274,11 +291,21 @@ def eval(
274291 ", so your AlphaZero agent is converged, you can refer to " +
275292 "'log/evaluator/evaluator_logger.txt' for details."
276293 )
294+
295+ # The final information to be returned and broadcasted
296+ eval_info = to_item (info )
277297
278- if get_world_size () > 1 :
279- objects = [stop_flag , episode_info ]
280- broadcast_object_list (objects , src = 0 )
281- stop_flag , episode_info = objects
298+ # --- Synchronization for all ranks ---
299+ if get_world_size () > 1 :
300+ # All processes must participate in the broadcast.
301+ # `src=0` means rank 0 sends, and all other ranks receive.
302+ # The `objects` list on rank 0 contains the data to be sent.
303+ # On other ranks, it contains placeholders that will be overwritten.
304+ objects = [stop_flag , eval_info ]
305+ broadcast_object_list (objects , src = 0 )
306+ # After broadcast, all processes' `objects` list is updated.
307+ stop_flag , eval_info = objects
282308
283- episode_info = to_item (episode_info )
284- return stop_flag , episode_info
309+ # All ranks now have the same `stop_flag` and `eval_info`.
310+ # All ranks return a valid tuple.
311+ return stop_flag , eval_info
0 commit comments