@@ -44,6 +44,16 @@ class MetaFrozenLake(gym.Env):
4444 Defaults to True.
4545 slip_chance: Probability that a movement action is replaced by
4646 a no-op (agent stays in place). Defaults to 0.0.
47+ use_truncation_for_k_limit: If True, signal the k-episode limit
48+ with ``(terminated=False, truncated=True)`` instead of the
49+ default ``(terminated=True, truncated=False)``. Routing the
50+ meta-trial cap through ``truncated`` keeps ``Batch.dones``
51+ (and therefore the n-step bootstrap) live at the k-limit, so
52+ the value function learns the true infinite-horizon
53+ discounted return rather than collapsing to zero at the
54+ training horizon. Essential when ``show_k_progress=False``
55+ and the test rollout uses a different ``k_episodes`` than
56+ training. Defaults to False.
4757 """
4858
4959 def __init__ (
@@ -55,6 +65,7 @@ def __init__(
5565 max_episode_steps : int | None = None ,
5666 show_k_progress : bool = True ,
5767 slip_chance : float = 0.0 ,
68+ use_truncation_for_k_limit : bool = False ,
5869 ):
5970 self .size = size
6071 self .k_episodes = k_episodes
@@ -70,6 +81,7 @@ def __init__(
7081 )
7182 self .show_k_progress = show_k_progress
7283 self .slip_chance = slip_chance
84+ self .use_truncation_for_k_limit = use_truncation_for_k_limit
7385 self .reset ()
7486
7587 def reset (self , * args , ** kwargs ):
@@ -150,8 +162,12 @@ def step(self, action):
150162 else :
151163 next_state , info = self .make_obs (False ), {}
152164
153- terminated = self .current_k >= self .k_episodes
154- return next_state , reward , terminated , False , info
165+ end_of_meta_trial = self .current_k >= self .k_episodes
166+ if self .use_truncation_for_k_limit :
167+ terminated , truncated = False , end_of_meta_trial
168+ else :
169+ terminated , truncated = end_of_meta_trial , False
170+ return next_state , reward , terminated , truncated , info
155171
156172 def render (self , * args , ** kwargs ):
157173 render_map = copy .deepcopy (self .active_map )
@@ -177,7 +193,11 @@ class RoomKeyDoor(gym.Env):
177193 meta_rollout_horizon: The agent has this many timsteps to adapt to
178194 each world layout. The best solution is to infer the key and door locations
179195 and then solve the task as many times as possible within this time limit.
180- Defaults to 500.
196+ Defaults to 500. Ignored if k_episodes is set.
197+ k_episodes: If set, the meta-rollout lasts exactly this many episodes
198+ instead of a fixed number of timesteps. The effective maximum
199+ sequence length becomes k_episodes * max_episode_steps. Defaults
200+ to None (use meta_rollout_horizon).
181201 start_location: The starting location of the agent. Defaults to
182202 "random". Can also be set to a specific (x, y) coordinate.
183203 key_location: The location of the key. Defaults to "random". Can
@@ -186,6 +206,12 @@ class RoomKeyDoor(gym.Env):
186206 Can also be set to a specific (x, y) coordinate.
187207 randomize_actions: If True, the discrete action indices are
188208 randomly shuffled on each reset. Defaults to False.
209+ horizon_type: Either "finite" or "infinite". In "finite" mode, the
210+ normalized episode timestep is included in the observation and
211+ meta-done is signaled as terminated (the agent knows the horizon).
212+ In "infinite" mode, the timestep is excluded from the observation
213+ and meta-done is signaled as truncated (the agent does not know
214+ when the meta-rollout will end). Defaults to "infinite".
189215 """
190216
191217 def __init__ (
@@ -194,24 +220,56 @@ def __init__(
194220 size : int = 9 ,
195221 max_episode_steps : int = 50 ,
196222 meta_rollout_horizon : int = 500 ,
223+ k_episodes : int | None = None ,
197224 start_location : tuple [int , int ] | str = "random" ,
198225 key_location : tuple [int , int ] | str = "random" ,
199226 goal_location : tuple [int , int ] | str = "random" ,
200227 randomize_actions : bool = False ,
228+ horizon_type : str = "infinite" ,
201229 ):
230+ assert horizon_type in (
231+ "finite" ,
232+ "infinite" ,
233+ ), f"horizon_type must be 'finite' or 'infinite', got '{ horizon_type } '"
202234 self .dark = dark
203235 self .size = size
204236 self .H = max_episode_steps
205- self .H_meta = meta_rollout_horizon
206- self .observation_space = gym .spaces .Box (
207- low = 0.0 , high = 1.0 , shape = (4 if self .dark else 8 ,)
237+ self .k_episodes = k_episodes
238+ if k_episodes is not None :
239+ self .H_meta = k_episodes * max_episode_steps
240+ else :
241+ self .H_meta = meta_rollout_horizon
242+ self ._meta_rollout_horizon = meta_rollout_horizon
243+ self .horizon_type = horizon_type
244+ n_actions = 5
245+ time_dim = 1 if self .horizon_type == "finite" else 0
246+ obs_dim = (3 if self .dark else 7 ) + time_dim
247+ max_k = (
248+ k_episodes
249+ if k_episodes is not None
250+ else meta_rollout_horizon // max_episode_steps
208251 )
209- self .action_space = gym .spaces .Discrete (5 )
252+ self .observation_space = gym .spaces .Dict (
253+ {
254+ "observed" : gym .spaces .Box (low = 0.0 , high = 1.0 , shape = (obs_dim ,)),
255+ "episode_id" : gym .spaces .Box (0 , max_k , shape = (), dtype = np .int32 ),
256+ "prev_action" : gym .spaces .Box (low = 0.0 , high = 1.0 , shape = (n_actions ,)),
257+ "prev_reward" : gym .spaces .Box (low = - np .inf , high = np .inf , shape = (1 ,)),
258+ }
259+ )
260+ self .action_space = gym .spaces .Discrete (n_actions )
210261 self .goal_location = goal_location
211262 self .key_location = key_location
212263 self .start_location = start_location
213264 self .randomize_actions = randomize_actions
214265
266+ @property
267+ def meta_horizon (self ) -> int :
268+ """Max trajectory length including soft reset steps between episodes."""
269+ if self .k_episodes is not None :
270+ return self .k_episodes * (self .H + 1 ) - 1
271+ return self ._meta_rollout_horizon
272+
215273 def reset_same_task (self ):
216274 self .pos = self .start
217275 self .episode_time = 0
@@ -220,6 +278,10 @@ def reset_same_task(self):
220278 def reset (self , * args , ** kwargs ):
221279 self .generate_task ()
222280 self .global_time = 0
281+ self .episode_number = 0
282+ self .episode_return = 0.0
283+ self ._prev_action = np .zeros (self .action_space .n , dtype = np .float32 )
284+ self ._prev_reward = np .array ([0.0 ], dtype = np .float32 )
223285 self .reset_same_task ()
224286 self .reset_next_step = False
225287 return self .obs (), {}
@@ -246,7 +308,17 @@ def generate_task(self):
246308
247309 def step (self , action : int ):
248310 self .global_time += 1
311+ info = {}
312+
249313 if self .reset_next_step :
314+ info [f"{ AMAGO_ENV_LOG_PREFIX } Episode { self .episode_number } Return" ] = (
315+ self .episode_return
316+ )
317+ info [f"{ AMAGO_ENV_LOG_PREFIX } Episode { self .episode_number } Length" ] = (
318+ self .episode_time
319+ )
320+ self .episode_number += 1
321+ self .episode_return = 0.0
250322 self .reset_same_task ()
251323 self .reset_next_step = False
252324 reward = 0.0
@@ -262,19 +334,47 @@ def step(self, action: int):
262334 self .has_key = True
263335 if self .episode_time >= self .H :
264336 self .reset_next_step = True
265- metadone = self .global_time >= self .H_meta
266- return self .obs (), reward , metadone , metadone , {}
337+ self .episode_return += reward
338+
339+ action_onehot = np .zeros (self .action_space .n , dtype = np .float32 )
340+ action_onehot [action ] = 1.0
341+ self ._prev_action = action_onehot
342+ self ._prev_reward = np .array ([reward ], dtype = np .float32 )
343+
344+ if self .k_episodes is not None :
345+ completed = self .episode_number + (1 if self .reset_next_step else 0 )
346+ metadone = completed >= self .k_episodes
347+ else :
348+ metadone = self .global_time >= self .H_meta
349+ if metadone and self .reset_next_step :
350+ info [f"{ AMAGO_ENV_LOG_PREFIX } Episode { self .episode_number } Return" ] = (
351+ self .episode_return
352+ )
353+ info [f"{ AMAGO_ENV_LOG_PREFIX } Episode { self .episode_number } Length" ] = (
354+ self .episode_time
355+ )
356+ if self .horizon_type == "finite" :
357+ terminated , truncated = metadone , False
358+ else :
359+ terminated , truncated = False , metadone
360+ return self .obs (), reward , terminated , truncated , info
267361
268362 def obs (self ):
269363 x , y = self .pos
270364 norm = lambda j : float (j ) / self .size
271- # time and has_key keep this fully observed
272- base = [norm (x ), norm (y ), self .has_key , float (self .episode_time ) / self .H ]
365+ base = [norm (x ), norm (y ), self .has_key ]
366+ if self .horizon_type == "finite" :
367+ base .append (float (self .episode_time ) / self .H )
273368 if not self .dark :
274369 goal_x , goal_y = self .goal
275370 key_x , key_y = self .key
276371 base += [norm (goal_x ), norm (goal_y ), norm (key_x ), norm (key_y )]
277- return np .array (base , dtype = np .float32 )
372+ return {
373+ "observed" : np .array (base , dtype = np .float32 ),
374+ "episode_id" : np .int32 (self .episode_number ),
375+ "prev_action" : self ._prev_action .copy (),
376+ "prev_reward" : self ._prev_reward .copy (),
377+ }
278378
279379 def render (self , * args , ** kwargs ):
280380 img = [["." for _ in range (self .size )] for _ in range (self .size )]
0 commit comments