Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e501e80
Added num_envs parameter in GymEnv to call multiple environments in o…
ParamThakkar123 Jan 19, 2026
835634e
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 19, 2026
305fe34
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 19, 2026
1a24ab5
Added num_envs parameter in GymEnv to call multiple environments in o…
ParamThakkar123 Jan 19, 2026
4d7f816
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 19, 2026
3c75f82
Merge branch 'add/num-envs-gym' of https://github.com/ParamThakkar123…
ParamThakkar123 Jan 19, 2026
c688def
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 19, 2026
3f13f50
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 20, 2026
4b8419c
changed num_envs to num_workers
ParamThakkar123 Jan 20, 2026
2df6f03
changed num_envs to num_workers
ParamThakkar123 Jan 20, 2026
00b807b
Updates
ParamThakkar123 Jan 20, 2026
e9f76ab
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 20, 2026
f6bcafd
Fixed tests
ParamThakkar123 Jan 20, 2026
5d5724c
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 20, 2026
774f573
Test fixes
ParamThakkar123 Jan 20, 2026
387963f
Merge branch 'main' of https://github.com/pytorch/rl into add/num-env…
ParamThakkar123 Jan 20, 2026
7294c49
[BugFix] Fix test_gym_kwargs_preserved_with_seed
vmoens Jan 20, 2026
d4ddb43
[BugFix] Use PENDULUM_VERSIONED() for gym version compatibility
vmoens Jan 21, 2026
ab9119e
[Style] Fix pre-commit issues in gym.py
vmoens Jan 21, 2026
929b922
Merge remote-tracking branch 'origin/main' into add/num-envs-gym
vmoens Jan 21, 2026
bced29d
cloudpickle refactor
vmoens Jan 21, 2026
2d34456
cloudpickle global
vmoens Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
RenameTransform,
StepCounter,
)
from torchrl.envs.batched_envs import SerialEnv
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.brax import _has_brax, BraxEnv, BraxWrapper
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
from torchrl.envs.libs.envpool import _has_envpool, MultiThreadedEnvWrapper
Expand Down Expand Up @@ -1878,6 +1878,85 @@ def __init__(self):
result is False
), f"Expected False for Dict environment without pixels, got {result}"

def test_num_workers_returns_parallel_env(self):
"""Ensure explicit TorchRL `num_workers` returns a lazy ParallelEnv, while gym's
native `num_envs` remains a gym-native vectorization."""

# TorchRL-managed parallelism: should return ParallelEnv
env = GymEnv("CartPole-v1", num_workers=3)
Comment thread
ParamThakkar123 marked this conversation as resolved.
try:
assert isinstance(env, ParallelEnv)
# accept either attribute name used by ParallelEnv implementations
nworkers = getattr(env, "num_workers", None)
if nworkers is None:
nworkers = getattr(env, "num_envs", None)
assert nworkers == 3
# start workers on first use
env.reset()
assert env.batch_size == torch.Size([3])
finally:
env.close()

# Gym-native vectorization should NOT be converted implicitly by TorchRL
env_gymvec = GymEnv("CartPole-v1", num_envs=3)
try:
assert not isinstance(env_gymvec, ParallelEnv)
finally:
env_gymvec.close()

def test_num_workers_kwargs_modifiable(self):
"""Ensure the kwargs preserved by the GymEnv factory can be modified via
`configure_parallel` before workers start."""

env = GymEnv("CartPole-v1", num_workers=3)
try:
# should return a lazy ParallelEnv
assert isinstance(env, ParallelEnv)

# configure_parallel should accept kwargs and be callable before start
env.configure_parallel(use_buffers=True, num_threads=1)

# starting the environment should work after configuring
td = env.reset()
assert isinstance(td, TensorDict)
finally:
env.close()

def test_set_seed_and_reset_works(self):
"""Smoke test that setting seed and reset works (seed forwarded into build)."""
env = GymEnv("CartPole-v1")
final_seed = env.set_seed(0)
assert final_seed is not None
td = env.reset()

assert isinstance(td, TensorDict)
env.close()

# Also verify behavior for TorchRL-managed parallel envs
penv = GymEnv("CartPole-v1", num_workers=2)
try:
final_seed = penv.set_seed(0)
assert final_seed is not None
td = penv.reset()
assert isinstance(td, TensorDict)
finally:
penv.close()

def test_gym_kwargs_preserved_with_seed(self):
"""Test that kwargs like frame_skip are preserved when seed is provided.
Regression test for a bug where `kwargs` were overwritten when `_seed` was not None.
"""
# Use Pendulum instead of CartPole because CartPole can terminate
# early due to pole falling, especially with frame_skip=4
env = GymEnv(PENDULUM_VERSIONED(), frame_skip=4, from_pixels=False)
try:
td = env.reset()
rollout = env.rollout(max_steps=5)
assert rollout.shape[0] == 5
assert "observation" in td.keys()
finally:
env.close()

def test_is_from_pixels_wrapper_env(self):
"""Test that _is_from_pixels correctly identifies wrapped environments."""
from torchrl.envs.libs.gym import _is_from_pixels
Expand Down
9 changes: 3 additions & 6 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Callable
from typing import Any, Union

import cloudpickle
import numpy as np
import torch
from torch import Tensor
Expand Down Expand Up @@ -246,15 +247,11 @@ def __init__(self, fn: Callable, **kwargs):
functools.update_wrapper(self, getattr(fn, "forward", fn))

def __getstate__(self):
import cloudpickle

return cloudpickle.dumps((self.fn, self.kwargs))

def __setstate__(self, ob: bytes):
import pickle

self.fn, self.kwargs = pickle.loads(ob)
functools.update_wrapper(self, self.fn)
self.fn, self.kwargs = cloudpickle.loads(ob)
functools.update_wrapper(self, getattr(self.fn, "forward", self.fn))

def __call__(self, *args, **kwargs) -> Any:
kwargs.update(self.kwargs)
Expand Down
59 changes: 50 additions & 9 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from contextlib import nullcontext
from copy import copy
from functools import partial
from types import ModuleType
from warnings import warn

Expand All @@ -36,7 +37,6 @@
Unbounded,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict, torch_to_numpy_dtype_dict
from torchrl.envs.batched_envs import CloudpickleWrapper
from torchrl.envs.common import _EnvPostInit
from torchrl.envs.gym_like import default_info_dict_reader, GymLikeEnv
from torchrl.envs.utils import _classproperty
Expand Down Expand Up @@ -818,6 +818,17 @@ class PixelObservationWrapper:
class _GymAsyncMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
missing_obs_value = kwargs.pop("missing_obs_value", None)
num_workers = kwargs.pop("num_workers", 1)

if cls.__name__ == "GymEnv" and num_workers > 1:
from torchrl.envs import EnvCreator, ParallelEnv

env_name = args[0] if args else kwargs.get("env_name")
env_kwargs = kwargs.copy()
env_kwargs.pop("env_name", None)
make_env = partial(cls, env_name, **env_kwargs)
return ParallelEnv(num_workers, EnvCreator(make_env))

instance: GymWrapper = super().__call__(*args, **kwargs)

# before gym 0.22, there was no final_observation
Expand Down Expand Up @@ -1719,6 +1730,15 @@ class GymEnv(GymWrapper):
num_envs (int, optional): the number of envs to run in parallel. Defaults to
``None`` (a single env is to be run). :class:`~gym.vector.AsyncVectorEnv`
will be used by default.
num_workers (int, optional): number of top-level worker subprocesses used to create/run
multiple :class:`GymEnv` instances in parallel (handled by the metaclass
:class:`_GymAsyncMeta`). When ``num_workers > 1``, a lazy
:class:`~torchrl.envs.ParallelEnv` is returned whose factory preserves the original
`GymEnv` kwargs. You can modify the ParallelEnv construction/configuration before
it starts by calling :meth:`~torchrl.envs.batched_envs.BatchedEnvBase.configure_parallel`
on the returned object (for example: ``env.configure_parallel(use_buffers=True, num_threads=2)``).
When both ``num_workers`` and ``num_envs`` are greater than 1, the total number of
environments executed in parallel is ``num_workers * num_envs``. Defaults to ``1``.
disable_env_checker (bool, optional): for gym > 0.24 only. If ``True`` (default
for these versions), the environment checker won't be run.
from_pixels (bool, optional): if ``True``, an attempt to return the pixel
Expand Down Expand Up @@ -1784,6 +1804,33 @@ class GymEnv(GymWrapper):
>>> print(env.available_envs)
['ALE/Adventure-ram-v5', 'ALE/Adventure-v5', 'ALE/AirRaid-ram-v5', 'ALE/AirRaid-v5', 'ALE/Alien-ram-v5', 'ALE/Alien-v5',

To run multiple environments in parallel:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("Pendulum-v1", num_workers=4)
>>> td_reset = env.reset()
>>> td = env.rand_step(td_reset)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False),
observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)

.. note::
If both `OpenAI/gym` and `gymnasium` are present in the virtual environment,
one can swap backend using :func:`~torchrl.envs.libs.gym.set_gym_backend`:
Expand Down Expand Up @@ -1922,14 +1969,8 @@ def _build_env(
raise err
env = super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels)
if num_envs > 0:
try:
env = self._async_env([CloudpickleWrapper(lambda: env)] * num_envs)
except RuntimeError:
# It would fail if the environment is not pickable. In that case,
# delegating environment instantiation to each subprocess as a fallback.
env = self._async_env(
[lambda: self.lib.make(env_name, **kwargs)] * num_envs
)
make_fn = partial(self.lib.make, env_name, **kwargs)
env = self._async_env([make_fn] * num_envs)
self.batch_size = torch.Size([num_envs, *self.batch_size])
return env

Expand Down
Loading