Skip to content

Commit 160ad41

Browse files
Merge pull request #183 from google-deepmind/B7D4C144F4B8D76FD563D63AE71881C5
Split MJX tests from MuJoCo tests in Menagerie
2 parents f347540 + c6d9721 commit 160ad41

File tree

2 files changed

+84
-41
lines changed

2 files changed

+84
-41
lines changed

test/mjx_model_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2022 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for all MJX models."""
15+
16+
import pathlib
17+
from typing import List
18+
19+
from absl.testing import absltest
20+
from absl.testing import parameterized
21+
import jax
22+
import jax.numpy as jp
23+
import mujoco
24+
from mujoco import mjx
25+
26+
# Internal import.
27+
28+
29+
_ROOT_DIR = pathlib.Path(__file__).parent.parent
30+
_MODEL_DIRS = [f for f in _ROOT_DIR.iterdir() if f.is_dir()]
31+
_MJX_MODEL_XMLS: List[pathlib.Path] = []
32+
33+
34+
def _get_xmls(pattern: str) -> List[pathlib.Path]:
35+
for d in _MODEL_DIRS:
36+
# Produce tuples of test name and XML path.
37+
for f in d.glob(pattern):
38+
test_name = str(f).removeprefix(str(f.parent.parent))
39+
yield (test_name, f)
40+
41+
_MJX_MODEL_XMLS = list(_get_xmls('scene*mjx.xml'))
42+
43+
# Total simulation duration, in seconds.
44+
_MAX_SIM_TIME = 0.1
45+
46+
47+
class MjxModelsTest(parameterized.TestCase):
48+
"""Tests that MJX models load and do not return NaNs."""
49+
50+
@parameterized.named_parameters(_MJX_MODEL_XMLS)
51+
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
52+
model = mujoco.MjModel.from_xml_path(str(xml_path))
53+
model = mjx.put_model(model)
54+
data = mjx.make_data(model)
55+
ctrlrange = jp.where(
56+
model.actuator_ctrllimited[:, None],
57+
model.actuator_ctrlrange,
58+
jp.array([-10.0, 10.0]),
59+
)
60+
61+
def step(x, _):
62+
data, rng = x
63+
rng, key = jax.random.split(rng)
64+
ctrl = jax.random.uniform(
65+
key,
66+
shape=(model.nu,),
67+
minval=ctrlrange[:, 0],
68+
maxval=ctrlrange[:, 1],
69+
)
70+
data = mjx.step(model, data.replace(ctrl=ctrl))
71+
return (data, rng), ()
72+
73+
(data, _), _ = jax.lax.scan(
74+
step,
75+
(data, jax.random.PRNGKey(0)),
76+
(),
77+
length=min(_MAX_SIM_TIME // model.opt.timestep, 100),
78+
)
79+
80+
self.assertFalse(jp.isnan(data.qpos).any())
81+
82+
83+
if __name__ == '__main__':
84+
absltest.main()

test/model_test.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,14 @@
1818

1919
from absl.testing import absltest
2020
from absl.testing import parameterized
21-
import jax
22-
import jax.numpy as jp
2321
import mujoco
24-
from mujoco import mjx
2522

2623
# Internal import.
2724

2825

2926
_ROOT_DIR = pathlib.Path(__file__).parent.parent
3027
_MODEL_DIRS = [f for f in _ROOT_DIR.iterdir() if f.is_dir()]
3128
_MODEL_XMLS: List[pathlib.Path] = []
32-
_MJX_MODEL_XMLS: List[pathlib.Path] = []
3329

3430

3531
def _get_xmls(pattern: str) -> List[pathlib.Path]:
@@ -40,7 +36,6 @@ def _get_xmls(pattern: str) -> List[pathlib.Path]:
4036
yield (test_name, f)
4137

4238
_MODEL_XMLS = list(_get_xmls('scene*.xml'))
43-
_MJX_MODEL_XMLS = list(_get_xmls('scene*mjx.xml'))
4439

4540
# Total simulation duration, in seconds.
4641
_MAX_SIM_TIME = 0.1
@@ -86,41 +81,5 @@ def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
8681
self.fail(f'MuJoCo warning(s) encountered:\n{warning_info}')
8782

8883

89-
class MjxModelsTest(parameterized.TestCase):
90-
"""Tests that MJX models load and do not return NaNs."""
91-
92-
@parameterized.named_parameters(_MJX_MODEL_XMLS)
93-
def test_compiles_and_steps(self, xml_path: pathlib.Path) -> None:
94-
model = mujoco.MjModel.from_xml_path(str(xml_path))
95-
model = mjx.put_model(model)
96-
data = mjx.make_data(model)
97-
ctrlrange = jp.where(
98-
model.actuator_ctrllimited[:, None],
99-
model.actuator_ctrlrange,
100-
jp.array([-10.0, 10.0]),
101-
)
102-
103-
def step(x, _):
104-
data, rng = x
105-
rng, key = jax.random.split(rng)
106-
ctrl = jax.random.uniform(
107-
key,
108-
shape=(model.nu,),
109-
minval=ctrlrange[:, 0],
110-
maxval=ctrlrange[:, 1],
111-
)
112-
data = mjx.step(model, data.replace(ctrl=ctrl))
113-
return (data, rng), ()
114-
115-
(data, _), _ = jax.lax.scan(
116-
step,
117-
(data, jax.random.PRNGKey(0)),
118-
(),
119-
length=min(_MAX_SIM_TIME // model.opt.timestep, 100),
120-
)
121-
122-
self.assertFalse(jp.isnan(data.qpos).any())
123-
124-
12584
if __name__ == '__main__':
12685
absltest.main()

0 commit comments

Comments
 (0)