Skip to content

Commit 024adbe

Browse files
Merge pull request #88 from google-deepmind:apollo-joystick
PiperOrigin-RevId: 746679494 Change-Id: Ife4eb8f049791d8879d72b3509a022249ebb4b28
2 parents 9fe6c50 + ba5fc5f commit 024adbe

File tree

12 files changed

+2030
-0
lines changed

12 files changed

+2030
-0
lines changed

mujoco_playground/_src/locomotion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mujoco import mjx
2323

2424
from mujoco_playground._src import mjx_env
25+
from mujoco_playground._src.locomotion.apollo import joystick as apollo_joystick
2526
from mujoco_playground._src.locomotion.barkour import joystick as barkour_joystick
2627
from mujoco_playground._src.locomotion.berkeley_humanoid import joystick as berkeley_humanoid_joystick
2728
from mujoco_playground._src.locomotion.berkeley_humanoid import randomize as berkeley_humanoid_randomize
@@ -41,6 +42,9 @@
4142
from mujoco_playground._src.locomotion.t1 import randomize as t1_randomize
4243

4344
_envs = {
45+
"ApolloJoystickFlatTerrain": functools.partial(
46+
apollo_joystick.Joystick, task="flat_terrain"
47+
),
4448
"BarkourJoystick": barkour_joystick.Joystick,
4549
"BerkeleyHumanoidJoystickFlatTerrain": functools.partial(
4650
berkeley_humanoid_joystick.Joystick, task="flat_terrain"
@@ -82,6 +86,7 @@
8286
}
8387

8488
_cfgs = {
89+
"ApolloJoystickFlatTerrain": apollo_joystick.default_config,
8590
"BarkourJoystick": barkour_joystick.default_config,
8691
"BerkeleyHumanoidJoystickFlatTerrain": (
8792
berkeley_humanoid_joystick.default_config
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Base classes for Apollo."""
16+
17+
from typing import Any, Dict, Optional, Union
18+
19+
import jax
20+
import jax.numpy as jp
21+
import mujoco
22+
import numpy as np
23+
from etils import epath
24+
from ml_collections import config_dict
25+
from mujoco import mjx
26+
27+
from mujoco_playground._src import mjx_env
28+
from mujoco_playground._src.locomotion.apollo import constants as consts
29+
from mujoco_playground._src.collision import geoms_colliding
30+
31+
32+
def get_assets() -> Dict[str, bytes]:
33+
assets = {}
34+
# Playground assets.
35+
mjx_env.update_assets(assets, consts.XML_DIR, "*.xml")
36+
mjx_env.update_assets(assets, consts.XML_DIR / "assets")
37+
# Menagerie assets.
38+
path = mjx_env.MENAGERIE_PATH / "apptronik_apollo"
39+
mjx_env.update_assets(assets, path, "*.xml")
40+
mjx_env.update_assets(assets, path / "assets")
41+
mjx_env.update_assets(assets, path / "assets" / "ability_hand")
42+
return assets
43+
44+
45+
class ApolloEnv(mjx_env.MjxEnv):
46+
"""Base class for Apollo environments."""
47+
48+
def __init__(
49+
self,
50+
xml_path: str,
51+
config: config_dict.ConfigDict,
52+
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
53+
) -> None:
54+
super().__init__(config, config_overrides)
55+
56+
self._mj_model = mujoco.MjModel.from_xml_string(
57+
epath.Path(xml_path).read_text(), assets=get_assets()
58+
)
59+
self._mj_model.opt.timestep = self.sim_dt
60+
61+
self._mj_model.vis.global_.offwidth = 3840
62+
self._mj_model.vis.global_.offheight = 2160
63+
64+
self._mjx_model = mjx.put_model(self._mj_model)
65+
self._xml_path = xml_path
66+
67+
self._init_q = jp.array(self._mj_model.keyframe("knees_bent").qpos)
68+
self._default_ctrl = jp.array(self._mj_model.keyframe("knees_bent").ctrl)
69+
self._default_pose = jp.array(self._mj_model.keyframe("knees_bent").qpos[7:])
70+
self._actuator_torques = self.mj_model.jnt_actfrcrange[1:, 1]
71+
72+
# Body IDs.
73+
self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id
74+
75+
# Geom IDs.
76+
self._floor_geom_id = self._mj_model.geom("floor").id
77+
self._left_feet_geom_id = np.array(
78+
[self._mj_model.geom(name).id for name in consts.LEFT_FEET_GEOMS]
79+
)
80+
self._right_feet_geom_id = np.array(
81+
[self._mj_model.geom(name).id for name in consts.RIGHT_FEET_GEOMS]
82+
)
83+
self._left_hand_geom_id = self._mj_model.geom("collision_l_hand_plate").id
84+
self._right_hand_geom_id = self._mj_model.geom("collision_r_hand_plate").id
85+
self._left_foot_geom_id = self._mj_model.geom("collision_l_sole").id
86+
self._right_foot_geom_id = self._mj_model.geom("collision_r_sole").id
87+
self._left_shin_geom_id = self._mj_model.geom("collision_capsule_body_l_shin").id
88+
self._right_shin_geom_id = self._mj_model.geom("collision_capsule_body_r_shin").id
89+
self._left_thigh_geom_id = self._mj_model.geom("collision_capsule_body_l_thigh").id
90+
self._right_thigh_geom_id = self._mj_model.geom("collision_capsule_body_r_thigh").id
91+
92+
# Site IDs.
93+
self._imu_site_id = self._mj_model.site("imu").id
94+
self._feet_site_id = np.array(
95+
[self._mj_model.site(name).id for name in consts.FEET_SITES]
96+
)
97+
98+
# Sensor readings.
99+
100+
def get_gravity(self, data: mjx.Data) -> jax.Array:
101+
"""Return the gravity vector in the world frame."""
102+
return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GRAVITY_SENSOR}")
103+
104+
def get_global_linvel(self, data: mjx.Data) -> jax.Array:
105+
"""Return the linear velocity of the robot in the world frame."""
106+
return mjx_env.get_sensor_data(
107+
self.mj_model, data, f"{consts.GLOBAL_LINVEL_SENSOR}"
108+
)
109+
110+
def get_global_angvel(self, data: mjx.Data) -> jax.Array:
111+
"""Return the angular velocity of the robot in the world frame."""
112+
return mjx_env.get_sensor_data(
113+
self.mj_model, data, f"{consts.GLOBAL_ANGVEL_SENSOR}"
114+
)
115+
116+
def get_local_linvel(self, data: mjx.Data) -> jax.Array:
117+
"""Return the linear velocity of the robot in the local frame."""
118+
return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.LOCAL_LINVEL_SENSOR}")
119+
120+
def get_accelerometer(self, data: mjx.Data) -> jax.Array:
121+
"""Return the accelerometer readings in the local frame."""
122+
return mjx_env.get_sensor_data(
123+
self.mj_model, data, f"{consts.ACCELEROMETER_SENSOR}"
124+
)
125+
126+
def get_gyro(self, data: mjx.Data) -> jax.Array:
127+
"""Return the gyroscope readings in the local frame."""
128+
return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GYRO_SENSOR}")
129+
130+
def get_feet_ground_contacts(self, data: mjx.Data) -> jax.Array:
131+
"""Return an array indicating whether each foot is in contact with the ground."""
132+
left_feet_contact = jp.array(
133+
[
134+
geoms_colliding(data, geom_id, self._floor_geom_id)
135+
for geom_id in self._left_feet_geom_id
136+
]
137+
)
138+
right_feet_contact = jp.array(
139+
[
140+
geoms_colliding(data, geom_id, self._floor_geom_id)
141+
for geom_id in self._right_feet_geom_id
142+
]
143+
)
144+
return jp.hstack([jp.any(left_feet_contact), jp.any(right_feet_contact)])
145+
146+
# Accessors.
147+
148+
@property
149+
def xml_path(self) -> str:
150+
return self._xml_path
151+
152+
@property
153+
def action_size(self) -> int:
154+
return self._mjx_model.nu
155+
156+
@property
157+
def mj_model(self) -> mujoco.MjModel:
158+
return self._mj_model
159+
160+
@property
161+
def mjx_model(self) -> mjx.Model:
162+
return self._mjx_model
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Constants for Apollo."""
16+
17+
from etils import epath
18+
19+
from mujoco_playground._src import mjx_env
20+
21+
XML_DIR = mjx_env.ROOT_PATH / "locomotion" / "apollo" / "xmls"
22+
23+
FEET_ONLY_FLAT_TERRAIN_XML = XML_DIR / "scene_mjx_feetonly_flat_terrain.xml"
24+
25+
26+
def task_to_xml(task_name: str) -> epath.Path:
27+
return {
28+
"flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML,
29+
}[task_name]
30+
31+
32+
FEET_SITES = [
33+
"l_foot",
34+
"r_foot",
35+
]
36+
37+
HAND_SITES = [
38+
"left_palm",
39+
"right_palm",
40+
]
41+
42+
LEFT_FEET_GEOMS = ["collision_l_sole"]
43+
RIGHT_FEET_GEOMS = ["collision_r_sole"]
44+
FEET_GEOMS = LEFT_FEET_GEOMS + RIGHT_FEET_GEOMS
45+
46+
ROOT_BODY = "torso_link"
47+
48+
GRAVITY_SENSOR = "upvector"
49+
GLOBAL_LINVEL_SENSOR = "global_linvel"
50+
GLOBAL_ANGVEL_SENSOR = "global_angvel"
51+
LOCAL_LINVEL_SENSOR = "local_linvel"
52+
ACCELEROMETER_SENSOR = "accelerometer"
53+
GYRO_SENSOR = "gyro"

0 commit comments

Comments
 (0)