-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path05_dark_key_door.py
More file actions
128 lines (120 loc) · 4.5 KB
/
05_dark_key_door.py
File metadata and controls
128 lines (120 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from argparse import ArgumentParser
import wandb
from amago.envs.builtin.toy_gym import RoomKeyDoor
from amago.envs import AMAGOEnv
from amago import cli_utils
def add_cli(parser):
parser.add_argument(
"--k_episodes",
type=int,
default=8,
help="Number of episodes per meta-rollout. Effective sequence length = k_episodes * episode_length.",
)
parser.add_argument(
"--room_size",
type=int,
default=8,
help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
)
parser.add_argument(
"--episode_length",
type=int,
default=50,
help="Maximum length of a single episode in the environment.",
)
parser.add_argument(
"--light_room_observation",
action="store_true",
help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
)
parser.add_argument(
"--randomize_actions",
action="store_true",
help="Randomize the agent's action space to make the task harder.",
)
parser.add_argument(
"--finite_horizon",
action="store_true",
help="Use finite-horizon mode: include time in observations and signal meta-done as terminated. Default is infinite-horizon (no time in obs, meta-done as truncated).",
)
return parser
if __name__ == "__main__":
parser = ArgumentParser()
cli_utils.add_common_cli(parser)
add_cli(parser)
args = parser.parse_args()
config = {}
tstep_encoder_type = cli_utils.switch_tstep_encoder(
config,
arch="ff",
n_layers=2,
d_hidden=128,
d_output=64,
specify_obs_keys=["observed", "prev_action", "prev_reward"],
hide_rl2s=True,
normalize_inputs=False,
)
traj_encoder_type = cli_utils.switch_traj_encoder(
config,
arch=args.traj_encoder,
memory_size=args.memory_size,
layers=args.memory_layers,
pos_emb="rope",
)
agent_type = cli_utils.switch_agent(
config, args.agent_type, reward_multiplier=100.0
)
horizon_type = "finite" if args.finite_horizon else "infinite"
dummy_env = RoomKeyDoor(
size=args.room_size,
max_episode_steps=args.episode_length,
k_episodes=args.k_episodes,
horizon_type=horizon_type,
)
meta_horizon = dummy_env.meta_horizon
args.timesteps_per_epoch = meta_horizon
# the fancier exploration schedule mentioned in the appendix can help
# when the domain is a true meta-RL problem and the "horizon" time limit
# (above) is actually relevant for resetting the task.
exploration_type = cli_utils.switch_exploration(
config, "bilevel", steps_anneal=500_000, rollout_horizon=meta_horizon
)
cli_utils.use_config(config, args.configs)
group_name = f"{args.run_name}_dark_key_door"
for trial in range(args.trials):
run_name = group_name + f"_trial_{trial}"
make_train_env = lambda: AMAGOEnv(
env=RoomKeyDoor(
size=args.room_size,
max_episode_steps=args.episode_length,
k_episodes=args.k_episodes,
dark=not args.light_room_observation,
randomize_actions=args.randomize_actions,
horizon_type=horizon_type,
),
env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}-{horizon_type}",
)
experiment = cli_utils.create_experiment_from_cli(
args,
agent_type=agent_type,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
make_train_env=make_train_env,
make_val_env=make_train_env,
max_seq_len=meta_horizon,
traj_save_len=meta_horizon * 10,
group_name=group_name,
run_name=run_name,
val_timesteps_per_epoch=meta_horizon * 4,
exploration_wrapper_type=exploration_type,
stagger_traj_file_lengths=False,
wandb_project="z-room-key-door",
)
experiment = cli_utils.switch_async_mode(experiment, args.mode)
experiment.start()
if args.ckpt is not None:
experiment.load_checkpoint(args.ckpt)
experiment.learn()
experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
experiment.delete_buffer_from_disk()
wandb.finish()