@@ -26,6 +26,10 @@ class MultiThreadedEnvWrapper(_EnvWrapper):
2626
2727 Paper: https://arxiv.org/abs/2206.10558
2828
29+ EnvPool environments auto-reset internally when episodes end. This wrapper
30+ handles that behavior by caching the auto-reset observations and returning
31+ them appropriately in step_and_maybe_reset.
32+
2933 Args:
3034 env (envpool.python.envpool.EnvPoolMixin): the envpool to wrap.
3135 categorical_action_encoding (bool, optional): if ``True``, categorical
@@ -138,6 +142,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
138142 tensordict_out = self ._transform_step_output (step_output )
139143 return tensordict_out
140144
145+ def step_and_maybe_reset (
146+ self , tensordict : TensorDictBase
147+ ) -> tuple [TensorDictBase , TensorDictBase ]:
148+ """Runs a step and handles envpool's internal auto-reset.
149+
150+ EnvPool auto-resets internally when episodes end. When done=True:
151+ - The observation returned is the final observation of the ending episode
152+ - The NEXT call to step() returns the first observation of a new episode
153+
154+ This method handles this by skipping explicit reset() calls for done
155+ environments. EnvPool maintains its own internal state, so the next
156+ step() will automatically return the reset observation.
157+
158+ Note: The observation in tensordict_ for done envs will be the final
159+ observation (not the reset observation). This is acceptable because
160+ envpool ignores the input observation and uses its internal state.
161+ """
162+ # Perform the step
163+ tensordict = self .step (tensordict )
164+
165+ # Move data from "next" to root for the next iteration
166+ tensordict_ = self ._step_mdp (tensordict )
167+
168+ # EnvPool auto-resets internally, so we skip calling reset().
169+ # However, we need to clear the done flags in tensordict_ since envpool
170+ # has already reset those environments. The next step() will return
171+ # the reset observations automatically.
172+ for key in self .done_keys :
173+ if key in tensordict_ .keys (True ):
174+ tensordict_ .set (key , torch .zeros_like (tensordict_ .get (key )))
175+
176+ return tensordict , tensordict_
177+
141178 def _get_action_spec (self ) -> TensorSpec :
142179 # local import to avoid importing gym in the script
143180 from torchrl .envs .libs .gym import _gym_to_torchrl_spec_transform
@@ -378,6 +415,9 @@ def _build_env(
378415 import envpool
379416
380417 create_env_kwargs = create_env_kwargs or {}
418+ # EnvPool requires max_num_players to be set for single-player environments
419+ if "max_num_players" not in create_env_kwargs :
420+ create_env_kwargs ["max_num_players" ] = 1
381421 env = envpool .make (
382422 task_id = env_name ,
383423 env_type = "gym" ,
0 commit comments