Skip to content

Commit 50f0640

Browse files
committed
Add quadruped
1 parent 024adbe commit 50f0640

3 files changed

Lines changed: 467 additions & 45 deletions

File tree

mujoco_playground/_src/dm_control_suite/__init__.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mujoco_playground._src.dm_control_suite import humanoid
3131
from mujoco_playground._src.dm_control_suite import pendulum
3232
from mujoco_playground._src.dm_control_suite import point_mass
33+
from mujoco_playground._src.dm_control_suite import quadruped
3334
from mujoco_playground._src.dm_control_suite import reacher
3435
from mujoco_playground._src.dm_control_suite import swimmer
3536
from mujoco_playground._src.dm_control_suite import walker
@@ -39,21 +40,13 @@
3940
"AcrobotSwingupSparse": partial(acrobot.Balance, sparse=True),
4041
"BallInCup": ball_in_cup.BallInCup,
4142
"CartpoleBalance": partial(cartpole.Balance, swing_up=False, sparse=False),
42-
"CartpoleBalanceSparse": partial(
43-
cartpole.Balance, swing_up=False, sparse=True
44-
),
43+
"CartpoleBalanceSparse": partial(cartpole.Balance, swing_up=False, sparse=True),
4544
"CartpoleSwingup": partial(cartpole.Balance, swing_up=True, sparse=False),
46-
"CartpoleSwingupSparse": partial(
47-
cartpole.Balance, swing_up=True, sparse=True
48-
),
45+
"CartpoleSwingupSparse": partial(cartpole.Balance, swing_up=True, sparse=True),
4946
"CheetahRun": cheetah.Run,
5047
"FingerSpin": finger.Spin,
51-
"FingerTurnEasy": partial(
52-
finger.Turn, target_radius=finger.EASY_TARGET_SIZE
53-
),
54-
"FingerTurnHard": partial(
55-
finger.Turn, target_radius=finger.HARD_TARGET_SIZE
56-
),
48+
"FingerTurnEasy": partial(finger.Turn, target_radius=finger.EASY_TARGET_SIZE),
49+
"FingerTurnHard": partial(finger.Turn, target_radius=finger.HARD_TARGET_SIZE),
5750
"FishSwim": fish.Swim,
5851
"HopperHop": partial(hopper.Hopper, hopping=True),
5952
"HopperStand": partial(hopper.Hopper, hopping=False),
@@ -62,6 +55,8 @@
6255
"HumanoidRun": partial(humanoid.Humanoid, move_speed=humanoid.RUN_SPEED),
6356
"PendulumSwingup": pendulum.SwingUp,
6457
"PointMass": point_mass.PointMass,
58+
"QuadrupedWalk": partial(quadruped.Quadruped, desired_speed=quadruped.WALK_SPEED),
59+
"QuadrupedRun": partial(quadruped.Quadruped, desired_speed=quadruped.RUN_SPEED),
6560
"ReacherEasy": partial(reacher.Reacher, target_size=reacher.BIG_TARGET),
6661
"ReacherHard": partial(reacher.Reacher, target_size=reacher.SMALL_TARGET),
6762
"SwimmerSwimmer6": partial(swimmer.Swim, n_links=6),
@@ -91,6 +86,8 @@
9186
"HumanoidWalk": humanoid.default_config,
9287
"PendulumSwingup": pendulum.default_config,
9388
"PointMass": point_mass.default_config,
89+
"QuadrupedWalk": quadruped.default_config,
90+
"QuadrupedRun": quadruped.default_config,
9491
"ReacherEasy": reacher.default_config,
9592
"ReacherHard": reacher.default_config,
9693
"SwimmerSwimmer6": swimmer.default_config,
@@ -102,54 +99,54 @@
10299

103100

104101
def __getattr__(name):
105-
if name == "ALL_ENVS":
106-
return tuple(_envs.keys())
107-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
102+
if name == "ALL_ENVS":
103+
return tuple(_envs.keys())
104+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
108105

109106

110107
def register_environment(
111108
env_name: str,
112109
env_class: Type[mjx_env.MjxEnv],
113110
cfg_class: Callable[[], config_dict.ConfigDict],
114111
) -> None:
115-
"""Register a new environment.
112+
"""Register a new environment.
116113
117-
Args:
118-
env_name: The name of the environment.
119-
env_class: The environment class.
120-
cfg_class: The default configuration
121-
"""
122-
_envs[env_name] = env_class
123-
_cfgs[env_name] = cfg_class
114+
Args:
115+
env_name: The name of the environment.
116+
env_class: The environment class.
117+
cfg_class: The default configuration
118+
"""
119+
_envs[env_name] = env_class
120+
_cfgs[env_name] = cfg_class
124121

125122

126123
def get_default_config(env_name: str) -> config_dict.ConfigDict:
127-
"""Get the default configuration for an environment."""
128-
if env_name not in _cfgs:
129-
raise ValueError(
130-
f"Env '{env_name}' not found in default configs. Available configs:"
131-
f" {list(_cfgs.keys())}"
132-
)
133-
return _cfgs[env_name]()
124+
"""Get the default configuration for an environment."""
125+
if env_name not in _cfgs:
126+
raise ValueError(
127+
f"Env '{env_name}' not found in default configs. Available configs:"
128+
f" {list(_cfgs.keys())}"
129+
)
130+
return _cfgs[env_name]()
134131

135132

136133
def load(
137134
env_name: str,
138135
config: Optional[config_dict.ConfigDict] = None,
139136
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
140137
) -> mjx_env.MjxEnv:
141-
"""Get an environment instance with the given configuration.
142-
143-
Args:
144-
env_name: The name of the environment.
145-
config: The configuration to use. If not provided, the default
146-
configuration is used.
147-
config_overrides: A dictionary of overrides for the configuration.
148-
149-
Returns:
150-
An instance of the environment.
151-
"""
152-
if env_name not in _envs:
153-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
154-
config = config or get_default_config(env_name)
155-
return _envs[env_name](config=config, config_overrides=config_overrides)
138+
"""Get an environment instance with the given configuration.
139+
140+
Args:
141+
env_name: The name of the environment.
142+
config: The configuration to use. If not provided, the default
143+
configuration is used.
144+
config_overrides: A dictionary of overrides for the configuration.
145+
146+
Returns:
147+
An instance of the environment.
148+
"""
149+
if env_name not in _envs:
150+
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
151+
config = config or get_default_config(env_name)
152+
return _envs[env_name](config=config, config_overrides=config_overrides)
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Quadruped environment."""
16+
17+
from itertools import product
18+
from typing import Any, Dict, Optional, Union
19+
20+
import jax
21+
import jax.numpy as jp
22+
import mujoco
23+
from ml_collections import config_dict
24+
from mujoco import mjx
25+
from mujoco_playground._src import mjx_env, reward
26+
from mujoco_playground._src.dm_control_suite import common
27+
28+
_XML_PATH = mjx_env.ROOT_PATH / "dm_control_suite" / "xmls" / "quadruped.xml"
29+
30+
WALK_SPEED = 0.5
31+
RUN_SPEED = 5.0
32+
33+
34+
def default_config() -> config_dict.ConfigDict:
35+
return config_dict.create(
36+
ctrl_dt=0.02,
37+
sim_dt=0.005,
38+
episode_length=1000,
39+
action_repeat=1,
40+
vision=False,
41+
)
42+
43+
44+
def _find_non_contacting_height(mjx_model, data, orientation, x_pos=0.0, y_pos=0.0):
45+
def body_fn(state):
46+
z_pos, num_contacts, num_attempts, _ = state
47+
qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos]))
48+
qpos = qpos.at[3:7].set(jp.array(orientation))
49+
ndata = data.replace(qpos=qpos)
50+
ndata = mjx.forward(mjx_model, ndata)
51+
num_contacts = ndata.ncon
52+
z_pos += 0.01
53+
num_attempts += 1
54+
return (z_pos, num_contacts, num_attempts, ndata)
55+
56+
initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts)
57+
*_, num_attemps, ndata = jax.lax.while_loop(
58+
lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000),
59+
body_fn,
60+
initial_state,
61+
)
62+
ndata = jax.tree_map(
63+
lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data
64+
)
65+
return ndata
66+
67+
68+
class Quadruped(mjx_env.MjxEnv):
69+
"""Quadruped environment."""
70+
71+
def __init__(
72+
self,
73+
desired_speed: float,
74+
config: config_dict.ConfigDict = default_config(),
75+
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
76+
):
77+
super().__init__(config, config_overrides)
78+
if self._config.vision:
79+
raise NotImplementedError("Vision not implemented for Quadruped.")
80+
self._desired_speed = desired_speed
81+
self._xml_path = _XML_PATH.as_posix()
82+
self._mj_model = mujoco.MjModel.from_xml_string(
83+
_XML_PATH.read_text(), common.get_assets()
84+
)
85+
self._mj_model.opt.timestep = self.sim_dt
86+
self._mjx_model = mjx.put_model(self._mj_model)
87+
self._post_init()
88+
89+
def _post_init(self):
90+
self._force_torque_names = [
91+
f"{f}_toe_{pos}_{side}"
92+
for (f, pos, side) in product(
93+
("force", "torque"), ("front", "back"), ("left", "right")
94+
)
95+
]
96+
self._torso_id = self._mj_model.body("torso").id
97+
98+
def reset(self, rng: jax.Array) -> mjx_env.State:
99+
data = mjx_env.init(self.mjx_model)
100+
metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())}
101+
info = {"rng": rng}
102+
reward, done = jp.zeros(2)
103+
obs = self._get_obs(data, info)
104+
return mjx_env.State(data, obs, reward, done, metrics, info)
105+
106+
def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
107+
lower, upper = (
108+
self._mj_model.actuator_ctrlrange[:, 0],
109+
self._mj_model.actuator_ctrlrange[:, 1],
110+
)
111+
action = (action + 1.0) / 2.0 * (upper - lower) + lower
112+
data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps)
113+
reward = self._get_reward(data, action, state.info, state.metrics)
114+
obs = self._get_obs(data, state.info)
115+
done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any()
116+
done = done.astype(float)
117+
return mjx_env.State(data, obs, reward, done, state.metrics, state.info)
118+
119+
def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array:
120+
del info
121+
ego = self._egocentric_state(data)
122+
torso_vel = self.torso_velocity(data)
123+
upright = self.torso_upright(data)
124+
imu = self.imu(data)
125+
force_torque = self.force_torque(data)
126+
return jp.hstack((ego, torso_vel, upright, imu, force_torque))
127+
128+
def _get_reward(
129+
self,
130+
data: mjx.Data,
131+
action: jax.Array,
132+
info: dict[str, Any],
133+
metrics: dict[str, Any],
134+
) -> jax.Array:
135+
del info, action
136+
move_reward = reward.tolerance(
137+
self.torso_velocity(data)[0],
138+
bounds=(self._desired_speed, float("inf")),
139+
sigmoid="linear",
140+
margin=self._desired_speed,
141+
value_at_margin=0.5,
142+
)
143+
upright_reward = self._upright_reward(data)
144+
metrics["reward/move"] = move_reward
145+
metrics["reward/upright"] = upright_reward
146+
return move_reward * upright_reward
147+
148+
def _upright_reward(self, data: mjx.Data) -> jax.Array:
149+
upright = self.torso_upright(data)
150+
return reward.tolerance(
151+
upright,
152+
bounds=(1, float("inf")),
153+
sigmoid="linear",
154+
margin=2,
155+
value_at_margin=0,
156+
)
157+
158+
def _egocentric_state(self, data: mjx.Data) -> jax.Array:
159+
return jp.hstack((data.qpos[7:], data.qvel[7:], data.act))
160+
161+
def torso_upright(self, data: mjx.Data) -> jax.Array:
162+
return data.xmat[self._torso_id, 2, 2]
163+
164+
def torso_velocity(self, data: mjx.Data) -> jax.Array:
165+
return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter")
166+
167+
def imu(self, data: mjx.Data) -> jax.Array:
168+
gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro")
169+
accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel")
170+
return jp.hstack((gyro, accelerometer))
171+
172+
def force_torque(self, data: mjx.Data) -> jax.Array:
173+
return jp.hstack(
174+
[
175+
mjx_env.get_sensor_data(self.mj_model, data, name)
176+
for name in self._force_torque_names
177+
]
178+
)
179+
180+
@property
181+
def xml_path(self) -> str:
182+
return self._xml_path
183+
184+
@property
185+
def action_size(self) -> int:
186+
return self.mjx_model.nu
187+
188+
@property
189+
def mj_model(self) -> mujoco.MjModel:
190+
return self._mj_model
191+
192+
@property
193+
def mjx_model(self) -> mjx.Model:
194+
return self._mjx_model

0 commit comments

Comments
 (0)