@@ -106,10 +106,10 @@ def train_unizero_multitask_segment_ddp(
106106 new_RANDOM_SCORES = RANDOM_SCORES [new_order ]
107107 new_HUMAN_SCORES = HUMAN_SCORES [new_order ]
108108 # Log the reordered results
109- print ("Reordered RANDOM_SCORES:" )
110- print (new_RANDOM_SCORES )
111- print ("\n Reordered HUMAN_SCORES:" )
112- print (new_HUMAN_SCORES )
109+ logging . info ("Reordered RANDOM_SCORES:" )
110+ logging . info (new_RANDOM_SCORES )
111+ logging . info ("\n Reordered HUMAN_SCORES:" )
112+ logging . info (new_HUMAN_SCORES )
113113 # ------------------------------------------------------------------------------------
114114
115115 # Initialize the temperature scheduler for task weighting.
@@ -150,7 +150,7 @@ def train_unizero_multitask_segment_ddp(
150150 # Initialize empty lists to avoid errors later.
151151 cfgs , game_buffers , collector_envs , evaluator_envs , collectors , evaluators = [], [], [], [], [], []
152152 else :
153- print (f"Rank { rank } /{ world_size } processing tasks { start_idx } to { end_idx - 1 } " )
153+ logging . info (f"Rank { rank } /{ world_size } processing tasks { start_idx } to { end_idx - 1 } " )
154154
155155 cfgs = []
156156 game_buffers = []
@@ -281,7 +281,7 @@ def train_unizero_multitask_segment_ddp(
281281 clip_scale = np .clip (1 + (3 * train_epoch / 1000 ), 1 , 4 )
282282 allocated_batch_sizes = allocate_batch_size (cfgs , game_buffers , alpha = 1.0 , clip_scale = clip_scale )
283283 if rank == 0 :
284- print ("Allocated batch_sizes: " , allocated_batch_sizes )
284+ logging . info ("Allocated batch_sizes: " , allocated_batch_sizes )
285285 # Assign the corresponding batch size to each task config
286286 for idx , (cfg , collector , evaluator , replay_buffer ) in enumerate (
287287 zip (cfgs , collectors , evaluators , game_buffers )):
@@ -323,11 +323,11 @@ def train_unizero_multitask_segment_ddp(
323323 collect_kwargs ['epsilon' ] = epsilon_greedy_fn (collector .envstep )
324324
325325 # Check if it's time for evaluation.
326- # if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0:
327- if learner .train_iter == 0 or learner .train_iter % cfg .policy .eval_freq == 0 : # TODO: Only for debug
326+ if learner .train_iter > 10 and learner .train_iter % cfg .policy .eval_freq == 0 :
327+ # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # TODO: Only for debug
328328
329- print ('=' * 20 )
330- print (f'Rank { rank } evaluating task_id: { cfg .policy .task_id } ...' )
329+ logging . info ('=' * 20 )
330+ logging . info (f'Rank { rank } evaluating task_id: { cfg .policy .task_id } ...' )
331331
332332 # TODO: Ensure policy reset logic is optimal for multi-task settings.
333333 evaluator ._policy .reset (reset_init_data = True , task_id = cfg .policy .task_id )
@@ -336,28 +336,27 @@ def train_unizero_multitask_segment_ddp(
336336 stop , reward = safe_eval (evaluator , learner , collector , rank , world_size )
337337 # Check if evaluation was successful.
338338 if stop is None or reward is None :
339- print (f"Rank { rank } encountered issues during evaluation, continuing training..." )
339+ logging . warning (f"Rank { rank } encountered issues during evaluation, continuing training..." )
340340 task_returns [cfg .policy .task_id ] = float ('inf' ) # Set task difficulty to max if evaluation fails.
341341 else :
342342 # Extract 'eval_episode_return_mean' from the reward dictionary.
343343 try :
344344 eval_mean_reward = reward .get ('eval_episode_return_mean' , float ('inf' ))
345- print (f"Task { cfg .policy .task_id } evaluation reward: { eval_mean_reward } " )
345+ logging . info (f"Task { cfg .policy .task_id } evaluation reward: { eval_mean_reward } " )
346346 task_returns [cfg .policy .task_id ] = eval_mean_reward
347347 except Exception as e :
348- print (f"Error extracting evaluation reward: { e } " )
348+ logging . error (f"Error extracting evaluation reward: { e } " )
349349 task_returns [cfg .policy .task_id ] = float ('inf' ) # Set reward to max on error.
350350
351- print ('=' * 20 )
352- print (f'Starting collection for Rank { rank } task_id: { cfg .policy .task_id } ...' )
353- print (f'Rank { rank } : cfg.policy.task_id={ cfg .policy .task_id } ' )
351+ logging . info ('=' * 20 )
352+ logging . info (f'Starting collection for Rank { rank } task_id: { cfg .policy .task_id } ...' )
353+ logging . info (f'Rank { rank } : cfg.policy.task_id={ cfg .policy .task_id } ' )
354354 logging .info (f'Rank { rank } : Starting data collection for task { cfg .policy .task_id } at train_iter { learner .train_iter } ' )
355355
356356 # Reset initial data before each collection, crucial for multi-task settings.
357357 collector ._policy .reset (reset_init_data = True , task_id = cfg .policy .task_id )
358358 # Collect data.
359359 new_data = collector .collect (train_iter = learner .train_iter , policy_kwargs = collect_kwargs )
360-
361360 logging .info (f'Rank { rank } : Finished data collection for task { cfg .policy .task_id } , collected { len (new_data [0 ]) if new_data else 0 } segments' )
362361
363362 # Update the replay buffer.
@@ -380,7 +379,7 @@ def train_unizero_multitask_segment_ddp(
380379 # Log after data collection.
381380 logging .info (f'Rank { rank } : Completed data collection for task { cfg .policy .task_id } ' )
382381
383- # ========== CRITICAL FIX: Synchronize all ranks after data collection ==========
382+ # ========== Synchronize all ranks after data collection ==========
384383 # Wait for all ranks to complete their data collection before proceeding.
385384 # This prevents fast-collecting ranks from reaching barriers/all_gather calls
386385 # while slow-collecting ranks are still in the collection loop.
@@ -394,12 +393,17 @@ def train_unizero_multitask_segment_ddp(
394393 # ===============================================================================
395394
396395 # Check if there is enough data for training.
397- not_enough_data = any (
396+ local_not_enough_data = any (
398397 replay_buffer .get_num_of_transitions () < cfgs [0 ].policy .total_batch_size / world_size
399398 for replay_buffer in game_buffers
400399 )
401-
402- print (f"not_enough_data:{ not_enough_data } " )
400+ logging .info (f"Rank { rank } local_not_enough_data:{ local_not_enough_data } " )
401+ flag_tensor = torch .tensor (1.0 if local_not_enough_data else 0.0 , device = cfg .policy .device )
402+ dist .all_reduce (flag_tensor , op = dist .ReduceOp .MAX )
403+ not_enough_data = (flag_tensor .item () > 0.5 )
404+ if rank == 0 :
405+ logging .info (f"Global not_enough_data status: { not_enough_data } " )
406+
403407 # Get the current temperature for task weighting.
404408 current_temperature_task_weight = temperature_scheduler .get_temperature (learner .train_iter )
405409
@@ -526,7 +530,7 @@ def train_unizero_multitask_segment_ddp(
526530 )
527531 # Broadcast task weights to all processes.
528532 dist .broadcast_object_list ([task_exploitation_weight ], src = 0 )
529- print (
533+ logging . info (
530534 f"rank{ rank } , task_exploitation_weight (sorted by task_id): { task_exploitation_weight } " )
531535 else :
532536 logging .warning (f"Rank { rank } : Unable to compute global obs_loss task weights, obs_loss data is empty." )
0 commit comments