@@ -93,18 +93,22 @@ def wrap_for_brax_training(
9393 randomization_fn : Optional [
9494 Callable [[mjx .Model ], Tuple [mjx .Model , mjx .Model ]]
9595 ] = None ,
96+ full_reset : bool = False ,
9697) -> Wrapper :
9798 """Common wrapper pattern for all brax training agents.
9899
99100 Args:
100101 env: environment to be wrapped
101102 vision: whether the environment will be vision based
102- num_vision_envs: number of environments the renderer should generate,
103- should equal the number of batched envs
103+ num_vision_envs: number of environments the renderer should generate, should
104+ equal the number of batched envs
104105 episode_length: length of episode
105106 action_repeat: how many repeated actions to take per step
106107 randomization_fn: randomization function that produces a vectorized model
107108 and in_axes to vmap over
109+ full_reset: whether to call `env.reset` during `env.step` on done rather
110+ than resetting to a cached first state. Setting full_reset=True may
111+ increase wallclock time because it forces full resets to random states.
108112
109113 Returns:
110114 An environment that is wrapped with Episode and AutoReset wrappers. If the
@@ -118,24 +122,66 @@ def wrap_for_brax_training(
118122 else :
119123 env = BraxDomainRandomizationVmapWrapper (env , randomization_fn )
120124 env = brax_training .EpisodeWrapper (env , episode_length , action_repeat )
121- env = BraxAutoResetWrapper (env )
125+ env = BraxAutoResetWrapper (env , full_reset = full_reset )
122126 return env
123127
124128
125129class BraxAutoResetWrapper (Wrapper ):
126- """Automatically resets Brax envs that are done."""
130+ """Automatically resets Brax envs that are done.
131+
132+ If `full_reset` is disabled (default):
133+ * the environment will reset to a cached first state.
134+ * only data and obs are reset, not the environment info.
135+
136+ If `full_reset` is enabled:
137+ * the environment will call env.reset during env.step on done.
138+ * `full_reset` will thus incur a penalty in wallclock time depending on the
139+ complexity of the reset function.
140+ * info is fully reset, except for info under the key
141+ `AutoResetWrapper_preserve_info`, which is passed through from the prior
142+ step. This can be used for curriculum learning.
143+
144+ Attributes:
145+ env: The wrapped environment.
146+ full_reset: Whether to call `env.reset` during `env.step` on done.
147+ """
148+
149+ def __init__ (self , env : Any , full_reset : bool = False ):
150+ super ().__init__ (env )
151+ self ._full_reset = full_reset
152+ self ._info_key = 'AutoResetWrapper'
127153
128154 def reset (self , rng : jax .Array ) -> mjx_env .State :
129- state = self .env .reset (rng )
130- state .info ['first_state' ] = state .data
131- state .info ['first_obs' ] = state .obs
155+ rng_key = jax .vmap (jax .random .split )(rng )
156+ rng , key = rng_key [..., 0 ], rng_key [..., 1 ]
157+ state = self .env .reset (key )
158+ state .info [f'{ self ._info_key } _first_data' ] = state .data
159+ state .info [f'{ self ._info_key } _first_obs' ] = state .obs
160+ state .info [f'{ self ._info_key } _rng' ] = rng
161+ state .info [f'{ self ._info_key } _done_count' ] = jp .zeros (
162+ key .shape [:- 1 ], dtype = int
163+ )
132164 return state
133165
134166 def step (self , state : mjx_env .State , action : jax .Array ) -> mjx_env .State :
167+ # grab the reset state.
168+ reset_state = None
169+ rng_key = jax .vmap (jax .random .split )(state .info [f'{ self ._info_key } _rng' ])
170+ reset_rng , reset_key = rng_key [..., 0 ], rng_key [..., 1 ]
171+ if self ._full_reset :
172+ reset_state = self .reset (reset_key )
173+ reset_data = reset_state .data
174+ reset_obs = reset_state .obs
175+ else :
176+ reset_data = state .info [f'{ self ._info_key } _first_data' ]
177+ reset_obs = state .info [f'{ self ._info_key } _first_obs' ]
178+
135179 if 'steps' in state .info :
180+ # reset steps to 0 if done.
136181 steps = state .info ['steps' ]
137182 steps = jp .where (state .done , jp .zeros_like (steps ), steps )
138183 state .info .update (steps = steps )
184+
139185 state = state .replace (done = jp .zeros_like (state .done ))
140186 state = self .env .step (state , action )
141187
@@ -147,11 +193,25 @@ def where_done(x, y):
147193 done = jp .reshape (done , [x .shape [0 ]] + [1 ] * (len (x .shape ) - 1 ))
148194 return jp .where (done , x , y )
149195
150- data = jax .tree .map (
151- where_done , state .info ['first_state' ], state .data
152- )
153- obs = jax .tree .map (where_done , state .info ['first_obs' ], state .obs )
154- return state .replace (data = data , obs = obs )
196+ data = jax .tree .map (where_done , reset_data , state .data )
197+ obs = jax .tree .map (where_done , reset_obs , state .obs )
198+
199+ next_info = state .info
200+ done_count_key = f'{ self ._info_key } _done_count'
201+ if self ._full_reset and reset_state :
202+ next_info = jax .tree .map (where_done , reset_state .info , state .info )
203+ next_info [done_count_key ] = state .info [done_count_key ]
204+
205+ if 'steps' in next_info :
206+ next_info ['steps' ] = state .info ['steps' ]
207+ preserve_info_key = f'{ self ._info_key } _preserve_info'
208+ if preserve_info_key in next_info :
209+ next_info [preserve_info_key ] = state .info [preserve_info_key ]
210+
211+ next_info [done_count_key ] += state .done .astype (int )
212+ next_info [f'{ self ._info_key } _rng' ] = reset_rng
213+
214+ return state .replace (data = data , obs = obs , info = next_info )
155215
156216
157217class BraxDomainRandomizationVmapWrapper (Wrapper ):
0 commit comments