Skip to content

Commit ff66b08

Browse files
authored
Fix MuJoCo add_markers for mujoco>=3.2 (#1329)
1 parent d4dcc21 commit ff66b08

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

Diff for: gymnasium/envs/mujoco/mujoco_rendering.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66
import imageio
77
import mujoco
88
import numpy as np
9+
from packaging.version import Version
910

1011
from gymnasium.logger import warn
1112

1213

14+
# The marker API changed in MuJoCo 3.2.0, so we check the mujoco version and set a flag that
15+
# determines which function we use when adding markers to the scene.
16+
_MUJOCO_MARKER_LEGACY_MODE = Version(mujoco.__version__) < Version("3.2.0")
17+
18+
1319
def _import_egl(width, height):
1420
from mujoco.egl import GLContext
1521

@@ -88,8 +94,37 @@ def add_marker(self, **marker_params):
8894

8995
def _add_marker_to_scene(self, marker: dict):
9096
if self.scn.ngeom >= self.scn.maxgeom:
91-
raise RuntimeError("Ran out of geoms. maxgeom: %d" % self.scn.maxgeom)
97+
raise RuntimeError(f"Ran out of geoms. maxgeom: {self.scn.maxgeom}")
98+
99+
if _MUJOCO_MARKER_LEGACY_MODE: # Old API for markers requires special handling
100+
self._legacy_add_marker_to_scene(marker)
101+
else:
102+
geom_type = marker.get("type", mujoco.mjtGeom.mjGEOM_SPHERE)
103+
size = marker.get("size", np.array([0.01, 0.01, 0.01]))
104+
pos = marker.get("pos", np.array([0.0, 0.0, 0.0]))
105+
mat = marker.get("mat", np.eye(3).flatten())
106+
rgba = marker.get("rgba", np.array([1.0, 1.0, 1.0, 1.0]))
107+
mujoco.mjv_initGeom(
108+
self.scn.geoms[self.scn.ngeom],
109+
geom_type,
110+
size=size,
111+
pos=pos,
112+
mat=mat,
113+
rgba=rgba,
114+
)
115+
116+
self.scn.ngeom += 1
92117

118+
def _legacy_add_marker_to_scene(self, marker: dict):
119+
"""Add a marker to the scene compatible with older versions of MuJoCo.
120+
121+
MuJoCo 3.2 introduced breaking changes to the visual geometries API. To maintain
122+
compatibility with older versions, we use the legacy API when an older version of MuJoCo is
123+
detected.
124+
125+
Args:
126+
marker: A dictionary containing the marker parameters.
127+
"""
93128
g = self.scn.geoms[self.scn.ngeom]
94129
# default values.
95130
g.dataid = -1
@@ -130,8 +165,6 @@ def _add_marker_to_scene(self, marker: dict):
130165
else:
131166
raise ValueError("mjtGeom doesn't have field %s" % key)
132167

133-
self.scn.ngeom += 1
134-
135168
def close(self):
136169
"""Override close in your rendering subclass to perform any necessary cleanup
137170
after env.close() is called.

Diff for: pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ classic-control = ["pygame >=2.1.3"]
4141
classic_control = ["pygame >=2.1.3"] # kept for backward compatibility
4242
mujoco-py = ["mujoco-py >=2.1,<2.2", "cython<3"]
4343
mujoco_py = ["mujoco-py >=2.1,<2.2", "cython<3"] # kept for backward compatibility
44-
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1"]
44+
mujoco = ["mujoco >=2.1.5", "imageio >=2.14.1", "packaging >=23.0"]
4545
toy-text = ["pygame >=2.1.3"]
4646
toy_text = ["pygame >=2.1.3"] # kept for backward compatibility
4747
jax = ["jax >=0.4.16", "jaxlib >=0.4.16", "flax >=0.5.0"]
@@ -64,6 +64,7 @@ all = [
6464
# mujoco
6565
"mujoco >=2.1.5",
6666
"imageio >=2.14.1",
67+
"packaging >=23.0",
6768
# toy-text
6869
"pygame >=2.1.3",
6970
# jax

Diff for: tests/envs/mujoco/test_mujoco_rendering.py

+22
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,28 @@ def test_max_geom_attribute(
117117
viewer.close()
118118

119119

120+
@pytest.mark.parametrize(
121+
"render_mode", ["human", "rgb_array", "depth_array", "rgbd_tuple"]
122+
)
123+
def test_add_markers(model: mujoco.MjModel, data: mujoco.MjData, render_mode: str):
124+
"""Test that the add_markers function works correctly."""
125+
# initialize renderer
126+
renderer = ExposedViewerRenderer(
127+
model, data, width=DEFAULT_SIZE, height=DEFAULT_SIZE, max_geom=10
128+
)
129+
# initialize viewer via render
130+
viewer = renderer.get_viewer(render_mode)
131+
viewer.add_marker(
132+
pos=np.array([0, 0, 0]),
133+
size=np.array([1, 1, 1]),
134+
rgba=np.array([1, 0, 0, 1]),
135+
)
136+
args = tuple() if render_mode == "human" else (render_mode,)
137+
viewer.render(*args) # We need to render to trigger the marker addition in MuJoCo
138+
# close viewer after usage
139+
viewer.close()
140+
141+
120142
@pytest.mark.parametrize(
121143
"render_mode", ["human", "rgb_array", "depth_array", "rgbd_tuple"]
122144
)

0 commit comments

Comments
 (0)