@@ -82,6 +82,8 @@ class MujocoEnv(metaclass=EnvMeta):
8282 ignore_done (bool): True if never terminating the environment (ignore @horizon).
8383 hard_reset (bool): If True, re-loads model, sim, and render object upon a reset call, else,
8484 only calls sim.reset and resets all robosuite-internal variables
85+ load_model_on_init (bool): If True, load and initialize the model and renderer in __init__ constructor,
86+ else, initialize these components in the first call to reset()
8587 renderer (str): string for the renderer to use
8688 renderer_config (dict): dictionary for the renderer configurations
8789 seed (int): environment seed. Default is None, where environment is unseeded, ie. random
@@ -102,6 +104,7 @@ def __init__(
102104 horizon = 1000 ,
103105 ignore_done = False ,
104106 hard_reset = True ,
107+ load_model_on_init = True ,
105108 renderer = "mjviewer" ,
106109 renderer_config = None ,
107110 seed = None ,
@@ -143,31 +146,44 @@ def __init__(
143146
144147 self ._ep_meta = {}
145148
146- # Load the model
147- self ._load_model ()
149+ self .load_model_on_init = load_model_on_init
148150
149- # Initialize the simulation
150- self ._initialize_sim ()
151+ # variable to keep track of whether the env has been fully initialized
152+ self ._env_is_initialized = False
151153
152- # initializes the rendering
153- self .initialize_renderer ()
154+ if self .load_model_on_init :
155+ # Load the model
156+ self ._load_model ()
154157
155- # the variables will be set later.
156- # need to set to None, in case these variables are referenced before being set
157- self .viewer = None
158- self .viewer_get_obs = None
158+ # Initialize the simulation
159+ self ._initialize_sim ()
159160
160- # Run all further internal (re-)initialization required
161- self ._reset_internal ()
161+ # initializes the rendering
162+ self .initialize_renderer ()
162163
163- # Load observables
164- if hasattr (self .viewer , "_setup_observables" ):
165- self ._observables = self .viewer ._setup_observables ()
166- else :
167- self ._observables = self ._setup_observables ()
164+ # the variables will be set later.
165+ # need to set to None, in case these variables are referenced before being set
166+ self .viewer = None
167+ self .viewer_get_obs = None
168168
169- # check if viewer has get observations method and set a flag for future use.
170- self .viewer_get_obs = hasattr (self .viewer , "_get_observations" )
169+ # Run all further internal (re-)initialization required
170+ self ._reset_internal ()
171+
172+ # Load observables
173+ if hasattr (self .viewer , "_setup_observables" ):
174+ self ._observables = self .viewer ._setup_observables ()
175+ else :
176+ self ._observables = self ._setup_observables ()
177+
178+ # check if viewer has get observations method and set a flag for future use.
179+ self .viewer_get_obs = hasattr (self .viewer , "_get_observations" )
180+ self ._env_is_initialized = True
181+ else :
182+ # the variables will be set later.
183+ # need to set to None, in case these variables are referenced before being set
184+ self .sim = None
185+ self .viewer = None
186+ self .viewer_get_obs = None
171187
172188 def initialize_renderer (self ):
173189 self .renderer = self .renderer .lower ()
@@ -271,7 +287,7 @@ def reset(self):
271287 if self .renderer == "mjviewer" :
272288 self ._destroy_viewer ()
273289
274- if self .hard_reset and not self .deterministic_reset :
290+ if ( self .sim is None ) or ( self . hard_reset and not self .deterministic_reset ) :
275291 if self .renderer == "mujoco" :
276292 self ._destroy_viewer ()
277293 self ._destroy_sim ()
@@ -281,9 +297,33 @@ def reset(self):
281297 else :
282298 self .sim .reset ()
283299
284- # Reset necessary robosuite-centric variables
285- self ._reset_internal ()
286- self .sim .forward ()
300+ if self ._env_is_initialized is True :
301+ # Reset necessary robosuite-centric variables
302+ self ._reset_internal ()
303+ self .sim .forward ()
304+ else :
305+ # initializes the rendering
306+ self .initialize_renderer ()
307+
308+ # the variables will be set later.
309+ # need to set to None, in case these variables are referenced before being set
310+ self .viewer = None
311+ self .viewer_get_obs = None
312+
313+ # Run all further internal (re-)initialization required
314+ self ._reset_internal ()
315+ self .sim .forward ()
316+
317+ # Load observables
318+ if hasattr (self .viewer , "_setup_observables" ):
319+ self ._observables = self .viewer ._setup_observables ()
320+ else :
321+ self ._observables = self ._setup_observables ()
322+
323+ # check if viewer has get observations method and set a flag for future use.
324+ self .viewer_get_obs = hasattr (self .viewer , "_get_observations" )
325+ self ._env_is_initialized = True
326+
287327 # Setup observables, reloading if
288328 self ._obs_cache = {}
289329 self ._reset_observables ()
0 commit comments