diff --git a/mujoco_playground/_src/manipulation/__init__.py b/mujoco_playground/_src/manipulation/__init__.py
index 7ed178d45..767d5f7ba 100644
--- a/mujoco_playground/_src/manipulation/__init__.py
+++ b/mujoco_playground/_src/manipulation/__init__.py
@@ -28,7 +28,7 @@
from mujoco_playground._src.manipulation.franka_emika_panda_robotiq import push_cube as robotiq_push_cube
from mujoco_playground._src.manipulation.leap_hand import reorient as leap_cube_reorient
from mujoco_playground._src.manipulation.leap_hand import rotate_z as leap_rotate_z
-
+from mujoco_playground._src.manipulation.aero_hand import rotate_z as aero_hand_rotate_z
_envs = {
"AlohaHandOver": aloha_handover.HandOver,
@@ -40,6 +40,7 @@
"PandaRobotiqPushCube": robotiq_push_cube.PandaRobotiqPushCube,
"LeapCubeReorient": leap_cube_reorient.CubeReorient,
"LeapCubeRotateZAxis": leap_rotate_z.CubeRotateZAxis,
+ "AeroCubeRotateZAxis": aero_hand_rotate_z.CubeRotateZAxis,
}
_cfgs = {
@@ -52,11 +53,13 @@
"PandaRobotiqPushCube": robotiq_push_cube.default_config,
"LeapCubeReorient": leap_cube_reorient.default_config,
"LeapCubeRotateZAxis": leap_rotate_z.default_config,
+ "AeroCubeRotateZAxis": aero_hand_rotate_z.default_config,
}
_randomizer = {
"LeapCubeRotateZAxis": leap_rotate_z.domain_randomize,
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
+ "AeroCubeRotateZAxis": aero_hand_rotate_z.domain_randomize,
}
diff --git a/mujoco_playground/_src/manipulation/aero_hand/README.md b/mujoco_playground/_src/manipulation/aero_hand/README.md
new file mode 100644
index 000000000..5e2070744
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/README.md
@@ -0,0 +1,114 @@
+# Tetheria Aero Hand Open with Tendon-Driven Actuation
+
+This directory introduces a tendon-driven manipulation example that extends MuJoCo Playground with support for tendon-level control and observation in reinforcement learning tasks.
+
+The model is adapted from the [Tetheria Aero Hand Open](https://docs.tetheria.ai/), featuring a physically accurate tendon system that emulates cable-driven actuation. In this setup, both the policy inputs and observations are defined in the tendon space, providing a complete example of training and deploying tendon-driven controllers and under-actuated fingers in MuJoCo.
+
+An overview of the hand is shown below:
+
+|  |  |  |
+|------------------------|------------------------|------------------------|
+
+
+## 1. Tendon-Driven MuJoCo Model
+
+### 1.1 Modeling
+
+The mechanical design is derived from URDF files, ensuring accurate representation of the real hand structure. The actuation system in the simulator models the cable design in the real hand through three key components:
+
+#### 1.1.1 Tendon Drives
+The tendons drive the hand to close the fingers and control the thumbs. These are modeled as spatial tendons in MuJoCo that follow the exact routing paths of the real cables.
+
+#### 1.1.2 Springs
+The springs, which are also modeled by tendon components in MuJoCo, provide the forces to pull the fingers in the backward direction. This creates the restoring forces necessary for finger extension.
+
+#### 1.1.3 Pulleys
+The pulleys, which are modeled as cylinders, organize the cables and springs to ensure they are routed in a similar way to the real hand. Careful placement of these pulleys ensures accurate tendon routing.
+
+| front view| close-up of index|
+|------------------------|------------------------|
+|  | 
+
+### 1.2 Parameters
+
+#### 1.2.1 Mechanical Parameters
+- **Joint limits, mass, and inertia**: Come directly from URDF and are accurate to the real hand
+- **Pulley placement**: Positioned precisely where they are placed in the real hand, ensuring cable and spring routes match the real system
+- **Validation**: The range of tendon between fully open and fully closed fingers in simulation (0.0459454) closely matches the real hand (0.04553) without manual adjustment
+
+#### 1.2.2 Tendon and Spring Specifications
+- **Tendon properties**: Use the same specifications as those in the real hand
+- **Spring properties**: Match real hand specifications, except for the spring on the DIP joint, which is adjusted as a compromise to achieve similar joint space behavior as the real hand
+
+#### 1.2.3 Control Parameters
+All remaining parameters, including:
+- Joint damping values
+- Actuator gains
+- Joint-specific damping coefficients
+
+These are fine-tuned to satisfy both similar joint behaviors in simulation and the real world.
+
+
+## 2. Training your own policy
+
+We introduce a **z-axis rotation task** for the **Tetheria Aero Hand Open**, optimized using the following reward formulation:
+
+$$
+\text{reward} = 1.0 \times \text{angular velocity}
+ - 1.0 \times \text{action rate}
+ + \text{termination} (-100.0)
+$$
+
+The optimization variables include the **tendon lengths** and the **thumb abduction joint**, which correspond to the real hand’s actuation system.
+This setup ensures that the same control inputs and sensory data can be directly applied for **sim-to-real deployment** on the physical Tetheria Aero Hand Open.
+
+
+To train policies for the Tetheria Hand:
+
+```bash
+
+# Run the training script
+python learning/train_jax_ppo.py --env_name TetheriaCubeRotateZAxis
+```
+
+Although the reward curves from different training runs may vary due to stochasticity in the learning process, they consistently **converge toward a positive reward**.
+
+## 3. Running a pretrained policy
+
+
+To test trained policies in simulation:
+
+```bash
+# Run the simulation rollout script
+python learning/train_jax_ppo.py --env_name TetheriaCubeRotateZAxis --play_only --load_checkpoint_path path/to/checkpoints
+```
+
+This will:
+- Load the trained policy
+- Run episodes in the MuJoCo simulation
+- Display the hand performing manipulation tasks
+
+## File Structure
+
+### Core Implementation
+- **`tetheria_hand_tendon_constants.py`** - Constants and configuration
+- **`rotate_z.py`** - Cube rotation task implementation
+
+### XML Models
+- **`xmls/right_hand.xml`** - Main hand model with tendon system
+- **`xmls/scene_mjx_cube.xml`** - Manipulation scene
+- **`xmls/reorientation_cube.xml`** - Cube reorientation task
+
+## Key Features
+
+- **Accurate tendon modeling**: Direct translation from real hand cable system
+- **Precise pulley placement**: Matches real hand routing exactly
+- **Validated parameters**: Tendon ranges match real hand within 0.1%
+
+---
+
+*This implementation provides a high-fidelity tendon-driven hand model that closely matches the real robotic hand, enabling effective sim-to-real transfer for manipulation tasks.*
+
+## Acknowledgements
+Our code is built upon
+- MuJoCo playground - https://github.com/google-deepmind/mujoco_playground
\ No newline at end of file
diff --git a/mujoco_playground/_src/manipulation/aero_hand/__init__.py b/mujoco_playground/_src/manipulation/aero_hand/__init__.py
new file mode 100644
index 000000000..147be9334
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 TetherIA Inc.
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
diff --git a/mujoco_playground/_src/manipulation/aero_hand/aero_hand_constants.py b/mujoco_playground/_src/manipulation/aero_hand/aero_hand_constants.py
new file mode 100644
index 000000000..00e6143ad
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/aero_hand_constants.py
@@ -0,0 +1,87 @@
+# Copyright 2025 TetherIA Inc.
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Constants for TetherIA Aero Hand Open."""
+
+from mujoco_playground._src import mjx_env
+
+ROOT_PATH = mjx_env.ROOT_PATH / "manipulation" / "aero_hand"
+CUBE_XML = ROOT_PATH / "xmls" / "scene_mjx_cube.xml"
+
+NQ = 16
+NV = 16
+NU = 7
+
+JOINT_NAMES = [
+ # index
+ "right_index_mcp_flex",
+ "right_index_pip",
+ "right_index_dip",
+ # middle
+ "right_middle_mcp_flex",
+ "right_middle_pip",
+ "right_middle_dip",
+ # ring
+ "right_ring_mcp_flex",
+ "right_ring_pip",
+ "right_ring_dip",
+ # pinky
+ "right_pinky_mcp_flex",
+ "right_pinky_pip",
+ "right_pinky_dip",
+ # thumb
+ "right_thumb_cmc_abd",
+ "right_thumb_cmc_flex",
+ "right_thumb_mcp",
+ "right_thumb_ip",
+]
+
+ACTUATOR_NAMES = [
+ # index
+ "right_index_A_tendon",
+ # middle
+ "right_middle_A_tendon",
+ # ring
+ "right_ring_A_tendon",
+ # pinky
+ "right_pinky_A_tendon",
+ # thumb
+ "right_thumb_A_cmc_abd",
+ "right_th1_A_tendon",
+ "right_th2_A_tendon",
+]
+
+FINGERTIP_NAMES = [
+ "if_tip",
+ "mf_tip",
+ "rf_tip",
+ "pf_tip",
+ "th_tip",
+]
+
+
+SENSOR_TENDON_NAMES = [
+ "len_if",
+ "len_mf",
+ "len_rf",
+ "len_pf",
+ "len_th1",
+ "len_th2",
+]
+
+SENSOR_JOINT_NAMES = [
+ "len_th_abd",
+]
diff --git a/mujoco_playground/_src/manipulation/aero_hand/base.py b/mujoco_playground/_src/manipulation/aero_hand/base.py
new file mode 100644
index 000000000..df845d17e
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/base.py
@@ -0,0 +1,129 @@
+# Copyright 2025 TetherIA Inc.
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base classes for TetherIA Aero Hand Open."""
+
+from typing import Any, Dict, Optional, Union
+
+from etils import epath
+import jax
+import jax.numpy as jp
+from ml_collections import config_dict
+import mujoco
+from mujoco import mjx
+
+from mujoco_playground._src import mjx_env
+from mujoco_playground._src.manipulation.aero_hand import aero_hand_constants as consts
+
+
+def get_assets() -> Dict[str, bytes]:
+ assets = {}
+ path = mjx_env.MENAGERIE_PATH / "tetheria_aero_hand_open"
+ mjx_env.update_assets(assets, path / "assets")
+ mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls", "*.xml")
+ mjx_env.update_assets(
+ assets, consts.ROOT_PATH / "xmls" / "reorientation_cube_textures"
+ )
+ mjx_env.update_assets(assets, consts.ROOT_PATH / "xmls" / "assets")
+ return assets
+
+
+class AeroHandEnv(mjx_env.MjxEnv):
+ """Base class for Aero Hand environments."""
+
+ def __init__(
+ self,
+ xml_path: str,
+ config: config_dict.ConfigDict,
+ config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
+ ) -> None:
+ super().__init__(config, config_overrides)
+ self._model_assets = get_assets()
+ self._mj_model = mujoco.MjModel.from_xml_string(
+ epath.Path(xml_path).read_text(), assets=self._model_assets
+ )
+ self._mj_model.opt.timestep = self._config.sim_dt
+
+ self._mj_model.vis.global_.offwidth = 3840
+ self._mj_model.vis.global_.offheight = 2160
+
+ self._mjx_model = mjx.put_model(self._mj_model)
+ self._xml_path = xml_path
+
+ # Sensor readings.
+
+ def get_palm_position(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "palm_position")
+
+ def get_cube_position(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_position")
+
+ def get_cube_orientation(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_orientation")
+
+ def get_cube_linvel(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_linvel")
+
+ def get_cube_angvel(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_angvel")
+
+ def get_cube_angacc(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_angacc")
+
+ def get_cube_upvector(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_upvector")
+
+ def get_cube_goal_orientation(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_goal_orientation")
+
+ def get_cube_goal_upvector(self, data: mjx.Data) -> jax.Array:
+ return mjx_env.get_sensor_data(self.mj_model, data, "cube_goal_upvector")
+
+ def get_fingertip_positions(self, data: mjx.Data) -> jax.Array:
+ """Get fingertip positions relative to the grasp site."""
+ return jp.concatenate([
+ mjx_env.get_sensor_data(self.mj_model, data, f"{name}_position")
+ for name in consts.FINGERTIP_NAMES
+ ])
+
+ # Accessors.
+
+ @property
+ def xml_path(self) -> str:
+ return self._xml_path
+
+ @property
+ def action_size(self) -> int:
+ return self._mjx_model.nu
+
+ @property
+ def mj_model(self) -> mujoco.MjModel:
+ return self._mj_model
+
+ @property
+ def mjx_model(self) -> mjx.Model:
+ return self._mjx_model
+
+
+def uniform_quat(rng: jax.Array) -> jax.Array:
+ """Generate a random quaternion from a uniform distribution."""
+ u, v, w = jax.random.uniform(rng, (3,))
+ return jp.array([
+ jp.sqrt(1 - u) * jp.sin(2 * jp.pi * v),
+ jp.sqrt(1 - u) * jp.cos(2 * jp.pi * v),
+ jp.sqrt(u) * jp.sin(2 * jp.pi * w),
+ jp.sqrt(u) * jp.cos(2 * jp.pi * w),
+ ])
diff --git a/mujoco_playground/_src/manipulation/aero_hand/imgs/index_close_up.png b/mujoco_playground/_src/manipulation/aero_hand/imgs/index_close_up.png
new file mode 100644
index 000000000..ecb1805a1
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/imgs/index_close_up.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/imgs/paper.png b/mujoco_playground/_src/manipulation/aero_hand/imgs/paper.png
new file mode 100644
index 000000000..464ce8602
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/imgs/paper.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/imgs/rock.png b/mujoco_playground/_src/manipulation/aero_hand/imgs/rock.png
new file mode 100644
index 000000000..39f4350b8
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/imgs/rock.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/imgs/scissor.png b/mujoco_playground/_src/manipulation/aero_hand/imgs/scissor.png
new file mode 100644
index 000000000..8c7e8c566
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/imgs/scissor.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/imgs/skeleton.png b/mujoco_playground/_src/manipulation/aero_hand/imgs/skeleton.png
new file mode 100644
index 000000000..83c99429e
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/imgs/skeleton.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/rotate_z.py b/mujoco_playground/_src/manipulation/aero_hand/rotate_z.py
new file mode 100644
index 000000000..4f99b8982
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/rotate_z.py
@@ -0,0 +1,465 @@
+# Copyright 2025 TetherIA Inc.
+# Copyright 2025 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Rotate-z with TetherIA Aero Hand Open."""
+
+from typing import Any, Dict, Optional, Union
+
+import jax
+import jax.numpy as jp
+from ml_collections import config_dict
+from mujoco import mjx
+import numpy as np
+
+from mujoco_playground._src import mjx_env
+from mujoco_playground._src.manipulation.aero_hand import aero_hand_constants as consts
+from mujoco_playground._src.manipulation.aero_hand import base as aero_hand_base
+
+
+def default_config() -> config_dict.ConfigDict:
+ return config_dict.create(
+ ctrl_dt=0.05,
+ sim_dt=0.01,
+ action_scale=[0.02, 0.02, 0.02, 0.02, 0.7, 0.003, 0.012],
+ action_repeat=1,
+ episode_length=500,
+ early_termination=True,
+ history_len=1,
+ noise_config=config_dict.create(
+ level=1.0,
+ scales=config_dict.create(
+ joint_pos=0.05,
+ tendon_length=0.005,
+ ),
+ ),
+ reward_config=config_dict.create(
+ scales=config_dict.create(
+ angvel=1.0,
+ linvel=0.0,
+ pose=0.0,
+ torques=0.0,
+ energy=0.0,
+ termination=-100.0,
+ action_rate=-1.0,
+ ),
+ ),
+ )
+
+
+class CubeRotateZAxis(aero_hand_base.AeroHandEnv):
+ """Rotate a cube around the z-axis as fast as possible wihout dropping it."""
+
+ def __init__(
+ self,
+ config: config_dict.ConfigDict = default_config(),
+ config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
+ ):
+ super().__init__(
+ xml_path=consts.CUBE_XML.as_posix(),
+ config=config,
+ config_overrides=config_overrides,
+ )
+ self._post_init()
+
+ def _post_init(self) -> None:
+ self._hand_qids = mjx_env.get_qpos_ids(self.mj_model, consts.JOINT_NAMES)
+
+ self._hand_dqids = mjx_env.get_qvel_ids(self.mj_model, consts.JOINT_NAMES)
+ self._cube_qids = mjx_env.get_qpos_ids(self.mj_model, ["cube_freejoint"])
+ self._floor_geom_id = self._mj_model.geom("floor").id
+ self._cube_geom_id = self._mj_model.geom("cube").id
+
+ home_key = self._mj_model.keyframe("home")
+ self._init_q = jp.array(home_key.qpos)
+ self._default_pose = self._init_q[self._hand_qids]
+ self._lowers, self._uppers = self.mj_model.jnt_range[self._hand_qids].T
+
+ self._init_tendon = jp.array(home_key.ctrl)
+ self._default_tendon = self._init_tendon
+
+ def reset(self, rng: jax.Array) -> mjx_env.State:
+ # Randomize hand qpos and qvel.
+ rng, pos_rng, vel_rng = jax.random.split(rng, 3)
+ q_hand = jp.clip(
+ self._default_pose + 0.1 * jax.random.normal(pos_rng, (consts.NQ,)),
+ self._lowers,
+ self._uppers,
+ )
+ v_hand = 0.0 * jax.random.normal(vel_rng, (consts.NV,))
+
+ # Randomize cube qpos and qvel.
+ rng, p_rng, quat_rng = jax.random.split(rng, 3)
+ start_pos = jp.array([0.1, 0.0, 0.05]) + jax.random.uniform(
+ p_rng, (3,), minval=-0.01, maxval=0.01
+ )
+ start_quat = aero_hand_base.uniform_quat(quat_rng)
+ q_cube = jp.array([*start_pos, *start_quat])
+ v_cube = jp.zeros(6)
+
+ qpos = jp.concatenate([q_hand, q_cube])
+ qvel = jp.concatenate([v_hand, v_cube])
+ data = mjx_env.make_data(
+ self.mjx_model,
+ qpos=qpos,
+ qvel=qvel,
+ ctrl=self._default_tendon, # Change: only use the control tendons
+ mocap_pos=jp.array([-100, -100, -100]), # Hide goal for this task.
+ )
+
+ info = {
+ "rng": rng,
+ "last_act": jp.zeros(self.mjx_model.nu),
+ "last_last_act": jp.zeros(self.mjx_model.nu),
+ "motor_targets": data.ctrl,
+ "last_cube_angvel": jp.zeros(3),
+ }
+
+ metrics = {}
+ for k in self._config.reward_config.scales.keys():
+ metrics[f"reward/{k}"] = jp.zeros(())
+
+ # Change: 14 is the sum of the number of the tendon/joint sensors (7) and the number of the control actions (7)
+ obs_history = jp.zeros(self._config.history_len * 14)
+ obs = self._get_obs(data, info, obs_history)
+ reward, done = jp.zeros(2) # pylint: disable=redefined-outer-name
+ return mjx_env.State(data, obs, reward, done, metrics, info)
+
+ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
+
+ action_scale_custom = jp.array(self._config.action_scale, dtype=jp.float32)
+ motor_targets = self._default_tendon + action * action_scale_custom
+ # NOTE: no clipping.
+ data = mjx_env.step(
+ self.mjx_model, state.data, motor_targets, self.n_substeps
+ )
+ state.info["motor_targets"] = motor_targets
+
+ obs = self._get_obs(data, state.info, state.obs["state"])
+ done = self._get_termination(data)
+
+ rewards = self._get_reward(data, action, state.info, state.metrics, done)
+ rewards = {
+ k: v * self._config.reward_config.scales[k] for k, v in rewards.items()
+ }
+ reward = sum(rewards.values()) * self.dt # pylint: disable=redefined-outer-name
+
+ state.info["last_last_act"] = state.info["last_act"]
+ state.info["last_act"] = action
+ state.info["last_cube_angvel"] = self.get_cube_angvel(data)
+ for k, v in rewards.items():
+ state.metrics[f"reward/{k}"] = v
+
+ done = done.astype(reward.dtype)
+ state = state.replace(data=data, obs=obs, reward=reward, done=done)
+ return state
+
+ def _get_termination(self, data: mjx.Data) -> jax.Array:
+ fall_termination = self.get_cube_position(data)[2] < -0.05
+ return fall_termination
+
+ def _get_obs(
+ self, data: mjx.Data, info: dict[str, Any], obs_history: jax.Array
+ ) -> Dict[str, jax.Array]:
+
+ info["rng"], noise_rng = jax.random.split(info["rng"])
+
+ # ------- tendon length sensor -------
+ tendon_lengths = jp.zeros(
+ (len(consts.SENSOR_TENDON_NAMES),), dtype=jp.float32
+ )
+ for idx, name in enumerate(consts.SENSOR_TENDON_NAMES):
+ v = mjx_env.get_sensor_data(self.mj_model, data, name)
+ v = jp.ravel(v)[0]
+ tendon_lengths = tendon_lengths.at[idx].set(v)
+
+ info["rng"], noise_rng = jax.random.split(info["rng"])
+ noisy_tendon_lengths = (
+ tendon_lengths
+ + (2 * jax.random.uniform(noise_rng, shape=tendon_lengths.shape) - 1)
+ * self._config.noise_config.level
+ * self._config.noise_config.scales.tendon_length
+ )
+
+ # ------- joint angle sensor -------
+ joint_angles = jp.zeros((len(consts.SENSOR_JOINT_NAMES),), dtype=jp.float32)
+ for idx, name in enumerate(consts.SENSOR_JOINT_NAMES):
+ v = mjx_env.get_sensor_data(self.mj_model, data, name)
+ v = jp.ravel(v)[0]
+ joint_angles = joint_angles.at[idx].set(v)
+
+ info["rng"], noise_rng = jax.random.split(info["rng"])
+ noisy_joint_angles = (
+ joint_angles
+ + (2 * jax.random.uniform(noise_rng, shape=joint_angles.shape) - 1)
+ * self._config.noise_config.level
+ * self._config.noise_config.scales.joint_pos
+ )
+
+ state = jp.concatenate([
+ noisy_tendon_lengths,
+ noisy_joint_angles,
+ info["last_act"],
+ ])
+
+ joint_angles = data.qpos[self._hand_qids]
+ info["rng"], noise_rng = jax.random.split(info["rng"])
+ obs_history = jp.roll(obs_history, state.size)
+ obs_history = obs_history.at[: state.size].set(state)
+
+ cube_pos = self.get_cube_position(data)
+ palm_pos = self.get_palm_position(data)
+ cube_pos_error = palm_pos - cube_pos
+ cube_quat = self.get_cube_orientation(data)
+ cube_angvel = self.get_cube_angvel(data)
+ cube_linvel = self.get_cube_linvel(data)
+ fingertip_positions = self.get_fingertip_positions(data)
+ joint_torques = data.actuator_force
+
+ privileged_state = jp.concatenate([
+ state,
+ joint_angles,
+ data.qvel[self._hand_dqids],
+ joint_torques,
+ fingertip_positions,
+ cube_pos_error,
+ cube_quat,
+ cube_angvel,
+ cube_linvel,
+ ])
+
+ return {
+ "state": obs_history,
+ "privileged_state": privileged_state,
+ }
+
+ def _get_reward(
+ self,
+ data: mjx.Data,
+ action: jax.Array,
+ info: dict[str, Any],
+ metrics: dict[str, Any],
+ done: jax.Array,
+ ) -> dict[str, jax.Array]:
+ del metrics # Unused.
+ cube_pos = self.get_cube_position(data)
+ palm_pos = self.get_palm_position(data)
+ cube_pos_error = palm_pos - cube_pos
+ cube_angvel = self.get_cube_angvel(data)
+ cube_linvel = self.get_cube_linvel(data)
+ return {
+ "angvel": self._reward_angvel(cube_angvel, cube_pos_error),
+ "linvel": self._cost_linvel(cube_linvel),
+ "termination": done,
+ "action_rate": self._cost_action_rate(
+ action, info["last_act"], info["last_last_act"]
+ ),
+ "pose": self._cost_pose(data.qpos[self._hand_qids]),
+ "torques": self._cost_torques(data.actuator_force),
+ "energy": self._cost_energy(
+ data.qvel[self._hand_dqids], data.qfrc_actuator[self._hand_dqids]
+ ),
+ }
+
+ def _cost_torques(self, torques: jax.Array) -> jax.Array:
+ return jp.sum(jp.square(torques))
+
+ def _cost_energy(
+ self, qvel: jax.Array, qfrc_actuator: jax.Array
+ ) -> jax.Array:
+ return jp.sum(
+ jp.abs(qvel) * jp.abs(qfrc_actuator)
+ ) # Change: only use the control joints
+
+ def _cost_linvel(self, cube_linvel: jax.Array) -> jax.Array:
+ return jp.linalg.norm(cube_linvel, ord=1, axis=-1)
+
+ def _reward_angvel(
+ self, cube_angvel: jax.Array, cube_pos_error: jax.Array
+ ) -> jax.Array:
+ # Unconditionally maximize angvel in the z-direction.
+ del cube_pos_error # Unused.
+ return cube_angvel @ jp.array([0.0, 0.0, 1.0])
+
+ def _cost_action_rate(
+ self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array
+ ) -> jax.Array:
+ del last_last_act # Unused.
+ return jp.sum(jp.square(act - last_act))
+
+ def _cost_pose(self, joint_angles: jax.Array) -> jax.Array:
+ return jp.sum(jp.square(joint_angles - self._default_pose))
+
+
+def domain_randomize(model: mjx.Model, rng: jax.Array):
+ mj_model = CubeRotateZAxis().mj_model
+ cube_geom_id = mj_model.geom("cube").id
+ cube_body_id = mj_model.body("cube").id
+ hand_qids = mjx_env.get_qpos_ids(mj_model, consts.JOINT_NAMES)
+ hand_body_names = [
+ "palm",
+ "right_index_f_link",
+ "right_index_proximal_link",
+ "right_index_middle_link",
+ "right_index_distal_link",
+ "right_middle_f_link",
+ "right_middle_proximal_link",
+ "right_middle_middle_link",
+ "right_middle_distal_link",
+ "right_ring_f_link",
+ "right_ring_proximal_link",
+ "right_ring_middle_link",
+ "right_ring_distal_link",
+ "right_pinky_f_link",
+ "right_pinky_proximal_link",
+ "right_pinky_middle_link",
+ "right_pinky_distal_link",
+ "right_t_link",
+ "right_thumb_mcp_link",
+ "right_thumb_proximal_link",
+ "right_thumb_distal_link",
+ ]
+ hand_body_ids = np.array([mj_model.body(n).id for n in hand_body_names])
+ fingertip_geoms = ["if_tip", "mf_tip", "rf_tip", "pf_tip", "th_tip"]
+ fingertip_geom_ids = [mj_model.geom(g).id for g in fingertip_geoms]
+
+ @jax.vmap
+ def rand(rng):
+ # Cube friction: =U(0.1, 0.5).
+ rng, key = jax.random.split(rng)
+ cube_friction = jax.random.uniform(key, (1,), minval=0.1, maxval=0.5)
+ geom_friction = model.geom_friction.at[
+ cube_geom_id : cube_geom_id + 1, 0
+ ].set(cube_friction)
+
+ # Fingertip friction: =U(0.5, 1.0).
+ fingertip_friction = jax.random.uniform(key, (1,), minval=0.5, maxval=1.0)
+ geom_friction = model.geom_friction.at[fingertip_geom_ids, 0].set(
+ fingertip_friction
+ )
+
+ # Scale cube mass: *U(0.8, 1.2).
+ rng, key1, key2 = jax.random.split(rng, 3)
+ dmass = jax.random.uniform(key1, minval=0.8, maxval=1.2)
+ cube_mass = model.body_mass[cube_body_id]
+ body_mass = model.body_mass.at[cube_body_id].set(cube_mass * dmass)
+ body_inertia = model.body_inertia.at[cube_body_id].set(
+ model.body_inertia[cube_body_id] * dmass
+ )
+ dpos = jax.random.uniform(key2, (3,), minval=-5e-3, maxval=5e-3)
+ body_ipos = model.body_ipos.at[cube_body_id].set(
+ model.body_ipos[cube_body_id] + dpos
+ )
+
+ # Jitter qpos0: +U(-0.05, 0.05).
+ rng, key = jax.random.split(rng)
+ qpos0 = model.qpos0
+ qpos0 = qpos0.at[hand_qids].set(
+ qpos0[hand_qids]
+ + jax.random.uniform(key, shape=(16,), minval=-0.05, maxval=0.05)
+ )
+
+ # Scale static friction: *U(0.9, 1.1).
+ rng, key = jax.random.split(rng)
+ frictionloss = model.dof_frictionloss[hand_qids] * jax.random.uniform(
+ key, shape=(16,), minval=0.5, maxval=2.0
+ )
+ dof_frictionloss = model.dof_frictionloss.at[hand_qids].set(frictionloss)
+
+ # Scale armature: *U(1.0, 1.05).
+ rng, key = jax.random.split(rng)
+ armature = model.dof_armature[hand_qids] * jax.random.uniform(
+ key, shape=(16,), minval=1.0, maxval=1.05
+ )
+ dof_armature = model.dof_armature.at[hand_qids].set(armature)
+
+ # Scale all link masses: *U(0.9, 1.1).
+ rng, key = jax.random.split(rng)
+ dmass = jax.random.uniform(
+ key, shape=(len(hand_body_ids),), minval=0.9, maxval=1.1
+ )
+ body_mass = model.body_mass.at[hand_body_ids].set(
+ model.body_mass[hand_body_ids] * dmass
+ )
+
+ # Joint stiffness: *U(0.8, 1.2).
+ rng, key = jax.random.split(rng)
+ kp = model.actuator_gainprm[:, 0] * jax.random.uniform(
+ key, (model.nu,), minval=0.8, maxval=1.2
+ )
+ actuator_gainprm = model.actuator_gainprm.at[:, 0].set(kp)
+ actuator_biasprm = model.actuator_biasprm.at[:, 1].set(-kp)
+
+ # Joint damping: *U(0.8, 1.2).
+ rng, key = jax.random.split(rng)
+ kd = model.dof_damping[hand_qids] * jax.random.uniform(
+ key, (16,), minval=0.8, maxval=1.2
+ )
+ dof_damping = model.dof_damping.at[hand_qids].set(kd)
+
+ return (
+ geom_friction,
+ body_mass,
+ body_inertia,
+ body_ipos,
+ qpos0,
+ dof_frictionloss,
+ dof_armature,
+ dof_damping,
+ actuator_gainprm,
+ actuator_biasprm,
+ )
+
+ (
+ geom_friction,
+ body_mass,
+ body_inertia,
+ body_ipos,
+ qpos0,
+ dof_frictionloss,
+ dof_armature,
+ dof_damping,
+ actuator_gainprm,
+ actuator_biasprm,
+ ) = rand(rng)
+
+ in_axes = jax.tree_util.tree_map(lambda x: None, model)
+ in_axes = in_axes.tree_replace({
+ "geom_friction": 0,
+ "body_mass": 0,
+ "body_inertia": 0,
+ "body_ipos": 0,
+ "qpos0": 0,
+ "dof_frictionloss": 0,
+ "dof_armature": 0,
+ "dof_damping": 0,
+ "actuator_gainprm": 0,
+ "actuator_biasprm": 0,
+ })
+
+ model = model.tree_replace({
+ "geom_friction": geom_friction,
+ "body_mass": body_mass,
+ "body_inertia": body_inertia,
+ "body_ipos": body_ipos,
+ "qpos0": qpos0,
+ "dof_frictionloss": dof_frictionloss,
+ "dof_armature": dof_armature,
+ "dof_damping": dof_damping,
+ "actuator_gainprm": actuator_gainprm,
+ "actuator_biasprm": actuator_biasprm,
+ })
+
+ return model, in_axes
diff --git a/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube.xml b/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube.xml
new file mode 100644
index 000000000..3496382cf
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube_textures/dex_cube.png b/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube_textures/dex_cube.png
new file mode 100644
index 000000000..3fa3f834f
Binary files /dev/null and b/mujoco_playground/_src/manipulation/aero_hand/xmls/reorientation_cube_textures/dex_cube.png differ
diff --git a/mujoco_playground/_src/manipulation/aero_hand/xmls/right_hand.xml b/mujoco_playground/_src/manipulation/aero_hand/xmls/right_hand.xml
new file mode 100644
index 000000000..4a71bd7d2
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/xmls/right_hand.xml
@@ -0,0 +1,751 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/mujoco_playground/_src/manipulation/aero_hand/xmls/scene_mjx_cube.xml b/mujoco_playground/_src/manipulation/aero_hand/xmls/scene_mjx_cube.xml
new file mode 100644
index 000000000..91bd6bbf1
--- /dev/null
+++ b/mujoco_playground/_src/manipulation/aero_hand/xmls/scene_mjx_cube.xml
@@ -0,0 +1,64 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py
index b17413211..874c41324 100644
--- a/mujoco_playground/_src/mjx_env.py
+++ b/mujoco_playground/_src/mjx_env.py
@@ -37,7 +37,7 @@
# Resource paths do not have glob implemented, so we use a bare epath.Path.
MENAGERIE_PATH = EXTERNAL_DEPS_PATH / "mujoco_menagerie"
# Commit SHA of the menagerie repo.
-MENAGERIE_COMMIT_SHA = "14ceccf557cc47240202f2354d684eca58ff8de4"
+MENAGERIE_COMMIT_SHA = "1b86ece576591213e2b666ebf59508454200ca97"
def _clone_with_progress(
diff --git a/mujoco_playground/config/manipulation_params.py b/mujoco_playground/config/manipulation_params.py
index a8a72386e..9a84e476c 100644
--- a/mujoco_playground/config/manipulation_params.py
+++ b/mujoco_playground/config/manipulation_params.py
@@ -166,6 +166,24 @@ def brax_ppo_config(
value_obs_key="privileged_state",
)
rl_config.num_resets_per_eval = 1
+ elif env_name == "AeroCubeRotateZAxis":
+ rl_config.num_timesteps = 300_000_000
+ rl_config.num_evals = 10
+ rl_config.num_minibatches = 32
+ rl_config.unroll_length = 40
+ rl_config.num_updates_per_batch = 4
+ rl_config.discounting = 0.97
+ rl_config.learning_rate = 3e-4
+ rl_config.entropy_cost = 1e-2
+ rl_config.num_envs = 8192
+ rl_config.batch_size = 256
+ rl_config.num_resets_per_eval = 1
+ rl_config.network_factory = config_dict.create(
+ policy_hidden_layer_sizes=(512, 256, 128),
+ value_hidden_layer_sizes=(512, 256, 128),
+ policy_obs_key="state",
+ value_obs_key="privileged_state",
+ )
else:
raise ValueError(f"Unsupported env: {env_name}")