Skip to content

Commit

Permalink
fixing slight movement of objects and CI
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Apr 22, 2024
1 parent 6f46079 commit 3f172d1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
14 changes: 8 additions & 6 deletions metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,8 @@ def __init__(
np.array([+1, +1, +1, +1]),
dtype=np.float64,
)

# Technically these observation lengths are different between v1 and v2,
# but we handle that elsewhere and just stick with v2 numbers here
self._obs_obj_max_len = 14

self._set_task_called = False

self.hand_init_pos = None # OVERRIDE ME
self._target_pos = None # OVERRIDE ME
self._random_reset_space = None # OVERRIDE ME
Expand All @@ -189,6 +184,8 @@ def __init__(
# doesn't seem to matter (it will only effect frame-stacking for the
# very first observation)

self.init_qpos = np.copy(self.data.qpos)
self.init_qvel = np.copy(self.data.qvel)
self._prev_obs = self._get_curr_obs_combined_no_goal()

EzPickle.__init__(
Expand Down Expand Up @@ -538,10 +535,15 @@ def evaluate_state(self, obs, action):
# V1 environments don't have to implement it
raise NotImplementedError

def reset_model(self):
qpos = self.init_qpos
qvel = self.init_qvel
self.set_state(qpos, qvel)

def reset(self, seed=None, options=None):
self.curr_path_length = 0
self.reset_model()
obs, info = super().reset()
mujoco.mj_forward(self.model, self.data)
self._prev_obs = obs[:18].copy()
obs[18:36] = self._prev_obs
obs = np.float64(obs)
Expand Down
1 change: 0 additions & 1 deletion metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def reset_model(self):
self.nail_init_pos = self._get_site_pos("nailHead")
self.obj_init_pos = self.hammer_init_pos.copy()
self._set_hammer_xyz(self.hammer_init_pos)
self.model.site("goal").pos = self._target_pos
return self._get_obs()

@staticmethod
Expand Down

0 comments on commit 3f172d1

Please sign in to comment.