Skip to content

Commit 8356c84

Browse files
btabacopybara-github
authored andcommitted
Copybara import of the project:
-- 1daeb57 by Baruch Tabanpour <[email protected]>: fix for rsl-rl training with new tensordict in rsl-rl-lib>=3.0.0 -- 7984057 by Baruch Tabanpour <[email protected]>: small fix -- d87cb04 by Baruch Tabanpour <[email protected]>: deps COPYBARA_INTEGRATE_REVIEW=#210 from btaba:btaba-patch-1 d87cb04 PiperOrigin-RevId: 807321828 Change-Id: I2c51892f818c304b2ed0f384b61fb97561e7e81c
1 parent 33fa9e5 commit 8356c84

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ __pycache__
55
MUJOCO_LOG.TXT
66
mujoco_menagerie
77
checkpoints/
8+
logs

learning/train_rsl_rl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ def render_callback(_, state):
167167
# Build RSL-RL config
168168
train_cfg = get_rl_config(_ENV_NAME.value)
169169

170+
obs_size = raw_env.observation_size
171+
if isinstance(obs_size, dict):
172+
train_cfg.obs_groups = {"policy": ["state"], "critic": ["privileged_state"]}
173+
else:
174+
train_cfg.obs_groups = {"policy": ["state"], "critic": ["state"]}
175+
170176
# Overwrite default config with flags
171177
train_cfg.seed = _SEED.value
172178
train_cfg.run_name = exp_name

mujoco_playground/_src/wrapper_torch.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections import deque
1818
import functools
1919
import os
20+
from typing import Any
2021

2122
import jax
2223
import numpy as np
@@ -31,6 +32,10 @@
3132
torch = None
3233

3334
from mujoco_playground._src import wrapper
35+
try:
36+
from tensordict import TensorDict # pytype: disable=import-error
37+
except ImportError:
38+
TensorDict = None
3439

3540

3641
def _jax_to_torch(tensor):
@@ -158,8 +163,10 @@ def step(self, action):
158163
if self.asymmetric_obs:
159164
obs = _jax_to_torch(self.env_state.obs["state"])
160165
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
166+
obs = {"state": obs, "privileged_state": critic_obs}
161167
else:
162168
obs = _jax_to_torch(self.env_state.obs)
169+
obs = {"state": obs}
163170
reward = _jax_to_torch(self.env_state.reward)
164171
done = _jax_to_torch(self.env_state.done)
165172
info = self.env_state.info
@@ -187,6 +194,7 @@ def step(self, action):
187194
if k not in info_ret["log"]:
188195
info_ret["log"][k] = _jax_to_torch(v).float().mean().item()
189196

197+
obs = TensorDict(obs, batch_size=[self.num_envs])
190198
return obs, reward, done, info_ret
191199

192200
def reset(self):
@@ -195,23 +203,15 @@ def reset(self):
195203

196204
if self.asymmetric_obs:
197205
obs = _jax_to_torch(self.env_state.obs["state"])
198-
# critic_obs = jax_to_torch(self.env_state.obs["privileged_state"])
206+
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
207+
obs = {"state": obs, "privileged_state": critic_obs}
199208
else:
200209
obs = _jax_to_torch(self.env_state.obs)
201-
return obs
202-
203-
def reset_with_critic_obs(self):
204-
self.env_state = self.reset_fn(self.key_reset)
205-
obs = _jax_to_torch(self.env_state.obs["state"])
206-
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
207-
return obs, critic_obs
210+
obs = {"state": obs}
211+
return TensorDict(obs, batch_size=[self.num_envs])
208212

209213
def get_observations(self):
210-
if self.asymmetric_obs:
211-
obs, critic_obs = self.reset_with_critic_obs()
212-
return obs, {"observations": {"critic": critic_obs}}
213-
else:
214-
return self.reset(), {"observations": {}}
214+
return self.reset()
215215

216216
def render(self, mode="human"): # pylint: disable=unused-argument
217217
if self.render_callback is not None:

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ dependencies = [
3535
"orbax-checkpoint>=0.11.22",
3636
"tqdm",
3737
"warp-lang>=1.9.0.dev",
38-
"wandb",
3938
]
4039
keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"]
4140

@@ -75,9 +74,14 @@ dev = [
7574
"pylint",
7675
"pytest-xdist",
7776
]
77+
learning = [
78+
"rsl-rl-lib>=3.0.0",
79+
"wandb",
80+
]
7881
all = [
7982
"playground[dev]",
8083
"playground[notebooks]",
84+
"playground[learning]",
8185
]
8286

8387
[tool.hatch.metadata]

0 commit comments

Comments
 (0)