Skip to content

Commit 9b45b8c

Browse files
committed
update frozen lake
1 parent 441d1ed commit 9b45b8c

2 files changed

Lines changed: 74 additions & 27 deletions

File tree

amago/envs/builtin/toy_gym.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ class MetaFrozenLake(gym.Env):
3636
recover_mode: If False, falling through the ice terminates the
3737
episode. If True, the agent is allowed to recover to its
3838
previous position but receives a penalty. Defaults to False.
39+
max_episode_steps: Maximum steps per attempt before a forced
40+
soft reset. Defaults to N² (standard) or 2*N² (hard).
41+
show_k_progress: If True, include current_k / k_episodes in
42+
observations. Set to False to hide trial progress (useful
43+
for testing length extrapolation with different k values).
44+
Defaults to True.
45+
slip_chance: Probability that a movement action is replaced by
46+
a no-op (agent stays in place). Defaults to 0.0.
3947
"""
4048

4149
def __init__(
@@ -44,13 +52,24 @@ def __init__(
4452
k_episodes: int = 10,
4553
hard_mode: bool = False,
4654
recover_mode: bool = False,
55+
max_episode_steps: int | None = None,
56+
show_k_progress: bool = True,
57+
slip_chance: float = 0.0,
4758
):
4859
self.size = size
4960
self.k_episodes = k_episodes
5061
self.action_space = gym.spaces.Discrete(5)
5162
self.observation_space = gym.spaces.Box(shape=(4,), low=0.0, high=1.0)
5263
self.hard_mode = hard_mode
5364
self.recover_mode = recover_mode
65+
base_steps = size * size * (2 if hard_mode else 1)
66+
if slip_chance > 0:
67+
base_steps = int(base_steps / (1.0 - slip_chance))
68+
self.max_episode_steps = (
69+
max_episode_steps if max_episode_steps is not None else base_steps
70+
)
71+
self.show_k_progress = show_k_progress
72+
self.slip_chance = slip_chance
5473
self.reset()
5574

5675
def reset(self, *args, **kwargs):
@@ -69,30 +88,30 @@ def make_obs(self, reset_signal: bool):
6988
y = min(max(self.y + random.choice([-1, 0, 1]), 0), self.size - 1)
7089
else:
7190
x, y = self.x, self.y
91+
k_obs = self.current_k / self.k_episodes if self.show_k_progress else 0.0
7292
return np.array(
73-
[
74-
x / self.size,
75-
y / self.size,
76-
reset_signal,
77-
self.current_k / self.k_episodes,
78-
],
93+
[x / self.size, y / self.size, reset_signal, k_obs],
7994
dtype=np.float32,
8095
)
8196

8297
def soft_reset(self):
8398
self.active_map = copy.deepcopy(self.current_map)
8499
self.x, self.y = 0, 0
100+
self.episode_steps = 0
85101
obs = self.make_obs(reset_signal=True)
86102
return obs, {}
87103

88104
def step(self, action):
89105
assert self.action_space.contains(action)
106+
self.episode_steps += 1
107+
if self.slip_chance > 0 and action != 0 and random.random() < self.slip_chance:
108+
action = 0
90109
move_x, move_y = self.action_mapping[action]
91110
next_x = max(min(self.x + move_x, self.size - 1), 0)
92111
next_y = max(min(self.y + move_y, self.size - 1), 0)
93112

94113
if (
95-
(self.x, self.y) != (next_y, next_y)
114+
(self.x, self.y) != (next_x, next_y)
96115
and self.hard_mode
97116
and random.random() < 0.33
98117
):
@@ -115,6 +134,14 @@ def step(self, action):
115134
self.x = next_x
116135
self.y = next_y
117136

137+
timed_out = (
138+
not soft_reset
139+
and self.max_episode_steps is not None
140+
and self.episode_steps >= self.max_episode_steps
141+
)
142+
if timed_out:
143+
soft_reset = True
144+
118145
if soft_reset:
119146
next_state, info = self.soft_reset()
120147
success = on == "G"

examples/00_meta_frozen_lake.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,27 @@ def add_cli(parser):
1919
parser.add_argument("--log", action="store_true")
2020
parser.add_argument("--trials", type=int, default=1)
2121
parser.add_argument("--lake_size", type=int, default=5)
22-
parser.add_argument("--k_episodes", type=int, default=15)
22+
parser.add_argument("--k_episodes", type=int, default=10)
2323
parser.add_argument("--hard_mode", action="store_true")
2424
parser.add_argument("--recover_mode", action="store_true")
25-
parser.add_argument("--max_rollout_length", type=int, default=512)
26-
parser.add_argument("--max_seq_len", type=int, default=512)
25+
parser.add_argument("--slip_chance", type=float, default=0.0)
26+
parser.add_argument(
27+
"--max_episode_steps",
28+
type=int,
29+
default=None,
30+
help="Max steps per attempt. Default: N² (standard) or 2*N² (hard).",
31+
)
32+
parser.add_argument(
33+
"--hide_k_progress",
34+
action="store_true",
35+
help="Hide current_k/k_episodes from observations (for length extrapolation tests).",
36+
)
37+
parser.add_argument(
38+
"--max_seq_len",
39+
type=int,
40+
default=None,
41+
help="Training sequence length. Default: max_episode_steps * k_episodes (full trajectory).",
42+
)
2743
return parser
2844

2945

@@ -35,6 +51,19 @@ def add_cli(parser):
3551
if args.log:
3652
import wandb
3753

54+
lake_kwargs = dict(
55+
size=args.lake_size,
56+
k_episodes=args.k_episodes,
57+
hard_mode=args.hard_mode,
58+
recover_mode=args.recover_mode,
59+
max_episode_steps=args.max_episode_steps,
60+
show_k_progress=not args.hide_k_progress,
61+
slip_chance=args.slip_chance,
62+
)
63+
max_ep_steps = MetaFrozenLake(**lake_kwargs).max_episode_steps
64+
max_rollout_length = max_ep_steps * args.k_episodes
65+
max_seq_len = args.max_seq_len or max_rollout_length
66+
3867
config = {}
3968
# configure trajectory encoder (seq2seq memory model)
4069
traj_encoder_type = cli_utils.switch_traj_encoder(
@@ -47,7 +76,6 @@ def add_cli(parser):
4776
tstep_encoder_type = cli_utils.switch_tstep_encoder(
4877
config, arch="ff", n_layers=1, d_hidden=128, d_output=64, normalize_inputs=False
4978
)
50-
5179
# we're using the default exploration strategy but being overly verbose about it for the example
5280
exploration_wrapper_type = cli_utils.switch_exploration(
5381
config,
@@ -70,28 +98,21 @@ def add_cli(parser):
7098
)
7199
# save checkpoints alongside the buffer
72100
ckpt_dir = args.buffer_dir
73-
74101
# wrap environment
75102
make_env = lambda: AMAGOEnv(
76-
MetaFrozenLake(
77-
k_episodes=args.k_episodes,
78-
size=args.lake_size,
79-
hard_mode=args.hard_mode,
80-
recover_mode=args.recover_mode,
81-
),
103+
MetaFrozenLake(**lake_kwargs),
82104
env_name=f"meta_frozen_lake_k{args.k_episodes}_{args.lake_size}x{args.lake_size}"
83105
+ ("_hard" if args.hard_mode else "_easy")
84106
+ ("_recover" if args.recover_mode else "_reset"),
85107
)
86108

87-
# create `Experiment`
88109
experiment = amago.Experiment(
89110
make_train_env=make_env,
90111
make_val_env=make_env,
91-
max_seq_len=args.max_seq_len,
92-
traj_save_len=args.max_rollout_length,
112+
max_seq_len=max_seq_len,
113+
traj_save_len=max_rollout_length,
93114
dataset=dset,
94-
ckpt_base_dir=ckpt_dir,
115+
ckpt_base_dir=args.buffer_dir,
95116
agent_type=agent_type,
96117
exploration_wrapper_type=exploration_wrapper_type,
97118
tstep_encoder_type=tstep_encoder_type,
@@ -102,18 +123,17 @@ def add_cli(parser):
102123
wandb_group_name=group_name,
103124
epochs=700 if not args.hard_mode else 900,
104125
parallel_actors=32,
105-
train_timesteps_per_epoch=args.max_rollout_length,
126+
train_timesteps_per_epoch=max_rollout_length,
106127
train_batches_per_epoch=1000,
107128
val_interval=20,
108-
val_timesteps_per_epoch=args.max_rollout_length * 2,
129+
val_timesteps_per_epoch=max_rollout_length * 2,
109130
ckpt_interval=200,
110131
env_mode="sync",
111132
)
112133

113-
# start experiment (build envs, policies, etc.)
114134
experiment.start()
115-
# run training
116135
experiment.learn()
117136
experiment.evaluate_test(make_env, timesteps=10_000)
118137
experiment.delete_buffer_from_disk()
119-
wandb.finish()
138+
if args.log:
139+
wandb.finish()

0 commit comments

Comments
 (0)