@@ -172,13 +172,8 @@ def __init__(
172
172
np .array ([+ 1 , + 1 , + 1 , + 1 ]),
173
173
dtype = np .float64 ,
174
174
)
175
-
176
- # Technically these observation lengths are different between v1 and v2,
177
- # but we handle that elsewhere and just stick with v2 numbers here
178
175
self ._obs_obj_max_len = 14
179
-
180
176
self ._set_task_called = False
181
-
182
177
self .hand_init_pos = None # OVERRIDE ME
183
178
self ._target_pos = None # OVERRIDE ME
184
179
self ._random_reset_space = None # OVERRIDE ME
@@ -189,6 +184,8 @@ def __init__(
189
184
# doesn't seem to matter (it will only effect frame-stacking for the
190
185
# very first observation)
191
186
187
+ self .init_qpos = np .copy (self .data .qpos )
188
+ self .init_qvel = np .copy (self .data .qvel )
192
189
self ._prev_obs = self ._get_curr_obs_combined_no_goal ()
193
190
194
191
EzPickle .__init__ (
@@ -538,10 +535,15 @@ def evaluate_state(self, obs, action):
538
535
# V1 environments don't have to implement it
539
536
raise NotImplementedError
540
537
538
+ def reset_model (self ):
539
+ qpos = self .init_qpos
540
+ qvel = self .init_qvel
541
+ self .set_state (qpos , qvel )
542
+
541
543
def reset (self , seed = None , options = None ):
542
544
self .curr_path_length = 0
545
+ self .reset_model ()
543
546
obs , info = super ().reset ()
544
- mujoco .mj_forward (self .model , self .data )
545
547
self ._prev_obs = obs [:18 ].copy ()
546
548
obs [18 :36 ] = self ._prev_obs
547
549
obs = np .float64 (obs )
0 commit comments