Skip to content

Commit 1afeee3

Browse files
committed
Add temporary testing files
1 parent 50f0640 commit 1afeee3

15 files changed

Lines changed: 763 additions & 420 deletions

File tree

config/hardware/2080_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_2080_ti:1", "account": "ls_krausea" }
6+

config/hardware/3090_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_3090:1", "account": "ls_krausea" }
6+

config/hardware/4090_rtx.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# @package _global_
2+
3+
hydra:
4+
launcher:
5+
additional_parameters: { "gpus": "rtx_4090:1", "account": "ls_krausea" }
6+

config/hydra/launcher/slurm.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
2+
timeout_min: 30
3+
cpus_per_task: 10
4+
tasks_per_node: 1
5+
mem_gb: null
6+
nodes: 1
7+
name: ${hydra.job.name}
8+
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
9+
mem_per_gpu: null
10+
mem_per_cpu: 10240
11+
account: ls_krausea
12+
additional_parameters: {"gpus": "rtx_4090:1", "account": "ls_krausea"}
13+
array_parallelism: 256
14+
max_num_timeout: 100
15+
setup:
16+
- '#SBATCH --requeue'

config/train_brax.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
defaults:
2+
- _self_
3+
4+
hydra:
5+
run:
6+
dir: ${log_dir}/${now:%Y-%m-%d}/${now:%H-%M-%S}
7+
sweep:
8+
dir: ${log_dir}/${hydra.job.name}
9+
subdir: ${hydra.job.override_dirname}/seed=${training.seed}
10+
job:
11+
config:
12+
override_dirname:
13+
exclude_keys:
14+
- log_dir
15+
- training.seed
16+
- wandb
17+
chdir: true
18+
19+
20+
wandb:
21+
group: null
22+
notes: null
23+
name: ${hydra:job.override_dirname}
24+
25+
jit: true
26+
27+
training:
28+
seed: 0
29+
render: true

mujoco_playground/_src/dm_control_suite/quadruped.py

Lines changed: 151 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
import jax
2121
import jax.numpy as jp
22-
import mujoco
2322
from ml_collections import config_dict
23+
import mujoco
2424
from mujoco import mjx
25-
from mujoco_playground._src import mjx_env, reward
25+
26+
from mujoco_playground._src import mjx_env
27+
from mujoco_playground._src import reward
2628
from mujoco_playground._src.dm_control_suite import common
2729

2830
_XML_PATH = mjx_env.ROOT_PATH / "dm_control_suite" / "xmls" / "quadruped.xml"
@@ -32,163 +34,163 @@
3234

3335

3436
def default_config() -> config_dict.ConfigDict:
35-
return config_dict.create(
36-
ctrl_dt=0.02,
37-
sim_dt=0.005,
38-
episode_length=1000,
39-
action_repeat=1,
40-
vision=False,
41-
)
42-
43-
44-
def _find_non_contacting_height(mjx_model, data, orientation, x_pos=0.0, y_pos=0.0):
45-
def body_fn(state):
46-
z_pos, num_contacts, num_attempts, _ = state
47-
qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos]))
48-
qpos = qpos.at[3:7].set(jp.array(orientation))
49-
ndata = data.replace(qpos=qpos)
50-
ndata = mjx.forward(mjx_model, ndata)
51-
num_contacts = ndata.ncon
52-
z_pos += 0.01
53-
num_attempts += 1
54-
return (z_pos, num_contacts, num_attempts, ndata)
55-
56-
initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts)
57-
*_, num_attemps, ndata = jax.lax.while_loop(
58-
lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000),
59-
body_fn,
60-
initial_state,
61-
)
62-
ndata = jax.tree_map(
63-
lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data
64-
)
65-
return ndata
37+
return config_dict.create(
38+
ctrl_dt=0.02,
39+
sim_dt=0.005,
40+
episode_length=1000,
41+
action_repeat=1,
42+
vision=False,
43+
)
44+
45+
46+
def _find_non_contacting_height(
47+
mjx_model, data, orientation, x_pos=0.0, y_pos=0.0
48+
):
49+
def body_fn(state):
50+
z_pos, num_contacts, num_attempts, _ = state
51+
qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos]))
52+
qpos = qpos.at[3:7].set(jp.array(orientation))
53+
ndata = data.replace(qpos=qpos)
54+
ndata = mjx.forward(mjx_model, ndata)
55+
num_contacts = ndata.ncon
56+
z_pos += 0.01
57+
num_attempts += 1
58+
return (z_pos, num_contacts, num_attempts, ndata)
59+
60+
initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts)
61+
*_, num_attemps, ndata = jax.lax.while_loop(
62+
lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000),
63+
body_fn,
64+
initial_state,
65+
)
66+
ndata = jax.tree_map(
67+
lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data
68+
)
69+
return ndata
6670

6771

6872
class Quadruped(mjx_env.MjxEnv):
69-
"""Quadruped environment."""
70-
71-
def __init__(
72-
self,
73-
desired_speed: float,
74-
config: config_dict.ConfigDict = default_config(),
75-
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
76-
):
77-
super().__init__(config, config_overrides)
78-
if self._config.vision:
79-
raise NotImplementedError("Vision not implemented for Quadruped.")
80-
self._desired_speed = desired_speed
81-
self._xml_path = _XML_PATH.as_posix()
82-
self._mj_model = mujoco.MjModel.from_xml_string(
83-
_XML_PATH.read_text(), common.get_assets()
84-
)
85-
self._mj_model.opt.timestep = self.sim_dt
86-
self._mjx_model = mjx.put_model(self._mj_model)
87-
self._post_init()
88-
89-
def _post_init(self):
90-
self._force_torque_names = [
91-
f"{f}_toe_{pos}_{side}"
92-
for (f, pos, side) in product(
93-
("force", "torque"), ("front", "back"), ("left", "right")
94-
)
95-
]
96-
self._torso_id = self._mj_model.body("torso").id
97-
98-
def reset(self, rng: jax.Array) -> mjx_env.State:
99-
data = mjx_env.init(self.mjx_model)
100-
metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())}
101-
info = {"rng": rng}
102-
reward, done = jp.zeros(2)
103-
obs = self._get_obs(data, info)
104-
return mjx_env.State(data, obs, reward, done, metrics, info)
105-
106-
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
107-
lower, upper = (
108-
self._mj_model.actuator_ctrlrange[:, 0],
109-
self._mj_model.actuator_ctrlrange[:, 1],
110-
)
111-
action = (action + 1.0) / 2.0 * (upper - lower) + lower
112-
data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps)
113-
reward = self._get_reward(data, action, state.info, state.metrics)
114-
obs = self._get_obs(data, state.info)
115-
done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any()
116-
done = done.astype(float)
117-
return mjx_env.State(data, obs, reward, done, state.metrics, state.info)
118-
119-
def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
120-
del info
121-
ego = self._egocentric_state(data)
122-
torso_vel = self.torso_velocity(data)
123-
upright = self.torso_upright(data)
124-
imu = self.imu(data)
125-
force_torque = self.force_torque(data)
126-
return jp.hstack((ego, torso_vel, upright, imu, force_torque))
127-
128-
def _get_reward(
129-
self,
130-
data: mjx.Data,
131-
action: jax.Array,
132-
info: dict[str, Any],
133-
metrics: dict[str, Any],
134-
) -> jax.Array:
135-
del info, action
136-
move_reward = reward.tolerance(
137-
self.torso_velocity(data)[0],
138-
bounds=(self._desired_speed, float("inf")),
139-
sigmoid="linear",
140-
margin=self._desired_speed,
141-
value_at_margin=0.5,
142-
)
143-
upright_reward = self._upright_reward(data)
144-
metrics["reward/move"] = move_reward
145-
metrics["reward/upright"] = upright_reward
146-
return move_reward * upright_reward
147-
148-
def _upright_reward(self, data: mjx.Data) -> jax.Array:
149-
upright = self.torso_upright(data)
150-
return reward.tolerance(
151-
upright,
152-
bounds=(1, float("inf")),
153-
sigmoid="linear",
154-
margin=2,
155-
value_at_margin=0,
73+
"""Quadruped environment."""
74+
75+
def __init__(
76+
self,
77+
desired_speed: float,
78+
config: config_dict.ConfigDict = default_config(),
79+
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
80+
):
81+
super().__init__(config, config_overrides)
82+
if self._config.vision:
83+
raise NotImplementedError("Vision not implemented for Quadruped.")
84+
self._desired_speed = desired_speed
85+
self._xml_path = _XML_PATH.as_posix()
86+
self._mj_model = mujoco.MjModel.from_xml_string(
87+
_XML_PATH.read_text(), common.get_assets()
88+
)
89+
self._mj_model.opt.timestep = self.sim_dt
90+
self._mjx_model = mjx.put_model(self._mj_model)
91+
self._post_init()
92+
93+
def _post_init(self):
94+
self._force_torque_names = [
95+
f"{f}_toe_{pos}_{side}"
96+
for (f, pos, side) in product(
97+
("force", "torque"), ("front", "back"), ("left", "right")
15698
)
99+
]
100+
self._torso_id = self._mj_model.body("torso").id
101+
102+
def reset(self, rng: jax.Array) -> mjx_env.State:
103+
data = mjx_env.init(self.mjx_model)
104+
metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())}
105+
info = {"rng": rng}
106+
reward, done = jp.zeros(2)
107+
obs = self._get_obs(data, info)
108+
return mjx_env.State(data, obs, reward, done, metrics, info)
109+
110+
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
111+
lower, upper = (
112+
self._mj_model.actuator_ctrlrange[:, 0],
113+
self._mj_model.actuator_ctrlrange[:, 1],
114+
)
115+
action = (action + 1.0) / 2.0 * (upper - lower) + lower
116+
data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps)
117+
reward = self._get_reward(data, action, state.info, state.metrics)
118+
obs = self._get_obs(data, state.info)
119+
done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any()
120+
done = done.astype(float)
121+
return mjx_env.State(data, obs, reward, done, state.metrics, state.info)
122+
123+
def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
124+
del info
125+
ego = self._egocentric_state(data)
126+
torso_vel = self.torso_velocity(data)
127+
upright = self.torso_upright(data)
128+
imu = self.imu(data)
129+
force_torque = self.force_torque(data)
130+
return jp.hstack((ego, torso_vel, upright, imu, force_torque))
131+
132+
def _get_reward(
133+
self,
134+
data: mjx.Data,
135+
action: jax.Array,
136+
info: dict[str, Any],
137+
metrics: dict[str, Any],
138+
) -> jax.Array:
139+
del info, action
140+
move_reward = reward.tolerance(
141+
self.torso_velocity(data)[0],
142+
bounds=(self._desired_speed, float("inf")),
143+
sigmoid="linear",
144+
margin=self._desired_speed,
145+
value_at_margin=0.5,
146+
)
147+
upright_reward = self._upright_reward(data)
148+
metrics["reward/move"] = move_reward
149+
metrics["reward/upright"] = upright_reward
150+
return move_reward * upright_reward
151+
152+
def _upright_reward(self, data: mjx.Data) -> jax.Array:
153+
upright = self.torso_upright(data)
154+
return reward.tolerance(
155+
upright,
156+
bounds=(1, float("inf")),
157+
sigmoid="linear",
158+
margin=2,
159+
value_at_margin=0,
160+
)
157161

158-
def _egocentric_state(self, data: mjx.Data) -> jax.Array:
159-
return jp.hstack((data.qpos[7:], data.qvel[7:], data.act))
162+
def _egocentric_state(self, data: mjx.Data) -> jax.Array:
163+
return jp.hstack((data.qpos[7:], data.qvel[7:], data.act))
160164

161-
def torso_upright(self, data: mjx.Data) -> jax.Array:
162-
return data.xmat[self._torso_id, 2, 2]
165+
def torso_upright(self, data: mjx.Data) -> jax.Array:
166+
return data.xmat[self._torso_id, 2, 2]
163167

164-
def torso_velocity(self, data: mjx.Data) -> jax.Array:
165-
return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter")
168+
def torso_velocity(self, data: mjx.Data) -> jax.Array:
169+
return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter")
166170

167-
def imu(self, data: mjx.Data) -> jax.Array:
168-
gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro")
169-
accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel")
170-
return jp.hstack((gyro, accelerometer))
171+
def imu(self, data: mjx.Data) -> jax.Array:
172+
gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro")
173+
accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel")
174+
return jp.hstack((gyro, accelerometer))
171175

172-
def force_torque(self, data: mjx.Data) -> jax.Array:
173-
return jp.hstack(
174-
[
175-
mjx_env.get_sensor_data(self.mj_model, data, name)
176-
for name in self._force_torque_names
177-
]
178-
)
176+
def force_torque(self, data: mjx.Data) -> jax.Array:
177+
return jp.hstack([
178+
mjx_env.get_sensor_data(self.mj_model, data, name)
179+
for name in self._force_torque_names
180+
])
179181

180-
@property
181-
def xml_path(self) -> str:
182-
return self._xml_path
182+
@property
183+
def xml_path(self) -> str:
184+
return self._xml_path
183185

184-
@property
185-
def action_size(self) -> int:
186-
return self.mjx_model.nu
186+
@property
187+
def action_size(self) -> int:
188+
return self.mjx_model.nu
187189

188-
@property
189-
def mj_model(self) -> mujoco.MjModel:
190-
return self._mj_model
190+
@property
191+
def mj_model(self) -> mujoco.MjModel:
192+
return self._mj_model
191193

192-
@property
193-
def mjx_model(self) -> mjx.Model:
194-
return self._mjx_model
194+
@property
195+
def mjx_model(self) -> mjx.Model:
196+
return self._mjx_model

0 commit comments

Comments
 (0)