Skip to content

Commit 3f172d1

Browse files
fixing slight movement of objects and CI
1 parent 6f46079 commit 3f172d1

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,8 @@ def __init__(
172172
np.array([+1, +1, +1, +1]),
173173
dtype=np.float64,
174174
)
175-
176-
# Technically these observation lengths are different between v1 and v2,
177-
# but we handle that elsewhere and just stick with v2 numbers here
178175
self._obs_obj_max_len = 14
179-
180176
self._set_task_called = False
181-
182177
self.hand_init_pos = None # OVERRIDE ME
183178
self._target_pos = None # OVERRIDE ME
184179
self._random_reset_space = None # OVERRIDE ME
@@ -189,6 +184,8 @@ def __init__(
189184
# doesn't seem to matter (it will only effect frame-stacking for the
190185
# very first observation)
191186

187+
self.init_qpos = np.copy(self.data.qpos)
188+
self.init_qvel = np.copy(self.data.qvel)
192189
self._prev_obs = self._get_curr_obs_combined_no_goal()
193190

194191
EzPickle.__init__(
@@ -538,10 +535,15 @@ def evaluate_state(self, obs, action):
538535
# V1 environments don't have to implement it
539536
raise NotImplementedError
540537

538+
def reset_model(self):
539+
qpos = self.init_qpos
540+
qvel = self.init_qvel
541+
self.set_state(qpos, qvel)
542+
541543
def reset(self, seed=None, options=None):
542544
self.curr_path_length = 0
545+
self.reset_model()
543546
obs, info = super().reset()
544-
mujoco.mj_forward(self.model, self.data)
545547
self._prev_obs = obs[:18].copy()
546548
obs[18:36] = self._prev_obs
547549
obs = np.float64(obs)

metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def reset_model(self):
104104
self.nail_init_pos = self._get_site_pos("nailHead")
105105
self.obj_init_pos = self.hammer_init_pos.copy()
106106
self._set_hammer_xyz(self.hammer_init_pos)
107-
self.model.site("goal").pos = self._target_pos
108107
return self._get_obs()
109108

110109
@staticmethod

0 commit comments

Comments
 (0)