@@ -93,6 +93,7 @@ def create_config(
9393 """
9494 return EasyDict (dict (
9595 env = dict (
96+ frame_skip = 1 , # TODO
9697 stop_value = int (1e6 ),
9798 env_id = env_id ,
9899 observation_shape = (3 , 64 , 64 ),
@@ -162,8 +163,8 @@ def create_config(
162163 # use_priority=False, # TODO=====
163164 priority_prob_alpha = 1 ,
164165 priority_prob_beta = 1 ,
165- # encoder_type='vit',
166- encoder_type = 'resnet' ,
166+ encoder_type = 'vit' ,
167+ # encoder_type='resnet',
167168 use_normal_head = True ,
168169 use_softmoe_head = False ,
169170 use_moe_head = False ,
@@ -195,7 +196,8 @@ def create_config(
195196 # use_adaptive_entropy_weight=False,
196197
197198 # (float) 自适应alpha优化器的学习率
198- adaptive_entropy_alpha_lr = 1e-4 ,
199+ # adaptive_entropy_alpha_lr=1e-4,
200+ adaptive_entropy_alpha_lr = 1e-3 ,
199201 target_entropy_start_ratio = 0.98 ,
200202 # target_entropy_end_ratio =0.9, # TODO=====
201203 # target_entropy_end_ratio =0.7,
@@ -289,15 +291,18 @@ def generate_configs(
289291 # --- Experiment Name Template ---
290292 # Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name.
291293 # benchmark_tag = "data_unizero_mt_refactor1010_debug" # e.g., unizero_atari_mt_20250612
292- benchmark_tag = "data_unizero_mt_refactor1012 " # e.g., unizero_atari_mt_20250612
294+ benchmark_tag = "data_unizero_mt_refactor1024 " # e.g., unizero_atari_mt_20250612
293295
294296 # model_tag = f"vit-small_moe8_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head"
295297 # model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
296298
297299 # model_tag = f"vit_prior_alpha-100k-098-07_encoder-100k-30-10_moe8_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
298300
299301 # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-100k-098-07_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
300- model_tag = f"resnet_tran-nlayer{ num_layers } _moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{ buffer_reanalyze_freq } _label-smooth_head-inner-ln"
302+ model_tag = f"vit_tran-nlayer{ num_layers } _moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{ buffer_reanalyze_freq } _label-smooth_head-inner-ln"
303+
304+ # model_tag = f"resnet_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
305+
301306 # model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-150k-098-05_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
302307
303308 exp_name_prefix = f'{ benchmark_tag } /atari_{ len (env_id_list )} games_{ model_tag } _seed{ seed } /'
@@ -309,7 +314,10 @@ def generate_configs(
309314 buffer_reanalyze_freq , reanalyze_batch_size , reanalyze_partition , num_segments , total_batch_size , num_layers
310315 )
311316 config .policy .task_id = task_id
312- config .exp_name = exp_name_prefix + f"{ env_id .split ('NoFrameskip' )[0 ]} _seed{ seed } "
317+ # --- MODIFIED LINE ---
318+ # Correctly extract the game name from 'ALE/GameName-v5' format.
319+ game_name = env_id .split ('/' )[1 ].split ('-' )[0 ]
320+ config .exp_name = exp_name_prefix + f"{ game_name } _seed{ seed } "
313321 configs .append ([task_id , [config , create_env_manager ()]])
314322 return configs
315323
@@ -348,6 +356,8 @@ def create_env_manager() -> EasyDict:
348356 export CUDA_VISIBLE_DEVICES=4,5,6,7
349357
350358 cd /path/to/your/project/
359+ /mnt/shared-storage-user/puyuan/lz/bin/python -m torch.distributed.launch --nproc_per_node=4 --master_port=29502 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/shared-storage-user/puyuan/code_20250828/LightZero/log/20251024_vit_nlayer4_alpha-100k-098-05.log
360+
351361 python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251012_resnet_nlayer4_alpha-100k-098-05.log
352362 /path/to/this/script.py 2>&1 | tee /path/to/your/log/file.log
353363 """
@@ -370,22 +380,23 @@ def create_env_manager() -> EasyDict:
370380 max_env_step = int (5e6 ) # TODO
371381 reanalyze_ratio = 0.0
372382
383+ # --- MODIFIED SECTION: Standardized env_id_list formats ---
373384 if num_games == 3 :
374- env_id_list = ['PongNoFrameskip-v4 ' , 'MsPacmanNoFrameskip-v4 ' , 'SeaquestNoFrameskip-v4 ' ]
385+ env_id_list = ['ALE/Pong-v5 ' , 'ALE/MsPacman-v5 ' , 'ALE/Seaquest-v5 ' ]
375386 elif num_games == 8 :
376387 env_id_list = [
377- 'PongNoFrameskip-v4 ' , 'MsPacmanNoFrameskip-v4 ' , 'SeaquestNoFrameskip-v4 ' , 'BoxingNoFrameskip-v4 ' ,
378- 'AlienNoFrameskip-v4 ' , 'ChopperCommandNoFrameskip-v4 ' , 'HeroNoFrameskip-v4 ' , 'RoadRunnerNoFrameskip-v4 ' ,
388+ 'ALE/Pong-v5 ' , 'ALE/MsPacman-v5 ' , 'ALE/Seaquest-v5 ' , 'ALE/Boxing-v5 ' ,
389+ 'ALE/Alien-v5 ' , 'ALE/ChopperCommand-v5 ' , 'ALE/Hero-v5 ' , 'ALE/RoadRunner-v5 ' ,
379390 ]
380391 elif num_games == 26 :
381392 env_id_list = [
382- 'PongNoFrameskip-v4 ' , 'MsPacmanNoFrameskip-v4 ' , 'SeaquestNoFrameskip-v4 ' , 'BoxingNoFrameskip-v4 ' ,
383- 'AlienNoFrameskip-v4 ' , 'ChopperCommandNoFrameskip-v4 ' , 'HeroNoFrameskip-v4 ' , 'RoadRunnerNoFrameskip-v4 ' ,
384- 'AmidarNoFrameskip-v4 ' , 'AssaultNoFrameskip-v4 ' , 'AsterixNoFrameskip-v4 ' , 'BankHeistNoFrameskip-v4 ' ,
385- 'BattleZoneNoFrameskip-v4 ' , 'CrazyClimberNoFrameskip-v4 ' , 'DemonAttackNoFrameskip-v4 ' , 'FreewayNoFrameskip-v4 ' ,
386- 'FrostbiteNoFrameskip-v4 ' , 'GopherNoFrameskip-v4 ' , 'JamesbondNoFrameskip-v4 ' , 'KangarooNoFrameskip-v4 ' ,
387- 'KrullNoFrameskip-v4 ' , 'KungFuMasterNoFrameskip-v4 ' , 'PrivateEyeNoFrameskip-v4 ' , 'UpNDownNoFrameskip-v4 ' ,
388- 'QbertNoFrameskip-v4 ' , 'BreakoutNoFrameskip-v4 ' ,
393+ 'ALE/Pong-v5 ' , 'ALE/MsPacman-v5 ' , 'ALE/Seaquest-v5 ' , 'ALE/Boxing-v5 ' ,
394+ 'ALE/Alien-v5 ' , 'ALE/ChopperCommand-v5 ' , 'ALE/Hero-v5 ' , 'ALE/RoadRunner-v5 ' ,
395+ 'ALE/Amidar-v5 ' , 'ALE/Assault-v5 ' , 'ALE/Asterix-v5 ' , 'ALE/BankHeist-v5 ' ,
396+ 'ALE/BattleZone-v5 ' , 'ALE/CrazyClimber-v5 ' , 'ALE/DemonAttack-v5 ' , 'ALE/Freeway-v5 ' ,
397+ 'ALE/Frostbite-v5 ' , 'ALE/Gopher-v5 ' , 'ALE/Jamesbond-v5 ' , 'ALE/Kangaroo-v5 ' ,
398+ 'ALE/Krull-v5 ' , 'ALE/KungFuMaster-v5 ' , 'ALE/PrivateEye-v5 ' , 'ALE/UpNDown-v5 ' ,
399+ 'ALE/Qbert-v5 ' , 'ALE/Breakout-v5 ' ,
389400 ]
390401 else :
391402 raise ValueError (f"Unsupported number of environments: { num_games } " )
0 commit comments