@@ -282,12 +282,27 @@ def train_unizero_multitask_balance_segment_ddp(
282282
283283 # --- Environment, Policy, and Worker Initialization ---
284284 task_configs , replay_buffers , collectors , evaluators = [], [], [], []
285-
285+
286286 # Use the first task's config to create the shared policy and learner
287287 _ , [main_cfg , main_create_cfg ] = tasks_for_this_rank [0 ]
288288 for _ , [cfg , _ ] in tasks_for_this_rank :
289289 cfg .policy .task_num = len (tasks_for_this_rank )
290290
291+ # ==================== START: Robust exp_name Fix ====================
292+ # Ensure main_cfg has a valid exp_name before calling compile_config.
293+ # If exp_name is missing, None, or too long, set a safe default.
294+ if not hasattr (main_cfg , 'exp_name' ) or main_cfg .exp_name is None or len (str (main_cfg .exp_name )) > 200 :
295+ # Use a simplified experiment name for the main config
296+ safe_exp_name = f'data_unizero_mt_balance/dmc_multitask_seed{ seed } '
297+ logging .warning (
298+ f"Rank { rank } : main_cfg.exp_name is missing, None, or too long. "
299+ f"Setting to safe default: { safe_exp_name } "
300+ )
301+ main_cfg .exp_name = safe_exp_name
302+ else :
303+ logging .info (f"Rank { rank } : Using exp_name from config: { main_cfg .exp_name } " )
304+ # ==================== END: Robust exp_name Fix ====================
305+
291306 assert main_create_cfg .policy .type in ['unizero_multitask' , 'sampled_unizero_multitask' ], \
292307 "This entry only supports 'unizero_multitask' or 'sampled_unizero_multitask' policies."
293308
@@ -299,12 +314,37 @@ def train_unizero_multitask_balance_segment_ddp(
299314
300315 main_cfg .policy .device = 'cuda' if torch .cuda .is_available () else 'cpu'
301316 compiled_cfg = compile_config (main_cfg , seed = seed , auto = True , create_cfg = main_create_cfg , save_cfg = True )
302-
317+
303318 policy = create_policy (compiled_cfg .policy , model = model , enable_field = ['learn' , 'collect' , 'eval' ])
319+
320+ # Log initial model architecture info BEFORE loading checkpoint
321+ if rank == 0 :
322+ num_layers_config = compiled_cfg .policy .model .world_model_cfg .num_layers
323+ initial_params = sum (p .numel () for p in policy ._learn_model .world_model .parameters ())
324+ initial_trainable = sum (p .numel () for p in policy ._learn_model .world_model .parameters () if p .requires_grad )
325+ logging .info (f"=" * 80 )
326+ logging .info (f"Model Architecture Configuration:" )
327+ logging .info (f" - num_layers from config: { num_layers_config } " )
328+ logging .info (f" - Total parameters (before checkpoint load): { initial_params :,} " )
329+ logging .info (f" - Trainable parameters (before checkpoint load): { initial_trainable :,} " )
330+ logging .info (f"=" * 80 )
331+
304332 if model_path :
305333 logging .info (f'Loading pre-trained model from: { model_path } ' )
306334 policy .learn_mode .load_state_dict (torch .load (model_path , map_location = compiled_cfg .policy .device ))
307335 logging .info ('Model loading complete.' )
336+ if rank == 0 :
337+ loaded_params = sum (p .numel () for p in policy ._learn_model .world_model .parameters ())
338+ loaded_trainable = sum (p .numel () for p in policy ._learn_model .world_model .parameters () if p .requires_grad )
339+ logging .info (f"Model Parameters After Loading Checkpoint:" )
340+ logging .info (f" - Total parameters (after checkpoint load): { loaded_params :,} " )
341+ logging .info (f" - Trainable parameters (after checkpoint load): { loaded_trainable :,} " )
342+ if initial_params != loaded_params :
343+ logging .warning (f"⚠️ WARNING: Parameter count mismatch!" )
344+ logging .warning (f" Config specifies { initial_params :,} params, but loaded model has { loaded_params :,} params" )
345+ logging .warning (f" This usually means the checkpoint was trained with different num_layers!" )
346+ logging .warning (f" The loaded checkpoint architecture will override your config settings." )
347+
308348
309349 tb_logger = SummaryWriter (os .path .join (f'./{ compiled_cfg .exp_name } /log' , f'rank_{ rank } ' ))
310350 learner = BaseLearner (compiled_cfg .policy .learn .learner , policy .learn_mode , tb_logger , exp_name = compiled_cfg .exp_name )
@@ -314,6 +354,19 @@ def train_unizero_multitask_balance_segment_ddp(
314354 for local_task_id , (task_id , [cfg , create_cfg ]) in enumerate (tasks_for_this_rank ):
315355 task_seed = seed + task_id
316356 cfg .policy .device = 'cuda' if cfg .policy .cuda and torch .cuda .is_available () else 'cpu'
357+
358+ # ==================== START: Robust exp_name Fix for Task Config ====================
359+ # Ensure each task config has a valid exp_name before calling compile_config
360+ if not hasattr (cfg , 'exp_name' ) or cfg .exp_name is None :
361+ # Extract env_id from config if available, otherwise use task_id
362+ env_id = getattr (cfg .env , 'env_id' , f'task{ task_id } ' )
363+ cfg .exp_name = f'data_unizero_mt_balance/task_{ env_id } _seed{ task_seed } '
364+ logging .warning (
365+ f"Rank { rank } : Task { task_id } config missing exp_name. "
366+ f"Setting to: { cfg .exp_name } "
367+ )
368+ # ==================== END: Robust exp_name Fix for Task Config ====================
369+
317370 compiled_task_cfg = compile_config (cfg , seed = task_seed , auto = True , create_cfg = create_cfg , save_cfg = True )
318371
319372 env_fn , collector_env_cfg , evaluator_env_cfg = get_vec_env_setting (compiled_task_cfg .env )
@@ -324,8 +377,28 @@ def train_unizero_multitask_balance_segment_ddp(
324377 set_pkg_seed (task_seed , use_cuda = compiled_task_cfg .policy .cuda )
325378
326379 replay_buffers .append (GameBuffer (compiled_task_cfg .policy ))
327- collectors .append (Collector (collector_env , policy .collect_mode , tb_logger , compiled_task_cfg .exp_name , compiled_task_cfg .policy , task_id ))
328- evaluators .append (Evaluator (compiled_task_cfg .policy .eval_freq , compiled_task_cfg .env .n_evaluator_episode , compiled_task_cfg .env .stop_value , evaluator_env , policy .eval_mode , tb_logger , compiled_task_cfg .exp_name , compiled_task_cfg .policy , task_id ))
380+ collectors .append (Collector (
381+ collect_print_freq = 100 ,
382+ env = collector_env ,
383+ policy = policy .collect_mode ,
384+ tb_logger = tb_logger ,
385+ exp_name = compiled_task_cfg .exp_name ,
386+ instance_name = f'collector_task{ task_id } ' ,
387+ policy_config = compiled_task_cfg .policy ,
388+ task_id = task_id
389+ ))
390+ evaluators .append (Evaluator (
391+ eval_freq = compiled_task_cfg .policy .eval_freq ,
392+ n_evaluator_episode = compiled_task_cfg .env .n_evaluator_episode ,
393+ stop_value = compiled_task_cfg .env .stop_value ,
394+ env = evaluator_env ,
395+ policy = policy .eval_mode ,
396+ tb_logger = tb_logger ,
397+ exp_name = compiled_task_cfg .exp_name ,
398+ instance_name = f'evaluator_task{ task_id } ' ,
399+ policy_config = compiled_task_cfg .policy ,
400+ task_id = task_id
401+ ))
329402 task_configs .append (compiled_task_cfg )
330403
331404 # --- Curriculum and Training Loop Initialization ---
@@ -348,8 +421,17 @@ def train_unizero_multitask_balance_segment_ddp(
348421 allocated_batch_sizes = allocate_batch_size (task_configs , replay_buffers , alpha = 1.0 , clip_scale = clip_scale )
349422 if rank == 0 :
350423 logging .info (f"Dynamically allocated batch sizes: { allocated_batch_sizes } " )
424+ # Assign the corresponding batch size to each task config
351425 for i , cfg in enumerate (task_configs ):
352- cfg .policy .batch_size = allocated_batch_sizes
426+ task_id = cfg .policy .task_id
427+ if isinstance (allocated_batch_sizes , dict ):
428+ cfg .policy .batch_size = allocated_batch_sizes .get (task_id , cfg .policy .batch_size )
429+ elif isinstance (allocated_batch_sizes , list ):
430+ # Use the index in the list or task_id as fallback
431+ cfg .policy .batch_size = allocated_batch_sizes [i ] if i < len (allocated_batch_sizes ) else cfg .policy .batch_size
432+ else :
433+ logging .warning (f"Unexpected type for allocated_batch_sizes: { type (allocated_batch_sizes )} " )
434+ # Also update the policy config (use the full list for compatibility)
353435 policy ._cfg .batch_size = allocated_batch_sizes
354436
355437 # --- 2. Data Collection and Evaluation for each task on this rank ---
@@ -505,7 +587,15 @@ def train_unizero_multitask_balance_segment_ddp(
505587 train_data_list = []
506588 total_envstep = sum (c .envstep for c in collectors )
507589 for cfg , replay_buffer in zip (unsolved_cfgs , unsolved_buffers ):
508- batch_size = cfg .policy .batch_size [cfg .policy .task_id ]
590+ # Handle batch_size whether it's an int, list, or dict
591+ if isinstance (cfg .policy .batch_size , (list , tuple )):
592+ batch_size = cfg .policy .batch_size [cfg .policy .task_id ]
593+ elif isinstance (cfg .policy .batch_size , dict ):
594+ batch_size = cfg .policy .batch_size [cfg .policy .task_id ]
595+ else :
596+ # batch_size is already an integer
597+ batch_size = cfg .policy .batch_size
598+
509599 if replay_buffer .get_num_of_transitions () >= batch_size :
510600 train_data = replay_buffer .sample (batch_size , policy )
511601 train_data .append (cfg .policy .task_id )
0 commit comments