Skip to content

Commit 25f570d

Browse files
btabacopybara-github
authored andcommitted
Fix #174.
PiperOrigin-RevId: 795717368 Change-Id: Ibc2a4e619bcafde355f8e3379471325e4eb37346
1 parent 4981d05 commit 25f570d

File tree

3 files changed

+8
-37
lines changed

3 files changed

+8
-37
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ All notable changes to this project will be documented in this file.
66

77
- Pass through the [MuJoCo Warp](https://github.com/google-deepmind/mujoco_warp)
88
(MjWarp) implementation to MJX, so that MuJoCo Playground environments can
9-
train with MuJoCo Warp! DM Control Suite and Locomotion environments now
10-
support MjWarp. You can pass through the implementation via the config
9+
train with MuJoCo Warp! You can pass through the implementation via the config
1110
override
1211
`registry.load('CartpoleBalance', config_overrides={'impl': 'warp'})`.
1312
- Update environments to utilize contact sensors and remove `collision.py`.
13+
- Remove `mjx_env.init` in favor of `mjx_env.make_data` since `make_data`
14+
now requires an `MjModel` argument rather than an `mjx.Model` argument.
15+
- Add device to `mjx_env.make_data`, fixes #174.
1416

1517
## [0.0.5] - 2025-06-23
1618

mujoco_playground/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from mujoco_playground._src import wrapper
2121
from mujoco_playground._src import wrapper_torch
2222
# pylint: disable=g-importing-member
23-
from mujoco_playground._src.mjx_env import init
2423
from mujoco_playground._src.mjx_env import MjxEnv
2524
from mujoco_playground._src.mjx_env import render_array
2625
from mujoco_playground._src.mjx_env import State
@@ -30,7 +29,6 @@
3029

3130
__all__ = [
3231
"dm_control_suite",
33-
"init",
3432
"locomotion",
3533
"manipulation",
3634
"MjxEnv",

mujoco_playground/_src/mjx_env.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -126,38 +126,6 @@ def update_assets(
126126
update_assets(assets, f, glob, recursive)
127127

128128

129-
def init(
130-
model: mjx.Model,
131-
qpos: Optional[jax.Array] = None,
132-
qvel: Optional[jax.Array] = None,
133-
ctrl: Optional[jax.Array] = None,
134-
act: Optional[jax.Array] = None,
135-
mocap_pos: Optional[jax.Array] = None,
136-
mocap_quat: Optional[jax.Array] = None,
137-
) -> mjx.Data:
138-
"""Initialize MJX Data."""
139-
warnings.warn(
140-
"`init` will be removed in the next major release.",
141-
DeprecationWarning,
142-
stacklevel=2,
143-
)
144-
data = mjx.make_data(model)
145-
if qpos is not None:
146-
data = data.replace(qpos=qpos)
147-
if qvel is not None:
148-
data = data.replace(qvel=qvel)
149-
if ctrl is not None:
150-
data = data.replace(ctrl=ctrl)
151-
if act is not None:
152-
data = data.replace(act=act)
153-
if mocap_pos is not None:
154-
data = data.replace(mocap_pos=mocap_pos.reshape(model.nmocap, -1))
155-
if mocap_quat is not None:
156-
data = data.replace(mocap_quat=mocap_quat.reshape(model.nmocap, -1))
157-
data = mjx.forward(model, data)
158-
return data
159-
160-
161129
def make_data(
162130
model: mujoco.MjModel,
163131
qpos: Optional[jax.Array] = None,
@@ -169,9 +137,12 @@ def make_data(
169137
impl: Optional[str] = None,
170138
nconmax: Optional[int] = None,
171139
njmax: Optional[int] = None,
140+
device: Optional[jax.Device] = None,
172141
) -> mjx.Data:
173142
"""Initialize MJX Data."""
174-
data = mjx.make_data(model, impl=impl, nconmax=nconmax, njmax=njmax)
143+
data = mjx.make_data(
144+
model, impl=impl, nconmax=nconmax, njmax=njmax, device=device
145+
)
175146
if qpos is not None:
176147
data = data.replace(qpos=qpos)
177148
if qvel is not None:

0 commit comments

Comments
 (0)