@@ -19,11 +19,27 @@ def add_cli(parser):
1919 parser .add_argument ("--log" , action = "store_true" )
2020 parser .add_argument ("--trials" , type = int , default = 1 )
2121 parser .add_argument ("--lake_size" , type = int , default = 5 )
22- parser .add_argument ("--k_episodes" , type = int , default = 15 )
22+ parser .add_argument ("--k_episodes" , type = int , default = 10 )
2323 parser .add_argument ("--hard_mode" , action = "store_true" )
2424 parser .add_argument ("--recover_mode" , action = "store_true" )
25- parser .add_argument ("--max_rollout_length" , type = int , default = 512 )
26- parser .add_argument ("--max_seq_len" , type = int , default = 512 )
25+ parser .add_argument ("--slip_chance" , type = float , default = 0.0 )
26+ parser .add_argument (
27+ "--max_episode_steps" ,
28+ type = int ,
29+ default = None ,
30+ help = "Max steps per attempt. Default: N² (standard) or 2*N² (hard)." ,
31+ )
32+ parser .add_argument (
33+ "--hide_k_progress" ,
34+ action = "store_true" ,
35+ help = "Hide current_k/k_episodes from observations (for length extrapolation tests)." ,
36+ )
37+ parser .add_argument (
38+ "--max_seq_len" ,
39+ type = int ,
40+ default = None ,
41+ help = "Training sequence length. Default: max_episode_steps * k_episodes (full trajectory)." ,
42+ )
2743 return parser
2844
2945
@@ -35,6 +51,19 @@ def add_cli(parser):
3551 if args .log :
3652 import wandb
3753
54+ lake_kwargs = dict (
55+ size = args .lake_size ,
56+ k_episodes = args .k_episodes ,
57+ hard_mode = args .hard_mode ,
58+ recover_mode = args .recover_mode ,
59+ max_episode_steps = args .max_episode_steps ,
60+ show_k_progress = not args .hide_k_progress ,
61+ slip_chance = args .slip_chance ,
62+ )
63+ max_ep_steps = MetaFrozenLake (** lake_kwargs ).max_episode_steps
64+ max_rollout_length = max_ep_steps * args .k_episodes
65+ max_seq_len = args .max_seq_len or max_rollout_length
66+
3867 config = {}
3968 # configure trajectory encoder (seq2seq memory model)
4069 traj_encoder_type = cli_utils .switch_traj_encoder (
@@ -47,7 +76,6 @@ def add_cli(parser):
4776 tstep_encoder_type = cli_utils .switch_tstep_encoder (
4877 config , arch = "ff" , n_layers = 1 , d_hidden = 128 , d_output = 64 , normalize_inputs = False
4978 )
50-
5179 # we're using the default exploration strategy but being overly verbose about it for the example
5280 exploration_wrapper_type = cli_utils .switch_exploration (
5381 config ,
@@ -70,28 +98,21 @@ def add_cli(parser):
7098 )
7199 # save checkpoints alongside the buffer
72100 ckpt_dir = args .buffer_dir
73-
74101 # wrap environment
75102 make_env = lambda : AMAGOEnv (
76- MetaFrozenLake (
77- k_episodes = args .k_episodes ,
78- size = args .lake_size ,
79- hard_mode = args .hard_mode ,
80- recover_mode = args .recover_mode ,
81- ),
103+ MetaFrozenLake (** lake_kwargs ),
82104 env_name = f"meta_frozen_lake_k{ args .k_episodes } _{ args .lake_size } x{ args .lake_size } "
83105 + ("_hard" if args .hard_mode else "_easy" )
84106 + ("_recover" if args .recover_mode else "_reset" ),
85107 )
86108
87- # create `Experiment`
88109 experiment = amago .Experiment (
89110 make_train_env = make_env ,
90111 make_val_env = make_env ,
91- max_seq_len = args . max_seq_len ,
92- traj_save_len = args . max_rollout_length ,
112+ max_seq_len = max_seq_len ,
113+ traj_save_len = max_rollout_length ,
93114 dataset = dset ,
94- ckpt_base_dir = ckpt_dir ,
115+ ckpt_base_dir = args . buffer_dir ,
95116 agent_type = agent_type ,
96117 exploration_wrapper_type = exploration_wrapper_type ,
97118 tstep_encoder_type = tstep_encoder_type ,
@@ -102,18 +123,17 @@ def add_cli(parser):
102123 wandb_group_name = group_name ,
103124 epochs = 700 if not args .hard_mode else 900 ,
104125 parallel_actors = 32 ,
105- train_timesteps_per_epoch = args . max_rollout_length ,
126+ train_timesteps_per_epoch = max_rollout_length ,
106127 train_batches_per_epoch = 1000 ,
107128 val_interval = 20 ,
108- val_timesteps_per_epoch = args . max_rollout_length * 2 ,
129+ val_timesteps_per_epoch = max_rollout_length * 2 ,
109130 ckpt_interval = 200 ,
110131 env_mode = "sync" ,
111132 )
112133
113- # start experiment (build envs, policies, etc.)
114134 experiment .start ()
115- # run training
116135 experiment .learn ()
117136 experiment .evaluate_test (make_env , timesteps = 10_000 )
118137 experiment .delete_buffer_from_disk ()
119- wandb .finish ()
138+ if args .log :
139+ wandb .finish ()
0 commit comments