Skip to content

Commit 5ed77bf

Browse files
committed
fix(pu): fix unizero_multitask ddp barrier bug
1 parent cb5ae6b commit 5ed77bf

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,12 +622,15 @@ def train_unizero_multitask_segment_ddp(
622622
print('=' * 20)
623623
print(f'Starting collection for Rank {rank} task_id: {cfg.policy.task_id}...')
624624
print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ')
625+
logging.info(f'Rank {rank}: Starting data collection for task {cfg.policy.task_id} at train_iter {learner.train_iter}')
625626

626627
# Reset initial data before each collection, crucial for multi-task settings.
627628
collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)
628629
# Collect data.
629630
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
630631

632+
logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments')
633+
631634
# Update the replay buffer.
632635
replay_buffer.push_game_segments(new_data)
633636
replay_buffer.remove_oldest_data_to_fit()
@@ -648,6 +651,19 @@ def train_unizero_multitask_segment_ddp(
648651
# Log after data collection.
649652
logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}')
650653

654+
# ========== CRITICAL FIX: Synchronize all ranks after data collection ==========
655+
# Wait for all ranks to complete their data collection before proceeding.
656+
# This prevents fast-collecting ranks from reaching barriers/all_gather calls
657+
# while slow-collecting ranks are still in the collection loop.
658+
try:
659+
logging.info(f'Rank {rank}: Waiting at post-collection barrier...')
660+
dist.barrier()
661+
logging.info(f'Rank {rank}: All ranks completed data collection, proceeding...')
662+
except Exception as e:
663+
logging.error(f'Rank {rank}: Post-collection barrier failed, error: {e}')
664+
raise e
665+
# ===============================================================================
666+
651667
# Check if there is enough data for training.
652668
not_enough_data = any(
653669
replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size
@@ -662,7 +678,9 @@ def train_unizero_multitask_segment_ddp(
662678
# Calculate task weights.
663679
try:
664680
# Gather task rewards.
681+
logging.info(f'Rank {rank}: Entering evaluation synchronization barrier at train_iter {learner.train_iter}')
665682
dist.barrier()
683+
logging.info(f'Rank {rank}: Passed evaluation barrier, gathering task returns')
666684
all_task_returns = [None for _ in range(world_size)]
667685
dist.all_gather_object(all_task_returns, task_returns)
668686
# Merge task rewards.

lzero/model/unizero_world_models/lpips.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torchvision import models
1414
from tqdm import tqdm
1515

16-
1716
class LPIPS(nn.Module):
1817
# Learned perceptual metric
1918
def __init__(self, use_dropout: bool = True):

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,13 @@ def create_env_manager() -> EasyDict:
384384
reanalyze_partition = 0.75
385385

386386
# ==================== Training Loop ====================
387+
# Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds
388+
# Different games (e.g., Pong vs Seaquest) have vastly different episode lengths,
389+
# which can cause some ranks to finish collection much faster than others.
390+
# Default timeout is 30 minutes; we increase it to 60 minutes for safety.
391+
os.environ.setdefault('NCCL_TIMEOUT', '3600') # 60 minutes in seconds
392+
os.environ.setdefault('NCCL_BLOCKING_WAIT', '1') # Enable blocking wait for better error messages
393+
387394
for seed in [0]:
388395
configs = generate_configs(
389396
env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num,

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def main(env_id, seed):
6868
world_model_cfg=dict(
6969
latent_recon_loss_weight=0.1, # TODO
7070
perceptual_loss_weight=0.1,
71-
use_new_cache_manager=False, # TODO
71+
# use_new_cache_manager=False, # TODO
72+
use_new_cache_manager=True, # ==============TODO==============
7273

7374
norm_type=norm_type,
7475
final_norm_option_in_obs_head='LayerNorm',
@@ -260,7 +261,7 @@ def main(env_id, seed):
260261

261262
# ============ use muzero_segment_collector instead of muzero_collector =============
262263
from lzero.entry import train_unizero_segment
263-
main_config.exp_name = f'data_unizero_st_1226/{env_id[3:-3]}/{env_id[3:-3]}_uz_head-clip-p_target005_allhead4_targetentropy-alpha-500k-098-005-min005_mse-loss2_rec01_poli-clip10_pol-smo-005_pol-loss-tmp-1.5_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
264+
main_config.exp_name = f'data_unizero_st_1226_2/{env_id[3:-3]}/{env_id[3:-3]}_uz_newkv_head-clip-p_target005_allhead4_targetentropy-alpha-500k-098-005-min005_mse-loss2_rec01_poli-clip10_pol-smo-005_pol-loss-tmp-1.5_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
264265
# main_config.exp_name = f'data_unizero/{env_id[3:-3]}/{env_id[3:-3]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
265266

266267
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)
@@ -274,9 +275,10 @@ def main(env_id, seed):
274275
args = parser.parse_args()
275276

276277
# Test environments from atari8 base set
277-
# args.env = 'ALE/Pong-v5' # Memory-planning environment with sparse rewards
278+
args.env = 'ALE/Pong-v5' # Memory-planning environment with sparse rewards
278279
# args.env = 'ALE/Qbert-v5' # Memory-planning environment with sparse rewards
279-
args.env = 'ALE/MsPacman-v5' # Memory-planning environment with sparse rewards
280+
281+
# args.env = 'ALE/MsPacman-v5' # Memory-planning environment with sparse rewards
280282

281283
main(args.env, args.seed)
282284

0 commit comments

Comments
 (0)