Skip to content

Commit 54c25ab

Browse files
authored
Merge pull request #104 from UT-Austin-RPL/kd
key door
2 parents 0974781 + 0cbb0fc commit 54c25ab

2 files changed

Lines changed: 147 additions & 22 deletions

File tree

amago/envs/builtin/toy_gym.py

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ class MetaFrozenLake(gym.Env):
4444
Defaults to True.
4545
slip_chance: Probability that a movement action is replaced by
4646
a no-op (agent stays in place). Defaults to 0.0.
47+
use_truncation_for_k_limit: If True, signal the k-episode limit
48+
with ``(terminated=False, truncated=True)`` instead of the
49+
default ``(terminated=True, truncated=False)``. Routing the
50+
meta-trial cap through ``truncated`` keeps ``Batch.dones``
51+
(and therefore the n-step bootstrap) live at the k-limit, so
52+
the value function learns the true infinite-horizon
53+
discounted return rather than collapsing to zero at the
54+
training horizon. Essential when ``show_k_progress=False``
55+
and the test rollout uses a different ``k_episodes`` than
56+
training. Defaults to False.
4757
"""
4858

4959
def __init__(
@@ -55,6 +65,7 @@ def __init__(
5565
max_episode_steps: int | None = None,
5666
show_k_progress: bool = True,
5767
slip_chance: float = 0.0,
68+
use_truncation_for_k_limit: bool = False,
5869
):
5970
self.size = size
6071
self.k_episodes = k_episodes
@@ -70,6 +81,7 @@ def __init__(
7081
)
7182
self.show_k_progress = show_k_progress
7283
self.slip_chance = slip_chance
84+
self.use_truncation_for_k_limit = use_truncation_for_k_limit
7385
self.reset()
7486

7587
def reset(self, *args, **kwargs):
@@ -150,8 +162,12 @@ def step(self, action):
150162
else:
151163
next_state, info = self.make_obs(False), {}
152164

153-
terminated = self.current_k >= self.k_episodes
154-
return next_state, reward, terminated, False, info
165+
end_of_meta_trial = self.current_k >= self.k_episodes
166+
if self.use_truncation_for_k_limit:
167+
terminated, truncated = False, end_of_meta_trial
168+
else:
169+
terminated, truncated = end_of_meta_trial, False
170+
return next_state, reward, terminated, truncated, info
155171

156172
def render(self, *args, **kwargs):
157173
render_map = copy.deepcopy(self.active_map)
@@ -177,7 +193,11 @@ class RoomKeyDoor(gym.Env):
177193
meta_rollout_horizon: The agent has this many timsteps to adapt to
178194
each world layout. The best solution is to infer the key and door locations
179195
and then solve the task as many times as possible within this time limit.
180-
Defaults to 500.
196+
Defaults to 500. Ignored if k_episodes is set.
197+
k_episodes: If set, the meta-rollout lasts exactly this many episodes
198+
instead of a fixed number of timesteps. The effective maximum
199+
sequence length becomes k_episodes * max_episode_steps. Defaults
200+
to None (use meta_rollout_horizon).
181201
start_location: The starting location of the agent. Defaults to
182202
"random". Can also be set to a specific (x, y) coordinate.
183203
key_location: The location of the key. Defaults to "random". Can
@@ -186,6 +206,12 @@ class RoomKeyDoor(gym.Env):
186206
Can also be set to a specific (x, y) coordinate.
187207
randomize_actions: If True, the discrete action indices are
188208
randomly shuffled on each reset. Defaults to False.
209+
horizon_type: Either "finite" or "infinite". In "finite" mode, the
210+
normalized episode timestep is included in the observation and
211+
meta-done is signaled as terminated (the agent knows the horizon).
212+
In "infinite" mode, the timestep is excluded from the observation
213+
and meta-done is signaled as truncated (the agent does not know
214+
when the meta-rollout will end). Defaults to "infinite".
189215
"""
190216

191217
def __init__(
@@ -194,24 +220,56 @@ def __init__(
194220
size: int = 9,
195221
max_episode_steps: int = 50,
196222
meta_rollout_horizon: int = 500,
223+
k_episodes: int | None = None,
197224
start_location: tuple[int, int] | str = "random",
198225
key_location: tuple[int, int] | str = "random",
199226
goal_location: tuple[int, int] | str = "random",
200227
randomize_actions: bool = False,
228+
horizon_type: str = "infinite",
201229
):
230+
assert horizon_type in (
231+
"finite",
232+
"infinite",
233+
), f"horizon_type must be 'finite' or 'infinite', got '{horizon_type}'"
202234
self.dark = dark
203235
self.size = size
204236
self.H = max_episode_steps
205-
self.H_meta = meta_rollout_horizon
206-
self.observation_space = gym.spaces.Box(
207-
low=0.0, high=1.0, shape=(4 if self.dark else 8,)
237+
self.k_episodes = k_episodes
238+
if k_episodes is not None:
239+
self.H_meta = k_episodes * max_episode_steps
240+
else:
241+
self.H_meta = meta_rollout_horizon
242+
self._meta_rollout_horizon = meta_rollout_horizon
243+
self.horizon_type = horizon_type
244+
n_actions = 5
245+
time_dim = 1 if self.horizon_type == "finite" else 0
246+
obs_dim = (3 if self.dark else 7) + time_dim
247+
max_k = (
248+
k_episodes
249+
if k_episodes is not None
250+
else meta_rollout_horizon // max_episode_steps
208251
)
209-
self.action_space = gym.spaces.Discrete(5)
252+
self.observation_space = gym.spaces.Dict(
253+
{
254+
"observed": gym.spaces.Box(low=0.0, high=1.0, shape=(obs_dim,)),
255+
"episode_id": gym.spaces.Box(0, max_k, shape=(), dtype=np.int32),
256+
"prev_action": gym.spaces.Box(low=0.0, high=1.0, shape=(n_actions,)),
257+
"prev_reward": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)),
258+
}
259+
)
260+
self.action_space = gym.spaces.Discrete(n_actions)
210261
self.goal_location = goal_location
211262
self.key_location = key_location
212263
self.start_location = start_location
213264
self.randomize_actions = randomize_actions
214265

266+
@property
267+
def meta_horizon(self) -> int:
268+
"""Max trajectory length including soft reset steps between episodes."""
269+
if self.k_episodes is not None:
270+
return self.k_episodes * (self.H + 1) - 1
271+
return self._meta_rollout_horizon
272+
215273
def reset_same_task(self):
216274
self.pos = self.start
217275
self.episode_time = 0
@@ -220,6 +278,10 @@ def reset_same_task(self):
220278
def reset(self, *args, **kwargs):
221279
self.generate_task()
222280
self.global_time = 0
281+
self.episode_number = 0
282+
self.episode_return = 0.0
283+
self._prev_action = np.zeros(self.action_space.n, dtype=np.float32)
284+
self._prev_reward = np.array([0.0], dtype=np.float32)
223285
self.reset_same_task()
224286
self.reset_next_step = False
225287
return self.obs(), {}
@@ -246,7 +308,17 @@ def generate_task(self):
246308

247309
def step(self, action: int):
248310
self.global_time += 1
311+
info = {}
312+
249313
if self.reset_next_step:
314+
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.episode_number} Return"] = (
315+
self.episode_return
316+
)
317+
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.episode_number} Length"] = (
318+
self.episode_time
319+
)
320+
self.episode_number += 1
321+
self.episode_return = 0.0
250322
self.reset_same_task()
251323
self.reset_next_step = False
252324
reward = 0.0
@@ -262,19 +334,47 @@ def step(self, action: int):
262334
self.has_key = True
263335
if self.episode_time >= self.H:
264336
self.reset_next_step = True
265-
metadone = self.global_time >= self.H_meta
266-
return self.obs(), reward, metadone, metadone, {}
337+
self.episode_return += reward
338+
339+
action_onehot = np.zeros(self.action_space.n, dtype=np.float32)
340+
action_onehot[action] = 1.0
341+
self._prev_action = action_onehot
342+
self._prev_reward = np.array([reward], dtype=np.float32)
343+
344+
if self.k_episodes is not None:
345+
completed = self.episode_number + (1 if self.reset_next_step else 0)
346+
metadone = completed >= self.k_episodes
347+
else:
348+
metadone = self.global_time >= self.H_meta
349+
if metadone and self.reset_next_step:
350+
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.episode_number} Return"] = (
351+
self.episode_return
352+
)
353+
info[f"{AMAGO_ENV_LOG_PREFIX}Episode {self.episode_number} Length"] = (
354+
self.episode_time
355+
)
356+
if self.horizon_type == "finite":
357+
terminated, truncated = metadone, False
358+
else:
359+
terminated, truncated = False, metadone
360+
return self.obs(), reward, terminated, truncated, info
267361

268362
def obs(self):
269363
x, y = self.pos
270364
norm = lambda j: float(j) / self.size
271-
# time and has_key keep this fully observed
272-
base = [norm(x), norm(y), self.has_key, float(self.episode_time) / self.H]
365+
base = [norm(x), norm(y), self.has_key]
366+
if self.horizon_type == "finite":
367+
base.append(float(self.episode_time) / self.H)
273368
if not self.dark:
274369
goal_x, goal_y = self.goal
275370
key_x, key_y = self.key
276371
base += [norm(goal_x), norm(goal_y), norm(key_x), norm(key_y)]
277-
return np.array(base, dtype=np.float32)
372+
return {
373+
"observed": np.array(base, dtype=np.float32),
374+
"episode_id": np.int32(self.episode_number),
375+
"prev_action": self._prev_action.copy(),
376+
"prev_reward": self._prev_reward.copy(),
377+
}
278378

279379
def render(self, *args, **kwargs):
280380
img = [["." for _ in range(self.size)] for _ in range(self.size)]

examples/05_dark_key_door.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
def add_cli(parser):
1111
parser.add_argument(
12-
"--meta_horizon",
12+
"--k_episodes",
1313
type=int,
14-
default=500,
15-
help="Total meta-adaptation timestep budget for the agent to explore the same room layout.",
14+
default=8,
15+
help="Number of episodes per meta-rollout. Effective sequence length = k_episodes * episode_length.",
1616
)
1717
parser.add_argument(
1818
"--room_size",
@@ -36,6 +36,11 @@ def add_cli(parser):
3636
action="store_true",
3737
help="Randomize the agent's action space to make the task harder.",
3838
)
39+
parser.add_argument(
40+
"--finite_horizon",
41+
action="store_true",
42+
help="Use finite-horizon mode: include time in observations and signal meta-done as terminated. Default is infinite-horizon (no time in obs, meta-done as truncated).",
43+
)
3944
return parser
4045

4146

@@ -47,22 +52,39 @@ def add_cli(parser):
4752

4853
config = {}
4954
tstep_encoder_type = cli_utils.switch_tstep_encoder(
50-
config, arch="ff", n_layers=2, d_hidden=128, d_output=64
55+
config,
56+
arch="ff",
57+
n_layers=2,
58+
d_hidden=128,
59+
d_output=64,
60+
specify_obs_keys=["observed", "prev_action", "prev_reward"],
61+
hide_rl2s=True,
62+
normalize_inputs=False,
5163
)
5264
traj_encoder_type = cli_utils.switch_traj_encoder(
5365
config,
5466
arch=args.traj_encoder,
5567
memory_size=args.memory_size,
5668
layers=args.memory_layers,
69+
pos_emb="rope",
5770
)
5871
agent_type = cli_utils.switch_agent(
5972
config, args.agent_type, reward_multiplier=100.0
6073
)
74+
horizon_type = "finite" if args.finite_horizon else "infinite"
75+
dummy_env = RoomKeyDoor(
76+
size=args.room_size,
77+
max_episode_steps=args.episode_length,
78+
k_episodes=args.k_episodes,
79+
horizon_type=horizon_type,
80+
)
81+
meta_horizon = dummy_env.meta_horizon
82+
args.timesteps_per_epoch = meta_horizon
6183
# the fancier exploration schedule mentioned in the appendix can help
6284
# when the domain is a true meta-RL problem and the "horizon" time limit
6385
# (above) is actually relevant for resetting the task.
6486
exploration_type = cli_utils.switch_exploration(
65-
config, "bilevel", steps_anneal=500_000, rollout_horizon=args.meta_horizon
87+
config, "bilevel", steps_anneal=500_000, rollout_horizon=meta_horizon
6688
)
6789
cli_utils.use_config(config, args.configs)
6890

@@ -73,11 +95,12 @@ def add_cli(parser):
7395
env=RoomKeyDoor(
7496
size=args.room_size,
7597
max_episode_steps=args.episode_length,
76-
meta_rollout_horizon=args.meta_horizon,
98+
k_episodes=args.k_episodes,
7799
dark=not args.light_room_observation,
78100
randomize_actions=args.randomize_actions,
101+
horizon_type=horizon_type,
79102
),
80-
env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}",
103+
env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}-{horizon_type}",
81104
)
82105
experiment = cli_utils.create_experiment_from_cli(
83106
args,
@@ -86,12 +109,14 @@ def add_cli(parser):
86109
traj_encoder_type=traj_encoder_type,
87110
make_train_env=make_train_env,
88111
make_val_env=make_train_env,
89-
max_seq_len=args.meta_horizon,
90-
traj_save_len=args.meta_horizon,
112+
max_seq_len=meta_horizon,
113+
traj_save_len=meta_horizon * 10,
91114
group_name=group_name,
92115
run_name=run_name,
93-
val_timesteps_per_epoch=args.meta_horizon * 4,
116+
val_timesteps_per_epoch=meta_horizon * 4,
94117
exploration_wrapper_type=exploration_type,
118+
stagger_traj_file_lengths=False,
119+
wandb_project="z-room-key-door",
95120
)
96121
experiment = cli_utils.switch_async_mode(experiment, args.mode)
97122
experiment.start()

0 commit comments

Comments
 (0)