Skip to content

Extend testing and make generic the batch space for vector envs #1139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4 changes: 2 additions & 2 deletions gymnasium/spaces/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return all(space.is_np_flattenable for space in self.spaces)

def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
"""Seed the PRNG of this space and all subspaces.

Depending on the type of seed, the subspaces will be seeded differently

* ``None`` - All the subspaces will use a random initial seed
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
* ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.

Args:
seed: An optional list of ints or int to seed the (sub-)spaces.
Expand Down
99 changes: 99 additions & 0 deletions gymnasium/spaces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:

dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)


@singledispatch
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
"""Returns if two spaces share a common dtype and shape (plus any critical variables).

This function is primarily used to check for compatibility of different spaces in a vector environment.

Args:
space_1: A Gymnasium space
space_2: A Gymnasium space

Returns:
If the two spaces share a common dtype and shape (plus any critical variables).
"""
if isinstance(space_1, Space) and isinstance(space_2, Space):
raise NotImplementedError(
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
)
else:
raise TypeError()


@is_space_dtype_shape_equiv.register(Box)
@is_space_dtype_shape_equiv.register(Discrete)
@is_space_dtype_shape_equiv.register(MultiDiscrete)
@is_space_dtype_shape_equiv.register(MultiBinary)
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
return (
# this check is necessary as singledispatch only checks the first variable and there are many options
type(space_1) is type(space_2)
and space_1.shape == space_2.shape
and space_1.dtype == space_2.dtype
)


@is_space_dtype_shape_equiv.register(Text)
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
return (
isinstance(space_2, Text)
and space_1.max_length == space_2.max_length
and space_1.character_set == space_2.character_set
)


@is_space_dtype_shape_equiv.register(Dict)
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
return (
isinstance(space_2, Dict)
and space_1.keys() == space_2.keys()
and all(
is_space_dtype_shape_equiv(space_1[key], space_2[key])
for key in space_1.keys()
)
)


@is_space_dtype_shape_equiv.register(Tuple)
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
return isinstance(space_2, Tuple) and all(
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
)


@is_space_dtype_shape_equiv.register(Graph)
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
return (
isinstance(space_2, Graph)
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
and (
(space_1.edge_space is None and space_2.edge_space is None)
or (
space_1.edge_space is not None
and space_2.edge_space is not None
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
)
)
)


@is_space_dtype_shape_equiv.register(OneOf)
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
return (
isinstance(space_2, OneOf)
and len(space_1) == len(space_2)
and all(
is_space_dtype_shape_equiv(space_1[i], space_2[i])
for i in range(len(space_1))
)
)


@is_space_dtype_shape_equiv.register(Sequence)
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
return (
isinstance(space_2, Sequence)
and space_1.stack is space_2.stack
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
)
89 changes: 48 additions & 41 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
CustomSpaceError,
NoAsyncCallError,
)
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
from gymnasium.vector.utils import (
CloudpickleWrapper,
batch_differing_spaces,
batch_space,
clear_mpi_env_vars,
concatenate,
Expand All @@ -33,10 +35,6 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.utils.batched_spaces import (
all_spaces_have_same_shape,
batch_differing_spaces,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv


Expand Down Expand Up @@ -119,13 +117,14 @@ def __init__(
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
``observation_space``, warning, may raise unexpected errors.

Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
from the code for ``_worker`` (or ``_async_worker``) method, and add changes.

Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
Expand All @@ -136,6 +135,7 @@ def __init__(
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
self.observation_mode = observation_mode

self.num_envs = len(env_fns)

Expand All @@ -148,29 +148,29 @@ def __init__(
self.render_mode = dummy_env.render_mode

self.single_action_space = dummy_env.action_space
self.action_space = batch_space(self.single_action_space, self.num_envs)

if isinstance(observation_mode, Space):
self.observation_space = observation_mode
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
assert isinstance(observation_mode[0], Space)
assert isinstance(observation_mode[1], Space)
self.observation_space, self.single_observation_space = observation_mode
else:
if observation_mode == "same":
self.single_observation_space = dummy_env.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
elif observation_mode == "different":
current_spaces = [env().observation_space for env in self.env_fns]

assert all_spaces_have_same_shape(
current_spaces
), "Low & High values for observation spaces can be different but shapes need to be the same"

self.single_observation_space = batch_differing_spaces(current_spaces)

self.observation_space = self.single_observation_space
# the environment is created and instantly destroy, might cause issues for some environment
# but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option.
env_spaces = [env().observation_space for env in self.env_fns]

self.single_observation_space = env_spaces[0]
self.observation_space = batch_differing_spaces(env_spaces)
else:
raise ValueError("Need to pass in mode for batching observations")
self.action_space = batch_space(self.single_action_space, self.num_envs)
raise ValueError(
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
)

dummy_env.close()
del dummy_env
Expand All @@ -187,9 +187,7 @@ def __init__(
)
except CustomSpaceError as e:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
"and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
"Set `shared_memory=False` if you use custom observation spaces."
"Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower."
) from e
else:
_obs_buffer = None
Expand Down Expand Up @@ -616,20 +614,33 @@ def _poll_pipe_envs(self, timeout: int | None = None):

def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)

for pipe in self.parent_pipes:
pipe.send(("_check_spaces", spaces))
pipe.send(
(
"_check_spaces",
(
self.observation_mode,
self.single_observation_space,
self.single_action_space,
),
)
)

results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
same_observation_spaces, same_action_spaces = zip(*results)

if not all(same_observation_spaces):
raise RuntimeError(
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)
if self.observation_mode == "same":
raise RuntimeError(
"AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead."
)
else:
raise RuntimeError(
"AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype."
)

if not all(same_action_spaces):
raise RuntimeError(
f"Some environments have an action space different from `{self.single_action_space}`. "
Expand Down Expand Up @@ -739,23 +750,19 @@ def _async_worker(
env.set_wrapper_attr(name, value)
pipe.send((None, True))
elif command == "_check_spaces":
obs_mode, single_obs_space, single_action_space = data

pipe.send(
(
(
(data[0] == observation_space)
or (
hasattr(observation_space, "low")
and hasattr(observation_space, "high")
and np.any(
np.all(observation_space.low == data[0].low, axis=1)
)
and np.any(
np.all(
observation_space.high == data[0].high, axis=1
)
(
single_obs_space == observation_space
if obs_mode == "same"
else is_space_dtype_shape_equiv(
single_obs_space, observation_space
)
),
data[1] == action_space,
single_action_space == action_space,
),
True,
)
Expand Down
Loading
Loading