diff --git a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py index 211770656..8fd91f3e5 100644 --- a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py +++ b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py @@ -172,13 +172,8 @@ def __init__( np.array([+1, +1, +1, +1]), dtype=np.float64, ) - - # Technically these observation lengths are different between v1 and v2, - # but we handle that elsewhere and just stick with v2 numbers here self._obs_obj_max_len = 14 - self._set_task_called = False - self.hand_init_pos = None # OVERRIDE ME self._target_pos = None # OVERRIDE ME self._random_reset_space = None # OVERRIDE ME @@ -189,6 +184,8 @@ def __init__( # doesn't seem to matter (it will only effect frame-stacking for the # very first observation) + self.init_qpos = np.copy(self.data.qpos) + self.init_qvel = np.copy(self.data.qvel) self._prev_obs = self._get_curr_obs_combined_no_goal() EzPickle.__init__( @@ -538,10 +535,15 @@ def evaluate_state(self, obs, action): # V1 environments don't have to implement it raise NotImplementedError + def reset_model(self): + qpos = self.init_qpos + qvel = self.init_qvel + self.set_state(qpos, qvel) + def reset(self, seed=None, options=None): self.curr_path_length = 0 + self.reset_model() obs, info = super().reset() - mujoco.mj_forward(self.model, self.data) self._prev_obs = obs[:18].copy() obs[18:36] = self._prev_obs obs = np.float64(obs) diff --git a/metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py b/metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py index 28eafcd4a..281fdd131 100644 --- a/metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py +++ b/metaworld/envs/mujoco/sawyer_xyz/v2/sawyer_hammer_v2.py @@ -104,7 +104,6 @@ def reset_model(self): self.nail_init_pos = self._get_site_pos("nailHead") self.obj_init_pos = self.hammer_init_pos.copy() self._set_hammer_xyz(self.hammer_init_pos) - self.model.site("goal").pos = self._target_pos return self._get_obs() @staticmethod