Skip to content

Commit 417b030

Browse files
btabacopybara-github
authored andcommitted
Add full resets to AutoResetWrapper. #140 #179
PiperOrigin-RevId: 796491534 Change-Id: I187b492c92c2615ea9583587c9b0badf81322c62
1 parent 9a6e545 commit 417b030

File tree

3 files changed

+156
-23
lines changed

3 files changed

+156
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ All notable changes to this project will be documented in this file.
1313
- Remove `mjx_env.init` in favor of `mjx_env.make_data` since `make_data`
1414
now requires an `MjModel` argument rather than an `mjx.Model` argument.
1515
- Add device to `mjx_env.make_data`, fixes #174.
16+
- Update AutoResetWrapper to allow full resets on done. Fixes #179. Also
17+
provides a means for doing curriculum learning via
18+
`state.info['AutoResetWrapper_done_count']`, see #140.
1619

1720
## [0.0.5] - 2025-06-23
1821

mujoco_playground/_src/wrapper.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,22 @@ def wrap_for_brax_training(
9393
randomization_fn: Optional[
9494
Callable[[mjx.Model], Tuple[mjx.Model, mjx.Model]]
9595
] = None,
96+
full_reset: bool = False,
9697
) -> Wrapper:
9798
"""Common wrapper pattern for all brax training agents.
9899
99100
Args:
100101
env: environment to be wrapped
101102
vision: whether the environment will be vision based
102-
num_vision_envs: number of environments the renderer should generate,
103-
should equal the number of batched envs
103+
num_vision_envs: number of environments the renderer should generate, should
104+
equal the number of batched envs
104105
episode_length: length of episode
105106
action_repeat: how many repeated actions to take per step
106107
randomization_fn: randomization function that produces a vectorized model
107108
and in_axes to vmap over
109+
full_reset: whether to call `env.reset` during `env.step` on done rather
110+
than resetting to a cached first state. Setting full_reset=True may
111+
increase wallclock time because it forces full resets to random states.
108112
109113
Returns:
110114
An environment that is wrapped with Episode and AutoReset wrappers. If the
@@ -118,24 +122,66 @@ def wrap_for_brax_training(
118122
else:
119123
env = BraxDomainRandomizationVmapWrapper(env, randomization_fn)
120124
env = brax_training.EpisodeWrapper(env, episode_length, action_repeat)
121-
env = BraxAutoResetWrapper(env)
125+
env = BraxAutoResetWrapper(env, full_reset=full_reset)
122126
return env
123127

124128

125129
class BraxAutoResetWrapper(Wrapper):
126-
"""Automatically resets Brax envs that are done."""
130+
"""Automatically resets Brax envs that are done.
131+
132+
If `full_reset` is disabled (default):
133+
* the environment will reset to a cached first state.
134+
* only data and obs are reset, not the environment info.
135+
136+
If `full_reset` is enabled:
137+
* the environment will call env.reset during env.step on done.
138+
* `full_reset` will thus incur a penalty in wallclock time depending on the
139+
complexity of the reset function.
140+
* info is fully reset, except for info under the key
141+
`AutoResetWrapper_preserve_info`, which is passed through from the prior
142+
step. This can be used for curriculum learning.
143+
144+
Attributes:
145+
env: The wrapped environment.
146+
full_reset: Whether to call `env.reset` during `env.step` on done.
147+
"""
148+
149+
def __init__(self, env: Any, full_reset: bool = False):
150+
super().__init__(env)
151+
self._full_reset = full_reset
152+
self._info_key = 'AutoResetWrapper'
127153

128154
def reset(self, rng: jax.Array) -> mjx_env.State:
129-
state = self.env.reset(rng)
130-
state.info['first_state'] = state.data
131-
state.info['first_obs'] = state.obs
155+
rng_key = jax.vmap(jax.random.split)(rng)
156+
rng, key = rng_key[..., 0], rng_key[..., 1]
157+
state = self.env.reset(key)
158+
state.info[f'{self._info_key}_first_data'] = state.data
159+
state.info[f'{self._info_key}_first_obs'] = state.obs
160+
state.info[f'{self._info_key}_rng'] = rng
161+
state.info[f'{self._info_key}_done_count'] = jp.zeros(
162+
key.shape[:-1], dtype=int
163+
)
132164
return state
133165

134166
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
167+
# grab the reset state.
168+
reset_state = None
169+
rng_key = jax.vmap(jax.random.split)(state.info[f'{self._info_key}_rng'])
170+
reset_rng, reset_key = rng_key[..., 0], rng_key[..., 1]
171+
if self._full_reset:
172+
reset_state = self.reset(reset_key)
173+
reset_data = reset_state.data
174+
reset_obs = reset_state.obs
175+
else:
176+
reset_data = state.info[f'{self._info_key}_first_data']
177+
reset_obs = state.info[f'{self._info_key}_first_obs']
178+
135179
if 'steps' in state.info:
180+
# reset steps to 0 if done.
136181
steps = state.info['steps']
137182
steps = jp.where(state.done, jp.zeros_like(steps), steps)
138183
state.info.update(steps=steps)
184+
139185
state = state.replace(done=jp.zeros_like(state.done))
140186
state = self.env.step(state, action)
141187

@@ -147,11 +193,25 @@ def where_done(x, y):
147193
done = jp.reshape(done, [x.shape[0]] + [1] * (len(x.shape) - 1))
148194
return jp.where(done, x, y)
149195

150-
data = jax.tree.map(
151-
where_done, state.info['first_state'], state.data
152-
)
153-
obs = jax.tree.map(where_done, state.info['first_obs'], state.obs)
154-
return state.replace(data=data, obs=obs)
196+
data = jax.tree.map(where_done, reset_data, state.data)
197+
obs = jax.tree.map(where_done, reset_obs, state.obs)
198+
199+
next_info = state.info
200+
done_count_key = f'{self._info_key}_done_count'
201+
if self._full_reset and reset_state:
202+
next_info = jax.tree.map(where_done, reset_state.info, state.info)
203+
next_info[done_count_key] = state.info[done_count_key]
204+
205+
if 'steps' in next_info:
206+
next_info['steps'] = state.info['steps']
207+
preserve_info_key = f'{self._info_key}_preserve_info'
208+
if preserve_info_key in next_info:
209+
next_info[preserve_info_key] = state.info[preserve_info_key]
210+
211+
next_info[done_count_key] += state.done.astype(int)
212+
next_info[f'{self._info_key}_rng'] = reset_rng
213+
214+
return state.replace(data=data, obs=obs, info=next_info)
155215

156216

157217
class BraxDomainRandomizationVmapWrapper(Wrapper):

mujoco_playground/_src/wrapper_test.py

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,46 +13,116 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Tests for the wrapper module."""
16+
1617
import functools
1718

1819
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
from brax.envs.wrappers import training as brax_training
1922
import jax
2023
import jax.numpy as jp
21-
import numpy as np
22-
2324
from mujoco_playground._src import dm_control_suite
2425
from mujoco_playground._src import wrapper
26+
import numpy as np
2527

2628

27-
class WrapperTest(absltest.TestCase):
29+
class WrapperTest(parameterized.TestCase):
2830

29-
def test_auto_reset_wrapper(self):
31+
@parameterized.named_parameters(
32+
('full_reset', True),
33+
('cache_reset', False),
34+
)
35+
def test_auto_reset_wrapper(self, full_reset):
36+
"""Tests the AutoResetWrapper."""
3037
class DoneEnv:
3138

3239
def __init__(self, env):
3340
self._env = env
3441

3542
def reset(self, key):
36-
return self._env.reset(key)
43+
state = self._env.reset(key)
44+
state.info['AutoResetWrapper_preserve_info'] = 1
45+
state.info['other_info'] = 1
46+
return state
3747

3848
def step(self, state, action):
3949
state = self._env.step(state, jp.ones_like(action))
4050
state = state.replace(done=action[0] > 0)
51+
state.info['AutoResetWrapper_preserve_info'] = 2
52+
state.info['other_info'] = 2
4153
return state
4254

4355
env = wrapper.BraxAutoResetWrapper(
44-
DoneEnv(dm_control_suite.load('CartpoleBalance'))
56+
brax_training.VmapWrapper(
57+
DoneEnv(dm_control_suite.load('CartpoleBalance'))
58+
),
59+
full_reset=full_reset,
4560
)
4661

4762
jit_reset = jax.jit(env.reset)
4863
jit_step = jax.jit(env.step)
49-
state = jit_reset(jax.random.PRNGKey(0))
50-
first_qpos = state.info['first_state'].qpos
64+
state = jit_reset(jax.random.PRNGKey(0)[None])
65+
first_qpos = state.data.qpos
5166

52-
state = jit_step(state, -jp.ones(env._env.action_size))
67+
# First step should not be done.
68+
state = jit_step(state, -jp.ones(env._env.action_size)[None])
69+
np.testing.assert_allclose(state.info['AutoResetWrapper_done_count'], 0)
5370
self.assertGreater(np.linalg.norm(state.data.qpos - first_qpos), 1e-3)
54-
state = jit_step(state, jp.ones(env._env.action_size))
55-
np.testing.assert_allclose(state.data.qpos, first_qpos, atol=1e-6)
71+
self.assertEqual(state.info['AutoResetWrapper_preserve_info'], 2)
72+
self.assertEqual(state.info['other_info'], 2)
73+
74+
for i in range(1, 3):
75+
state = jit_step(state, jp.ones(env._env.action_size)[None])
76+
jax.tree.map(lambda x: x.block_until_ready(), state)
77+
if full_reset:
78+
self.assertTrue((state.data.qpos != first_qpos).all())
79+
else:
80+
np.testing.assert_allclose(state.data.qpos, first_qpos, atol=1e-6)
81+
np.testing.assert_allclose(state.info['AutoResetWrapper_done_count'], i)
82+
self.assertEqual(state.info['AutoResetWrapper_preserve_info'], 2)
83+
expected_other_info = 1 if full_reset else 2
84+
self.assertEqual(state.info['other_info'], expected_other_info)
85+
86+
@parameterized.named_parameters(
87+
('full_reset', True),
88+
('cache_reset', False),
89+
)
90+
def test_evalwrapper_with_reset(self, full_reset):
91+
"""Tests EvalWrapper with reset in the AutoResetWrapper."""
92+
episode_length = 10
93+
num_envs = 4
94+
95+
env = dm_control_suite.load('CartpoleBalance')
96+
env = wrapper.wrap_for_brax_training(
97+
env,
98+
episode_length=episode_length,
99+
full_reset=full_reset,
100+
)
101+
env = brax_training.EvalWrapper(env)
102+
103+
jit_reset = jax.jit(env.reset)
104+
jit_step = jax.jit(env.step)
105+
106+
rng = jax.random.PRNGKey(0)
107+
rng = jax.random.split(rng, num_envs)
108+
state = jit_reset(rng)
109+
first_obs = state.obs
110+
action = jp.zeros((num_envs, env.action_size))
111+
112+
for _ in range(episode_length):
113+
state = jit_step(state, action)
114+
115+
# All episodes should finish at episode_length.
116+
avg_episode_length = state.info['eval_metrics'].episode_steps.mean()
117+
np.testing.assert_allclose(avg_episode_length, episode_length, atol=1e-6)
118+
active_episodes = state.info['eval_metrics'].active_episodes
119+
self.assertTrue(np.all(active_episodes == 0))
120+
121+
np.testing.assert_array_equal(state.info['steps'], 10 * np.ones(num_envs))
122+
if full_reset:
123+
self.assertTrue((state.obs != first_obs).all())
124+
else:
125+
np.testing.assert_allclose(state.obs, first_obs, rtol=1e-6)
56126

57127
def test_domain_randomization_wrapper(self):
58128
def randomization_fn(model, rng):

0 commit comments

Comments
 (0)