Skip to content

Commit bd67cdf

Browse files
author
wangshulun
committed
fix(pu): fix exp_name and task_id bug in dmc pipeline, fix some configs
1 parent b9b8d26 commit bd67cdf

19 files changed

+976
-93
lines changed

lzero/entry/train_unizero_multitask_balance_segment_ddp.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,27 @@ def train_unizero_multitask_balance_segment_ddp(
282282

283283
# --- Environment, Policy, and Worker Initialization ---
284284
task_configs, replay_buffers, collectors, evaluators = [], [], [], []
285-
285+
286286
# Use the first task's config to create the shared policy and learner
287287
_, [main_cfg, main_create_cfg] = tasks_for_this_rank[0]
288288
for _, [cfg, _] in tasks_for_this_rank:
289289
cfg.policy.task_num = len(tasks_for_this_rank)
290290

291+
# ==================== START: Robust exp_name Fix ====================
292+
# Ensure main_cfg has a valid exp_name before calling compile_config.
293+
# If exp_name is missing, None, or too long, set a safe default.
294+
if not hasattr(main_cfg, 'exp_name') or main_cfg.exp_name is None or len(str(main_cfg.exp_name)) > 200:
295+
# Use a simplified experiment name for the main config
296+
safe_exp_name = f'data_unizero_mt_balance/dmc_multitask_seed{seed}'
297+
logging.warning(
298+
f"Rank {rank}: main_cfg.exp_name is missing, None, or too long. "
299+
f"Setting to safe default: {safe_exp_name}"
300+
)
301+
main_cfg.exp_name = safe_exp_name
302+
else:
303+
logging.info(f"Rank {rank}: Using exp_name from config: {main_cfg.exp_name}")
304+
# ==================== END: Robust exp_name Fix ====================
305+
291306
assert main_create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \
292307
"This entry only supports 'unizero_multitask' or 'sampled_unizero_multitask' policies."
293308

@@ -299,12 +314,37 @@ def train_unizero_multitask_balance_segment_ddp(
299314

300315
main_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu'
301316
compiled_cfg = compile_config(main_cfg, seed=seed, auto=True, create_cfg=main_create_cfg, save_cfg=True)
302-
317+
303318
policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
319+
320+
# Log initial model architecture info BEFORE loading checkpoint
321+
if rank == 0:
322+
num_layers_config = compiled_cfg.policy.model.world_model_cfg.num_layers
323+
initial_params = sum(p.numel() for p in policy._learn_model.world_model.parameters())
324+
initial_trainable = sum(p.numel() for p in policy._learn_model.world_model.parameters() if p.requires_grad)
325+
logging.info(f"=" * 80)
326+
logging.info(f"Model Architecture Configuration:")
327+
logging.info(f" - num_layers from config: {num_layers_config}")
328+
logging.info(f" - Total parameters (before checkpoint load): {initial_params:,}")
329+
logging.info(f" - Trainable parameters (before checkpoint load): {initial_trainable:,}")
330+
logging.info(f"=" * 80)
331+
304332
if model_path:
305333
logging.info(f'Loading pre-trained model from: {model_path}')
306334
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device))
307335
logging.info('Model loading complete.')
336+
if rank == 0:
337+
loaded_params = sum(p.numel() for p in policy._learn_model.world_model.parameters())
338+
loaded_trainable = sum(p.numel() for p in policy._learn_model.world_model.parameters() if p.requires_grad)
339+
logging.info(f"Model Parameters After Loading Checkpoint:")
340+
logging.info(f" - Total parameters (after checkpoint load): {loaded_params:,}")
341+
logging.info(f" - Trainable parameters (after checkpoint load): {loaded_trainable:,}")
342+
if initial_params != loaded_params:
343+
logging.warning(f"⚠️ WARNING: Parameter count mismatch!")
344+
logging.warning(f" Config specifies {initial_params:,} params, but loaded model has {loaded_params:,} params")
345+
logging.warning(f" This usually means the checkpoint was trained with different num_layers!")
346+
logging.warning(f" The loaded checkpoint architecture will override your config settings.")
347+
308348

309349
tb_logger = SummaryWriter(os.path.join(f'./{compiled_cfg.exp_name}/log', f'rank_{rank}'))
310350
learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg.exp_name)
@@ -314,6 +354,19 @@ def train_unizero_multitask_balance_segment_ddp(
314354
for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank):
315355
task_seed = seed + task_id
316356
cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu'
357+
358+
# ==================== START: Robust exp_name Fix for Task Config ====================
359+
# Ensure each task config has a valid exp_name before calling compile_config
360+
if not hasattr(cfg, 'exp_name') or cfg.exp_name is None:
361+
# Extract env_id from config if available, otherwise use task_id
362+
env_id = getattr(cfg.env, 'env_id', f'task{task_id}')
363+
cfg.exp_name = f'data_unizero_mt_balance/task_{env_id}_seed{task_seed}'
364+
logging.warning(
365+
f"Rank {rank}: Task {task_id} config missing exp_name. "
366+
f"Setting to: {cfg.exp_name}"
367+
)
368+
# ==================== END: Robust exp_name Fix for Task Config ====================
369+
317370
compiled_task_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True)
318371

319372
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_task_cfg.env)
@@ -324,8 +377,28 @@ def train_unizero_multitask_balance_segment_ddp(
324377
set_pkg_seed(task_seed, use_cuda=compiled_task_cfg.policy.cuda)
325378

326379
replay_buffers.append(GameBuffer(compiled_task_cfg.policy))
327-
collectors.append(Collector(collector_env, policy.collect_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id))
328-
evaluators.append(Evaluator(compiled_task_cfg.policy.eval_freq, compiled_task_cfg.env.n_evaluator_episode, compiled_task_cfg.env.stop_value, evaluator_env, policy.eval_mode, tb_logger, compiled_task_cfg.exp_name, compiled_task_cfg.policy, task_id))
380+
collectors.append(Collector(
381+
collect_print_freq=100,
382+
env=collector_env,
383+
policy=policy.collect_mode,
384+
tb_logger=tb_logger,
385+
exp_name=compiled_task_cfg.exp_name,
386+
instance_name=f'collector_task{task_id}',
387+
policy_config=compiled_task_cfg.policy,
388+
task_id=task_id
389+
))
390+
evaluators.append(Evaluator(
391+
eval_freq=compiled_task_cfg.policy.eval_freq,
392+
n_evaluator_episode=compiled_task_cfg.env.n_evaluator_episode,
393+
stop_value=compiled_task_cfg.env.stop_value,
394+
env=evaluator_env,
395+
policy=policy.eval_mode,
396+
tb_logger=tb_logger,
397+
exp_name=compiled_task_cfg.exp_name,
398+
instance_name=f'evaluator_task{task_id}',
399+
policy_config=compiled_task_cfg.policy,
400+
task_id=task_id
401+
))
329402
task_configs.append(compiled_task_cfg)
330403

331404
# --- Curriculum and Training Loop Initialization ---
@@ -348,8 +421,17 @@ def train_unizero_multitask_balance_segment_ddp(
348421
allocated_batch_sizes = allocate_batch_size(task_configs, replay_buffers, alpha=1.0, clip_scale=clip_scale)
349422
if rank == 0:
350423
logging.info(f"Dynamically allocated batch sizes: {allocated_batch_sizes}")
424+
# Assign the corresponding batch size to each task config
351425
for i, cfg in enumerate(task_configs):
352-
cfg.policy.batch_size = allocated_batch_sizes
426+
task_id = cfg.policy.task_id
427+
if isinstance(allocated_batch_sizes, dict):
428+
cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size)
429+
elif isinstance(allocated_batch_sizes, list):
430+
# Use the index in the list or task_id as fallback
431+
cfg.policy.batch_size = allocated_batch_sizes[i] if i < len(allocated_batch_sizes) else cfg.policy.batch_size
432+
else:
433+
logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}")
434+
# Also update the policy config (use the full list for compatibility)
353435
policy._cfg.batch_size = allocated_batch_sizes
354436

355437
# --- 2. Data Collection and Evaluation for each task on this rank ---
@@ -505,7 +587,15 @@ def train_unizero_multitask_balance_segment_ddp(
505587
train_data_list = []
506588
total_envstep = sum(c.envstep for c in collectors)
507589
for cfg, replay_buffer in zip(unsolved_cfgs, unsolved_buffers):
508-
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
590+
# Handle batch_size whether it's an int, list, or dict
591+
if isinstance(cfg.policy.batch_size, (list, tuple)):
592+
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
593+
elif isinstance(cfg.policy.batch_size, dict):
594+
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
595+
else:
596+
# batch_size is already an integer
597+
batch_size = cfg.policy.batch_size
598+
509599
if replay_buffer.get_num_of_transitions() >= batch_size:
510600
train_data = replay_buffer.sample(batch_size, policy)
511601
train_data.append(cfg.policy.task_id)

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,13 @@ def train_unizero_multitask_segment_ddp(
554554
)
555555

556556
cfgs.append(cfg)
557-
replay_buffer.batch_size = cfg.policy.batch_size[task_id]
557+
# Handle batch_size robustly - it might be a list or already an integer
558+
if isinstance(cfg.policy.batch_size, (list, tuple)):
559+
replay_buffer.batch_size = cfg.policy.batch_size[task_id]
560+
elif isinstance(cfg.policy.batch_size, dict):
561+
replay_buffer.batch_size = cfg.policy.batch_size[task_id]
562+
else:
563+
replay_buffer.batch_size = cfg.policy.batch_size
558564

559565
game_buffers.append(replay_buffer)
560566
collector_envs.append(collector_env)
@@ -583,10 +589,19 @@ def train_unizero_multitask_segment_ddp(
583589
allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale)
584590
if rank == 0:
585591
print("分配后的 batch_sizes: ", allocated_batch_sizes)
592+
# Assign the corresponding batch size to each task config
586593
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
587594
zip(cfgs, collectors, evaluators, game_buffers)):
588-
cfg.policy.batch_size = allocated_batch_sizes
589-
policy._cfg.batch_size = allocated_batch_sizes
595+
task_id = cfg.policy.task_id
596+
if isinstance(allocated_batch_sizes, dict):
597+
cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size)
598+
elif isinstance(allocated_batch_sizes, list):
599+
# Use the index in the list or task_id as fallback
600+
cfg.policy.batch_size = allocated_batch_sizes[idx] if idx < len(allocated_batch_sizes) else cfg.policy.batch_size
601+
else:
602+
logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}")
603+
# Also update the policy config (use the full list for compatibility)
604+
policy._cfg.batch_size = allocated_batch_sizes
590605

591606
# For each task on the current rank, perform data collection and evaluation.
592607
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
@@ -737,7 +752,14 @@ def train_unizero_multitask_segment_ddp(
737752
envstep_multi_task = 0
738753
for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)):
739754
envstep_multi_task += collector.envstep
740-
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
755+
# Handle batch_size robustly - it might be a list or already an integer
756+
if isinstance(cfg.policy.batch_size, (list, tuple)):
757+
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
758+
elif isinstance(cfg.policy.batch_size, dict):
759+
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
760+
else:
761+
batch_size = cfg.policy.batch_size
762+
741763
if replay_buffer.get_num_of_transitions() > batch_size:
742764
if cfg.policy.buffer_reanalyze_freq >= 1:
743765
if i % reanalyze_interval == 0 and \

lzero/model/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,12 @@ def __init__(
587587
"""
588588
super().__init__()
589589
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
590-
logging.info(f"Using norm type: {norm_type}")
591-
logging.info(f"Using activation type: {activation}")
590+
591+
# Only log from rank 0 to avoid excessive output in distributed training
592+
from ding.utils import get_rank
593+
if get_rank() == 0:
594+
logging.info(f"Using norm type: {norm_type}")
595+
logging.info(f"Using activation type: {activation}")
592596

593597
self.observation_shape = observation_shape
594598
self.downsample = downsample

0 commit comments

Comments
 (0)