-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path00_meta_frozen_lake.py
More file actions
139 lines (127 loc) · 4.77 KB
/
00_meta_frozen_lake.py
File metadata and controls
139 lines (127 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from argparse import ArgumentParser
import amago
from amago.envs.builtin.toy_gym import MetaFrozenLake
from amago.envs import AMAGOEnv
from amago.loading import DiskTrajDataset
from amago import cli_utils
def add_cli(parser):
parser.add_argument(
"--seq_model",
type=str,
choices=["ff", "transformer", "rnn", "mamba"],
required=True,
)
parser.add_argument("--run_name", type=str, required=True)
parser.add_argument("--buffer_dir", type=str, required=True)
parser.add_argument("--log", action="store_true")
parser.add_argument("--trials", type=int, default=1)
parser.add_argument("--lake_size", type=int, default=5)
parser.add_argument("--k_episodes", type=int, default=10)
parser.add_argument("--hard_mode", action="store_true")
parser.add_argument("--recover_mode", action="store_true")
parser.add_argument("--slip_chance", type=float, default=0.0)
parser.add_argument(
"--max_episode_steps",
type=int,
default=None,
help="Max steps per attempt. Default: N² (standard) or 2*N² (hard).",
)
parser.add_argument(
"--hide_k_progress",
action="store_true",
help="Hide current_k/k_episodes from observations (for length extrapolation tests).",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=None,
help="Training sequence length. Default: max_episode_steps * k_episodes (full trajectory).",
)
return parser
if __name__ == "__main__":
parser = ArgumentParser()
add_cli(parser)
args = parser.parse_args()
if args.log:
import wandb
lake_kwargs = dict(
size=args.lake_size,
k_episodes=args.k_episodes,
hard_mode=args.hard_mode,
recover_mode=args.recover_mode,
max_episode_steps=args.max_episode_steps,
show_k_progress=not args.hide_k_progress,
slip_chance=args.slip_chance,
)
max_ep_steps = MetaFrozenLake(**lake_kwargs).max_episode_steps
max_rollout_length = max_ep_steps * args.k_episodes
max_seq_len = args.max_seq_len or max_rollout_length
config = {}
# configure trajectory encoder (seq2seq memory model)
traj_encoder_type = cli_utils.switch_traj_encoder(
config,
arch=args.seq_model,
memory_size=128,
layers=3,
)
# configure timestep encoder
tstep_encoder_type = cli_utils.switch_tstep_encoder(
config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
)
# we're using the default exploration strategy but being overly verbose about it for the example
exploration_wrapper_type = cli_utils.switch_exploration(
config,
strategy="egreedy",
eps_start=1.0,
eps_end=0.05,
steps_anneal=1_000_000,
randomize_eps=True,
)
agent_type = cli_utils.switch_agent(config, "agent", tau=0.004)
cli_utils.use_config(config)
group_name = f"{args.run_name}_{args.seq_model}"
for trial in range(args.trials):
run_name = group_name + f"_trial_{trial}"
# create a dataset on disk. envs will write finished episodes here
dset = DiskTrajDataset(
dset_root=args.buffer_dir, dset_name=run_name, dset_max_size=12_500
)
# save checkpoints alongside the buffer
ckpt_dir = args.buffer_dir
# wrap environment
make_env = lambda: AMAGOEnv(
MetaFrozenLake(**lake_kwargs),
env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
+ ("_hard" if args.hard_mode else "_easy")
+ ("_recover" if args.recover_mode else "_reset"),
)
experiment = amago.Experiment(
make_train_env=make_env,
make_val_env=make_env,
max_seq_len=max_seq_len,
traj_save_len=max_rollout_length,
dataset=dset,
ckpt_base_dir=args.buffer_dir,
agent_type=agent_type,
exploration_wrapper_type=exploration_wrapper_type,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
run_name=run_name,
dloader_workers=10,
log_to_wandb=args.log,
wandb_group_name=group_name,
epochs=700 if not args.hard_mode else 900,
parallel_actors=32,
train_timesteps_per_epoch=max_rollout_length,
train_batches_per_epoch=1000,
val_interval=20,
val_timesteps_per_epoch=max_rollout_length * 2,
ckpt_interval=200,
env_mode="sync",
)
experiment.start()
experiment.learn()
experiment.evaluate_test(make_env, timesteps=10_000)
experiment.delete_buffer_from_disk()
if args.log:
wandb.finish()