@@ -126,38 +126,6 @@ def update_assets(
126126 update_assets (assets , f , glob , recursive )
127127
128128
129- def init (
130- model : mjx .Model ,
131- qpos : Optional [jax .Array ] = None ,
132- qvel : Optional [jax .Array ] = None ,
133- ctrl : Optional [jax .Array ] = None ,
134- act : Optional [jax .Array ] = None ,
135- mocap_pos : Optional [jax .Array ] = None ,
136- mocap_quat : Optional [jax .Array ] = None ,
137- ) -> mjx .Data :
138- """Initialize MJX Data."""
139- warnings .warn (
140- "`init` will be removed in the next major release." ,
141- DeprecationWarning ,
142- stacklevel = 2 ,
143- )
144- data = mjx .make_data (model )
145- if qpos is not None :
146- data = data .replace (qpos = qpos )
147- if qvel is not None :
148- data = data .replace (qvel = qvel )
149- if ctrl is not None :
150- data = data .replace (ctrl = ctrl )
151- if act is not None :
152- data = data .replace (act = act )
153- if mocap_pos is not None :
154- data = data .replace (mocap_pos = mocap_pos .reshape (model .nmocap , - 1 ))
155- if mocap_quat is not None :
156- data = data .replace (mocap_quat = mocap_quat .reshape (model .nmocap , - 1 ))
157- data = mjx .forward (model , data )
158- return data
159-
160-
161129def make_data (
162130 model : mujoco .MjModel ,
163131 qpos : Optional [jax .Array ] = None ,
@@ -169,9 +137,12 @@ def make_data(
169137 impl : Optional [str ] = None ,
170138 nconmax : Optional [int ] = None ,
171139 njmax : Optional [int ] = None ,
140+ device : Optional [jax .Device ] = None ,
172141) -> mjx .Data :
173142 """Initialize MJX Data."""
174- data = mjx .make_data (model , impl = impl , nconmax = nconmax , njmax = njmax )
143+ data = mjx .make_data (
144+ model , impl = impl , nconmax = nconmax , njmax = njmax , device = device
145+ )
175146 if qpos is not None :
176147 data = data .replace (qpos = qpos )
177148 if qvel is not None :
0 commit comments