@@ -622,12 +622,15 @@ def train_unizero_multitask_segment_ddp(
622622 print ('=' * 20 )
623623 print (f'Starting collection for Rank { rank } task_id: { cfg .policy .task_id } ...' )
624624 print (f'Rank { rank } : cfg.policy.task_id={ cfg .policy .task_id } ' )
625+ logging .info (f'Rank { rank } : Starting data collection for task { cfg .policy .task_id } at train_iter { learner .train_iter } ' )
625626
626627 # Reset initial data before each collection, crucial for multi-task settings.
627628 collector ._policy .reset (reset_init_data = True , task_id = cfg .policy .task_id )
628629 # Collect data.
629630 new_data = collector .collect (train_iter = learner .train_iter , policy_kwargs = collect_kwargs )
630631
632+ logging .info (f'Rank { rank } : Finished data collection for task { cfg .policy .task_id } , collected { len (new_data [0 ]) if new_data else 0 } segments' )
633+
631634 # Update the replay buffer.
632635 replay_buffer .push_game_segments (new_data )
633636 replay_buffer .remove_oldest_data_to_fit ()
@@ -648,6 +651,19 @@ def train_unizero_multitask_segment_ddp(
648651 # Log after data collection.
649652 logging .info (f'Rank { rank } : Completed data collection for task { cfg .policy .task_id } ' )
650653
654+ # ========== CRITICAL FIX: Synchronize all ranks after data collection ==========
655+ # Wait for all ranks to complete their data collection before proceeding.
656+ # This prevents fast-collecting ranks from reaching barriers/all_gather calls
657+ # while slow-collecting ranks are still in the collection loop.
658+ try :
659+ logging .info (f'Rank { rank } : Waiting at post-collection barrier...' )
660+ dist .barrier ()
661+ logging .info (f'Rank { rank } : All ranks completed data collection, proceeding...' )
662+ except Exception as e :
663+ logging .error (f'Rank { rank } : Post-collection barrier failed, error: { e } ' )
664+ raise e
665+ # ===============================================================================
666+
651667 # Check if there is enough data for training.
652668 not_enough_data = any (
653669 replay_buffer .get_num_of_transitions () < cfgs [0 ].policy .total_batch_size / world_size
@@ -662,7 +678,9 @@ def train_unizero_multitask_segment_ddp(
662678 # Calculate task weights.
663679 try :
664680 # Gather task rewards.
681+ logging .info (f'Rank { rank } : Entering evaluation synchronization barrier at train_iter { learner .train_iter } ' )
665682 dist .barrier ()
683+ logging .info (f'Rank { rank } : Passed evaluation barrier, gathering task returns' )
666684 all_task_returns = [None for _ in range (world_size )]
667685 dist .all_gather_object (all_task_returns , task_returns )
668686 # Merge task rewards.
0 commit comments