|
19 | 19 |
|
20 | 20 | import jax |
21 | 21 | import jax.numpy as jp |
22 | | -import mujoco |
23 | 22 | from ml_collections import config_dict |
| 23 | +import mujoco |
24 | 24 | from mujoco import mjx |
25 | | -from mujoco_playground._src import mjx_env, reward |
| 25 | + |
| 26 | +from mujoco_playground._src import mjx_env |
| 27 | +from mujoco_playground._src import reward |
26 | 28 | from mujoco_playground._src.dm_control_suite import common |
27 | 29 |
|
28 | 30 | _XML_PATH = mjx_env.ROOT_PATH / "dm_control_suite" / "xmls" / "quadruped.xml" |
|
32 | 34 |
|
33 | 35 |
|
34 | 36 | def default_config() -> config_dict.ConfigDict: |
35 | | - return config_dict.create( |
36 | | - ctrl_dt=0.02, |
37 | | - sim_dt=0.005, |
38 | | - episode_length=1000, |
39 | | - action_repeat=1, |
40 | | - vision=False, |
41 | | - ) |
42 | | - |
43 | | - |
44 | | -def _find_non_contacting_height(mjx_model, data, orientation, x_pos=0.0, y_pos=0.0): |
45 | | - def body_fn(state): |
46 | | - z_pos, num_contacts, num_attempts, _ = state |
47 | | - qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos])) |
48 | | - qpos = qpos.at[3:7].set(jp.array(orientation)) |
49 | | - ndata = data.replace(qpos=qpos) |
50 | | - ndata = mjx.forward(mjx_model, ndata) |
51 | | - num_contacts = ndata.ncon |
52 | | - z_pos += 0.01 |
53 | | - num_attempts += 1 |
54 | | - return (z_pos, num_contacts, num_attempts, ndata) |
55 | | - |
56 | | - initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts) |
57 | | - *_, num_attemps, ndata = jax.lax.while_loop( |
58 | | - lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000), |
59 | | - body_fn, |
60 | | - initial_state, |
61 | | - ) |
62 | | - ndata = jax.tree_map( |
63 | | - lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data |
64 | | - ) |
65 | | - return ndata |
| 37 | + return config_dict.create( |
| 38 | + ctrl_dt=0.02, |
| 39 | + sim_dt=0.005, |
| 40 | + episode_length=1000, |
| 41 | + action_repeat=1, |
| 42 | + vision=False, |
| 43 | + ) |
| 44 | + |
| 45 | + |
| 46 | +def _find_non_contacting_height( |
| 47 | + mjx_model, data, orientation, x_pos=0.0, y_pos=0.0 |
| 48 | +): |
| 49 | + def body_fn(state): |
| 50 | + z_pos, num_contacts, num_attempts, _ = state |
| 51 | + qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos])) |
| 52 | + qpos = qpos.at[3:7].set(jp.array(orientation)) |
| 53 | + ndata = data.replace(qpos=qpos) |
| 54 | + ndata = mjx.forward(mjx_model, ndata) |
| 55 | + num_contacts = ndata.ncon |
| 56 | + z_pos += 0.01 |
| 57 | + num_attempts += 1 |
| 58 | + return (z_pos, num_contacts, num_attempts, ndata) |
| 59 | + |
| 60 | + initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts) |
| 61 | + *_, num_attemps, ndata = jax.lax.while_loop( |
| 62 | + lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000), |
| 63 | + body_fn, |
| 64 | + initial_state, |
| 65 | + ) |
| 66 | + ndata = jax.tree_map( |
| 67 | + lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data |
| 68 | + ) |
| 69 | + return ndata |
66 | 70 |
|
67 | 71 |
|
68 | 72 | class Quadruped(mjx_env.MjxEnv): |
69 | | - """Quadruped environment.""" |
70 | | - |
71 | | - def __init__( |
72 | | - self, |
73 | | - desired_speed: float, |
74 | | - config: config_dict.ConfigDict = default_config(), |
75 | | - config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, |
76 | | - ): |
77 | | - super().__init__(config, config_overrides) |
78 | | - if self._config.vision: |
79 | | - raise NotImplementedError("Vision not implemented for Quadruped.") |
80 | | - self._desired_speed = desired_speed |
81 | | - self._xml_path = _XML_PATH.as_posix() |
82 | | - self._mj_model = mujoco.MjModel.from_xml_string( |
83 | | - _XML_PATH.read_text(), common.get_assets() |
84 | | - ) |
85 | | - self._mj_model.opt.timestep = self.sim_dt |
86 | | - self._mjx_model = mjx.put_model(self._mj_model) |
87 | | - self._post_init() |
88 | | - |
89 | | - def _post_init(self): |
90 | | - self._force_torque_names = [ |
91 | | - f"{f}_toe_{pos}_{side}" |
92 | | - for (f, pos, side) in product( |
93 | | - ("force", "torque"), ("front", "back"), ("left", "right") |
94 | | - ) |
95 | | - ] |
96 | | - self._torso_id = self._mj_model.body("torso").id |
97 | | - |
98 | | - def reset(self, rng: jax.Array) -> mjx_env.State: |
99 | | - data = mjx_env.init(self.mjx_model) |
100 | | - metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())} |
101 | | - info = {"rng": rng} |
102 | | - reward, done = jp.zeros(2) |
103 | | - obs = self._get_obs(data, info) |
104 | | - return mjx_env.State(data, obs, reward, done, metrics, info) |
105 | | - |
106 | | - def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: |
107 | | - lower, upper = ( |
108 | | - self._mj_model.actuator_ctrlrange[:, 0], |
109 | | - self._mj_model.actuator_ctrlrange[:, 1], |
110 | | - ) |
111 | | - action = (action + 1.0) / 2.0 * (upper - lower) + lower |
112 | | - data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps) |
113 | | - reward = self._get_reward(data, action, state.info, state.metrics) |
114 | | - obs = self._get_obs(data, state.info) |
115 | | - done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() |
116 | | - done = done.astype(float) |
117 | | - return mjx_env.State(data, obs, reward, done, state.metrics, state.info) |
118 | | - |
119 | | - def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: |
120 | | - del info |
121 | | - ego = self._egocentric_state(data) |
122 | | - torso_vel = self.torso_velocity(data) |
123 | | - upright = self.torso_upright(data) |
124 | | - imu = self.imu(data) |
125 | | - force_torque = self.force_torque(data) |
126 | | - return jp.hstack((ego, torso_vel, upright, imu, force_torque)) |
127 | | - |
128 | | - def _get_reward( |
129 | | - self, |
130 | | - data: mjx.Data, |
131 | | - action: jax.Array, |
132 | | - info: dict[str, Any], |
133 | | - metrics: dict[str, Any], |
134 | | - ) -> jax.Array: |
135 | | - del info, action |
136 | | - move_reward = reward.tolerance( |
137 | | - self.torso_velocity(data)[0], |
138 | | - bounds=(self._desired_speed, float("inf")), |
139 | | - sigmoid="linear", |
140 | | - margin=self._desired_speed, |
141 | | - value_at_margin=0.5, |
142 | | - ) |
143 | | - upright_reward = self._upright_reward(data) |
144 | | - metrics["reward/move"] = move_reward |
145 | | - metrics["reward/upright"] = upright_reward |
146 | | - return move_reward * upright_reward |
147 | | - |
148 | | - def _upright_reward(self, data: mjx.Data) -> jax.Array: |
149 | | - upright = self.torso_upright(data) |
150 | | - return reward.tolerance( |
151 | | - upright, |
152 | | - bounds=(1, float("inf")), |
153 | | - sigmoid="linear", |
154 | | - margin=2, |
155 | | - value_at_margin=0, |
| 73 | + """Quadruped environment.""" |
| 74 | + |
| 75 | + def __init__( |
| 76 | + self, |
| 77 | + desired_speed: float, |
| 78 | + config: config_dict.ConfigDict = default_config(), |
| 79 | + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, |
| 80 | + ): |
| 81 | + super().__init__(config, config_overrides) |
| 82 | + if self._config.vision: |
| 83 | + raise NotImplementedError("Vision not implemented for Quadruped.") |
| 84 | + self._desired_speed = desired_speed |
| 85 | + self._xml_path = _XML_PATH.as_posix() |
| 86 | + self._mj_model = mujoco.MjModel.from_xml_string( |
| 87 | + _XML_PATH.read_text(), common.get_assets() |
| 88 | + ) |
| 89 | + self._mj_model.opt.timestep = self.sim_dt |
| 90 | + self._mjx_model = mjx.put_model(self._mj_model) |
| 91 | + self._post_init() |
| 92 | + |
| 93 | + def _post_init(self): |
| 94 | + self._force_torque_names = [ |
| 95 | + f"{f}_toe_{pos}_{side}" |
| 96 | + for (f, pos, side) in product( |
| 97 | + ("force", "torque"), ("front", "back"), ("left", "right") |
156 | 98 | ) |
| 99 | + ] |
| 100 | + self._torso_id = self._mj_model.body("torso").id |
| 101 | + |
| 102 | + def reset(self, rng: jax.Array) -> mjx_env.State: |
| 103 | + data = mjx_env.init(self.mjx_model) |
| 104 | + metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())} |
| 105 | + info = {"rng": rng} |
| 106 | + reward, done = jp.zeros(2) |
| 107 | + obs = self._get_obs(data, info) |
| 108 | + return mjx_env.State(data, obs, reward, done, metrics, info) |
| 109 | + |
| 110 | + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: |
| 111 | + lower, upper = ( |
| 112 | + self._mj_model.actuator_ctrlrange[:, 0], |
| 113 | + self._mj_model.actuator_ctrlrange[:, 1], |
| 114 | + ) |
| 115 | + action = (action + 1.0) / 2.0 * (upper - lower) + lower |
| 116 | + data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps) |
| 117 | + reward = self._get_reward(data, action, state.info, state.metrics) |
| 118 | + obs = self._get_obs(data, state.info) |
| 119 | + done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() |
| 120 | + done = done.astype(float) |
| 121 | + return mjx_env.State(data, obs, reward, done, state.metrics, state.info) |
| 122 | + |
| 123 | + def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: |
| 124 | + del info |
| 125 | + ego = self._egocentric_state(data) |
| 126 | + torso_vel = self.torso_velocity(data) |
| 127 | + upright = self.torso_upright(data) |
| 128 | + imu = self.imu(data) |
| 129 | + force_torque = self.force_torque(data) |
| 130 | + return jp.hstack((ego, torso_vel, upright, imu, force_torque)) |
| 131 | + |
| 132 | + def _get_reward( |
| 133 | + self, |
| 134 | + data: mjx.Data, |
| 135 | + action: jax.Array, |
| 136 | + info: dict[str, Any], |
| 137 | + metrics: dict[str, Any], |
| 138 | + ) -> jax.Array: |
| 139 | + del info, action |
| 140 | + move_reward = reward.tolerance( |
| 141 | + self.torso_velocity(data)[0], |
| 142 | + bounds=(self._desired_speed, float("inf")), |
| 143 | + sigmoid="linear", |
| 144 | + margin=self._desired_speed, |
| 145 | + value_at_margin=0.5, |
| 146 | + ) |
| 147 | + upright_reward = self._upright_reward(data) |
| 148 | + metrics["reward/move"] = move_reward |
| 149 | + metrics["reward/upright"] = upright_reward |
| 150 | + return move_reward * upright_reward |
| 151 | + |
| 152 | + def _upright_reward(self, data: mjx.Data) -> jax.Array: |
| 153 | + upright = self.torso_upright(data) |
| 154 | + return reward.tolerance( |
| 155 | + upright, |
| 156 | + bounds=(1, float("inf")), |
| 157 | + sigmoid="linear", |
| 158 | + margin=2, |
| 159 | + value_at_margin=0, |
| 160 | + ) |
157 | 161 |
|
158 | | - def _egocentric_state(self, data: mjx.Data) -> jax.Array: |
159 | | - return jp.hstack((data.qpos[7:], data.qvel[7:], data.act)) |
| 162 | + def _egocentric_state(self, data: mjx.Data) -> jax.Array: |
| 163 | + return jp.hstack((data.qpos[7:], data.qvel[7:], data.act)) |
160 | 164 |
|
161 | | - def torso_upright(self, data: mjx.Data) -> jax.Array: |
162 | | - return data.xmat[self._torso_id, 2, 2] |
| 165 | + def torso_upright(self, data: mjx.Data) -> jax.Array: |
| 166 | + return data.xmat[self._torso_id, 2, 2] |
163 | 167 |
|
164 | | - def torso_velocity(self, data: mjx.Data) -> jax.Array: |
165 | | - return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter") |
| 168 | + def torso_velocity(self, data: mjx.Data) -> jax.Array: |
| 169 | + return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter") |
166 | 170 |
|
167 | | - def imu(self, data: mjx.Data) -> jax.Array: |
168 | | - gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro") |
169 | | - accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel") |
170 | | - return jp.hstack((gyro, accelerometer)) |
| 171 | + def imu(self, data: mjx.Data) -> jax.Array: |
| 172 | + gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro") |
| 173 | + accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel") |
| 174 | + return jp.hstack((gyro, accelerometer)) |
171 | 175 |
|
172 | | - def force_torque(self, data: mjx.Data) -> jax.Array: |
173 | | - return jp.hstack( |
174 | | - [ |
175 | | - mjx_env.get_sensor_data(self.mj_model, data, name) |
176 | | - for name in self._force_torque_names |
177 | | - ] |
178 | | - ) |
| 176 | + def force_torque(self, data: mjx.Data) -> jax.Array: |
| 177 | + return jp.hstack([ |
| 178 | + mjx_env.get_sensor_data(self.mj_model, data, name) |
| 179 | + for name in self._force_torque_names |
| 180 | + ]) |
179 | 181 |
|
180 | | - @property |
181 | | - def xml_path(self) -> str: |
182 | | - return self._xml_path |
| 182 | + @property |
| 183 | + def xml_path(self) -> str: |
| 184 | + return self._xml_path |
183 | 185 |
|
184 | | - @property |
185 | | - def action_size(self) -> int: |
186 | | - return self.mjx_model.nu |
| 186 | + @property |
| 187 | + def action_size(self) -> int: |
| 188 | + return self.mjx_model.nu |
187 | 189 |
|
188 | | - @property |
189 | | - def mj_model(self) -> mujoco.MjModel: |
190 | | - return self._mj_model |
| 190 | + @property |
| 191 | + def mj_model(self) -> mujoco.MjModel: |
| 192 | + return self._mj_model |
191 | 193 |
|
192 | | - @property |
193 | | - def mjx_model(self) -> mjx.Model: |
194 | | - return self._mjx_model |
| 194 | + @property |
| 195 | + def mjx_model(self) -> mjx.Model: |
| 196 | + return self._mjx_model |
0 commit comments