Skip to content

Commit 7841fdf

Browse files
committed
fix(pu): fix not_enough_data ddp bug
1 parent b0a69b6 commit 7841fdf

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def train_unizero_multitask_segment_ddp(
106106
new_RANDOM_SCORES = RANDOM_SCORES[new_order]
107107
new_HUMAN_SCORES = HUMAN_SCORES[new_order]
108108
# Log the reordered results
109-
print("Reordered RANDOM_SCORES:")
110-
print(new_RANDOM_SCORES)
111-
print("\nReordered HUMAN_SCORES:")
112-
print(new_HUMAN_SCORES)
109+
logging.info("Reordered RANDOM_SCORES:")
110+
logging.info(new_RANDOM_SCORES)
111+
logging.info("\nReordered HUMAN_SCORES:")
112+
logging.info(new_HUMAN_SCORES)
113113
# ------------------------------------------------------------------------------------
114114

115115
# Initialize the temperature scheduler for task weighting.
@@ -150,7 +150,7 @@ def train_unizero_multitask_segment_ddp(
150150
# Initialize empty lists to avoid errors later.
151151
cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], []
152152
else:
153-
print(f"Rank {rank}/{world_size} processing tasks {start_idx} to {end_idx - 1}")
153+
logging.info(f"Rank {rank}/{world_size} processing tasks {start_idx} to {end_idx - 1}")
154154

155155
cfgs = []
156156
game_buffers = []
@@ -281,7 +281,7 @@ def train_unizero_multitask_segment_ddp(
281281
clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4)
282282
allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale)
283283
if rank == 0:
284-
print("Allocated batch_sizes: ", allocated_batch_sizes)
284+
logging.info("Allocated batch_sizes: ", allocated_batch_sizes)
285285
# Assign the corresponding batch size to each task config
286286
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
287287
zip(cfgs, collectors, evaluators, game_buffers)):
@@ -323,11 +323,11 @@ def train_unizero_multitask_segment_ddp(
323323
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
324324

325325
# Check if it's time for evaluation.
326-
# if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0:
327-
if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # TODO: Only for debug
326+
if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0:
327+
# if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # TODO: Only for debug
328328

329-
print('=' * 20)
330-
print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...')
329+
logging.info('=' * 20)
330+
logging.info(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...')
331331

332332
# TODO: Ensure policy reset logic is optimal for multi-task settings.
333333
evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)
@@ -336,28 +336,27 @@ def train_unizero_multitask_segment_ddp(
336336
stop, reward = safe_eval(evaluator, learner, collector, rank, world_size)
337337
# Check if evaluation was successful.
338338
if stop is None or reward is None:
339-
print(f"Rank {rank} encountered issues during evaluation, continuing training...")
339+
logging.warning(f"Rank {rank} encountered issues during evaluation, continuing training...")
340340
task_returns[cfg.policy.task_id] = float('inf') # Set task difficulty to max if evaluation fails.
341341
else:
342342
# Extract 'eval_episode_return_mean' from the reward dictionary.
343343
try:
344344
eval_mean_reward = reward.get('eval_episode_return_mean', float('inf'))
345-
print(f"Task {cfg.policy.task_id} evaluation reward: {eval_mean_reward}")
345+
logging.info(f"Task {cfg.policy.task_id} evaluation reward: {eval_mean_reward}")
346346
task_returns[cfg.policy.task_id] = eval_mean_reward
347347
except Exception as e:
348-
print(f"Error extracting evaluation reward: {e}")
348+
logging.error(f"Error extracting evaluation reward: {e}")
349349
task_returns[cfg.policy.task_id] = float('inf') # Set reward to max on error.
350350

351-
print('=' * 20)
352-
print(f'Starting collection for Rank {rank} task_id: {cfg.policy.task_id}...')
353-
print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ')
351+
logging.info('=' * 20)
352+
logging.info(f'Starting collection for Rank {rank} task_id: {cfg.policy.task_id}...')
353+
logging.info(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ')
354354
logging.info(f'Rank {rank}: Starting data collection for task {cfg.policy.task_id} at train_iter {learner.train_iter}')
355355

356356
# Reset initial data before each collection, crucial for multi-task settings.
357357
collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)
358358
# Collect data.
359359
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
360-
361360
logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments')
362361

363362
# Update the replay buffer.
@@ -380,7 +379,7 @@ def train_unizero_multitask_segment_ddp(
380379
# Log after data collection.
381380
logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}')
382381

383-
# ========== CRITICAL FIX: Synchronize all ranks after data collection ==========
382+
# ========== Synchronize all ranks after data collection ==========
384383
# Wait for all ranks to complete their data collection before proceeding.
385384
# This prevents fast-collecting ranks from reaching barriers/all_gather calls
386385
# while slow-collecting ranks are still in the collection loop.
@@ -394,12 +393,17 @@ def train_unizero_multitask_segment_ddp(
394393
# ===============================================================================
395394

396395
# Check if there is enough data for training.
397-
not_enough_data = any(
396+
local_not_enough_data = any(
398397
replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size
399398
for replay_buffer in game_buffers
400399
)
401-
402-
print(f"not_enough_data:{not_enough_data}")
400+
logging.info(f"Rank {rank} local_not_enough_data:{local_not_enough_data}")
401+
flag_tensor = torch.tensor(1.0 if local_not_enough_data else 0.0, device=cfg.policy.device)
402+
dist.all_reduce(flag_tensor, op=dist.ReduceOp.MAX)
403+
not_enough_data = (flag_tensor.item() > 0.5)
404+
if rank == 0:
405+
logging.info(f"Global not_enough_data status: {not_enough_data}")
406+
403407
# Get the current temperature for task weighting.
404408
current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter)
405409

@@ -526,7 +530,7 @@ def train_unizero_multitask_segment_ddp(
526530
)
527531
# Broadcast task weights to all processes.
528532
dist.broadcast_object_list([task_exploitation_weight], src=0)
529-
print(
533+
logging.info(
530534
f"rank{rank}, task_exploitation_weight (sorted by task_id): {task_exploitation_weight}")
531535
else:
532536
logging.warning(f"Rank {rank}: Unable to compute global obs_loss task weights, obs_loss data is empty.")

lzero/policy/unizero.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ class UniZeroPolicy(MuZeroPolicy):
466466
priority_prob_beta=0.4,
467467
# (int) The initial Env Steps for training.
468468
train_start_after_envsteps=int(0),
469+
# (bool) Whether to use task_exploitation_weight.
470+
use_task_exploitation_weight=False,
469471

470472
# ****** UCB ******
471473
# (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree.

0 commit comments

Comments
 (0)