Skip to content

Commit 4424fa1

Browse files
kevinzakkacopybara-github
authored andcommitted
Replace hardcoded observation size with jax.eval_shape.
PiperOrigin-RevId: 720599097 Change-Id: I71b7914417b389155c716b08c08cf5781adfdfe5
1 parent 5305951 commit 4424fa1

37 files changed

+3
-161
lines changed

mujoco_playground/_src/dm_control_suite/acrobot.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,6 @@ def xml_path(self) -> str:
142142
def action_size(self) -> int:
143143
return self.mjx_model.nu
144144

145-
@property
146-
def observation_size(self) -> mjx_env.ObservationSize:
147-
return 6
148-
149145
@property
150146
def mj_model(self) -> mujoco.MjModel:
151147
return self._mj_model

mujoco_playground/_src/dm_control_suite/ball_in_cup.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ def xml_path(self) -> str:
116116
def action_size(self) -> int:
117117
return self.mjx_model.nu
118118

119-
@property
120-
def observation_size(self) -> mjx_env.ObservationSize:
121-
return 8
122-
123119
@property
124120
def mj_model(self) -> mujoco.MjModel:
125121
return self._mj_model

mujoco_playground/_src/dm_control_suite/cartpole.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,6 @@ def xml_path(self) -> str:
278278
def action_size(self) -> int:
279279
return self.mjx_model.nu
280280

281-
@property
282-
def observation_size(self) -> mjx_env.ObservationSize:
283-
return 5
284-
285281
@property
286282
def mj_model(self) -> mujoco.MjModel:
287283
return self._mj_model

mujoco_playground/_src/dm_control_suite/cheetah.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ def xml_path(self) -> str:
135135
def action_size(self) -> int:
136136
return self.mjx_model.nu
137137

138-
@property
139-
def observation_size(self) -> mjx_env.ObservationSize:
140-
return 17
141-
142138
@property
143139
def mj_model(self) -> mujoco.MjModel:
144140
return self._mj_model

mujoco_playground/_src/dm_control_suite/dm_control_suite_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_can_create_all_environments(self, env_name: str) -> None:
3333
state = jax.jit(env.reset)(jax.random.PRNGKey(42))
3434
state = jax.jit(env.step)(state, jp.zeros(env.action_size))
3535
self.assertIsNotNone(state)
36-
self.assertEqual(state.obs.shape[0], env.observation_size)
36+
self.assertEqual(state.obs.shape, env.observation_size)
3737
self.assertFalse(jp.isnan(state.data.qpos).any())
3838

3939

mujoco_playground/_src/dm_control_suite/finger.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,6 @@ def sim_dt(self) -> float:
188188
def action_size(self) -> int:
189189
return self.mjx_model.nu
190190

191-
@property
192-
def observation_size(self) -> mjx_env.ObservationSize:
193-
return 9
194-
195191
@property
196192
def mj_model(self) -> mujoco.MjModel:
197193
return self._mj_model
@@ -344,10 +340,6 @@ def xml_path(self) -> str:
344340
def action_size(self) -> int:
345341
return self.mjx_model.nu
346342

347-
@property
348-
def observation_size(self) -> mjx_env.ObservationSize:
349-
return 12
350-
351343
@property
352344
def mj_model(self) -> mujoco.MjModel:
353345
return self._mj_model

mujoco_playground/_src/dm_control_suite/fish.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,6 @@ def xml_path(self) -> str:
187187
def action_size(self) -> int:
188188
return self.mjx_model.nu
189189

190-
@property
191-
def observation_size(self) -> mjx_env.ObservationSize:
192-
return 24
193-
194190
@property
195191
def mj_model(self) -> mujoco.MjModel:
196192
return self._mj_model

mujoco_playground/_src/dm_control_suite/hopper.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,6 @@ def xml_path(self) -> str:
194194
def action_size(self) -> int:
195195
return self.mjx_model.nu
196196

197-
@property
198-
def observation_size(self) -> mjx_env.ObservationSize:
199-
return 15
200-
201197
@property
202198
def mj_model(self) -> mujoco.MjModel:
203199
return self._mj_model

mujoco_playground/_src/dm_control_suite/humanoid.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ def xml_path(self) -> str:
215215
def action_size(self) -> int:
216216
return self.mjx_model.nu
217217

218-
@property
219-
def observation_size(self) -> mjx_env.ObservationSize:
220-
return 67
221-
222218
@property
223219
def mj_model(self) -> mujoco.MjModel:
224220
return self._mj_model

mujoco_playground/_src/dm_control_suite/pendulum.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ def xml_path(self) -> str:
132132
def action_size(self) -> int:
133133
return self.mjx_model.nu
134134

135-
@property
136-
def observation_size(self) -> mjx_env.ObservationSize:
137-
return 3
138-
139135
@property
140136
def mj_model(self) -> mujoco.MjModel:
141137
return self._mj_model

0 commit comments

Comments
 (0)