diff --git a/gymnasium/spaces/tuple.py b/gymnasium/spaces/tuple.py index aae683ad21..05a1f652ab 100644 --- a/gymnasium/spaces/tuple.py +++ b/gymnasium/spaces/tuple.py @@ -47,14 +47,14 @@ def is_np_flattenable(self): """Checks whether this space can be flattened to a :class:`spaces.Box`.""" return all(space.is_np_flattenable for space in self.spaces) - def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]: + def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]: """Seed the PRNG of this space and all subspaces. Depending on the type of seed, the subspaces will be seeded differently * ``None`` - All the subspaces will use a random initial seed * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. - * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. + * ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. Args: seed: An optional list of ints or int to seed the (sub-)spaces. diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 6a0cff4052..3ce5547070 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box: dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")]) return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype) + + +@singledispatch +def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool: + """Returns if two spaces share a common dtype and shape (plus any critical variables). + + This function is primarily used to check for compatibility of different spaces in a vector environment. + + Args: + space_1: A Gymnasium space + space_2: A Gymnasium space + + Returns: + If the two spaces share a common dtype and shape (plus any critical variables). + """ + if isinstance(space_1, Space) and isinstance(space_2, Space): + raise NotImplementedError( + "`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, " + ) + else: + raise TypeError() + + +@is_space_dtype_shape_equiv.register(Box) +@is_space_dtype_shape_equiv.register(Discrete) +@is_space_dtype_shape_equiv.register(MultiDiscrete) +@is_space_dtype_shape_equiv.register(MultiBinary) +def _is_space_fundamental_dtype_shape_equiv(space_1, space_2): + return ( + # this check is necessary as singledispatch only checks the first variable and there are many options + type(space_1) is type(space_2) + and space_1.shape == space_2.shape + and space_1.dtype == space_2.dtype + ) + + +@is_space_dtype_shape_equiv.register(Text) +def _is_space_text_dtype_shape_equiv(space_1: Text, space_2): + return ( + isinstance(space_2, Text) + and space_1.max_length == space_2.max_length + and space_1.character_set == space_2.character_set + ) + + +@is_space_dtype_shape_equiv.register(Dict) +def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2): + return ( + isinstance(space_2, Dict) + and space_1.keys() == space_2.keys() + and all( + is_space_dtype_shape_equiv(space_1[key], space_2[key]) + for key in space_1.keys() + ) + ) + + +@is_space_dtype_shape_equiv.register(Tuple) +def _is_space_tuple_dtype_shape_equiv(space_1, space_2): + return isinstance(space_2, Tuple) and all( + is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1)) + ) + + +@is_space_dtype_shape_equiv.register(Graph) +def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2): + return ( + isinstance(space_2, Graph) + and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space) + and ( + (space_1.edge_space is None and space_2.edge_space is None) + or ( + space_1.edge_space is not None + and space_2.edge_space is not None + and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space) + ) + ) + ) + + +@is_space_dtype_shape_equiv.register(OneOf) +def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2): + return ( + isinstance(space_2, OneOf) + and len(space_1) == len(space_2) + and all( + is_space_dtype_shape_equiv(space_1[i], space_2[i]) + for i in range(len(space_1)) + ) + ) + + +@is_space_dtype_shape_equiv.register(Sequence) +def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2): + return ( + isinstance(space_2, Sequence) + and space_1.stack is space_2.stack + and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space) + ) diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index 01d6d4e9bf..3d7e24b789 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -22,8 +22,10 @@ CustomSpaceError, NoAsyncCallError, ) +from gymnasium.spaces.utils import is_space_dtype_shape_equiv from gymnasium.vector.utils import ( CloudpickleWrapper, + batch_differing_spaces, batch_space, clear_mpi_env_vars, concatenate, @@ -33,10 +35,6 @@ read_from_shared_memory, write_to_shared_memory, ) -from gymnasium.vector.utils.batched_spaces import ( - all_spaces_have_same_shape, - batch_differing_spaces, -) from gymnasium.vector.vector_env import ArrayType, VectorEnv @@ -119,13 +117,14 @@ def __init__( worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled. observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces. - 'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object - allows the user to set some custom observation space mode not covered by 'same' or 'different.' + 'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype, + warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and + ``observation_space``, warning, may raise unexpected errors. Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start - from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. + from the code for ``_worker`` (or ``_async_worker``) method, and add changes. Raises: RuntimeError: If the observation space of some sub-environment does not match observation_space @@ -136,6 +135,7 @@ def __init__( self.env_fns = env_fns self.shared_memory = shared_memory self.copy = copy + self.observation_mode = observation_mode self.num_envs = len(env_fns) @@ -148,9 +148,12 @@ def __init__( self.render_mode = dummy_env.render_mode self.single_action_space = dummy_env.action_space + self.action_space = batch_space(self.single_action_space, self.num_envs) - if isinstance(observation_mode, Space): - self.observation_space = observation_mode + if isinstance(observation_mode, tuple) and len(observation_mode) == 2: + assert isinstance(observation_mode[0], Space) + assert isinstance(observation_mode[1], Space) + self.observation_space, self.single_observation_space = observation_mode else: if observation_mode == "same": self.single_observation_space = dummy_env.observation_space @@ -158,19 +161,16 @@ def __init__( self.single_observation_space, self.num_envs ) elif observation_mode == "different": - current_spaces = [env().observation_space for env in self.env_fns] - - assert all_spaces_have_same_shape( - current_spaces - ), "Low & High values for observation spaces can be different but shapes need to be the same" - - self.single_observation_space = batch_differing_spaces(current_spaces) - - self.observation_space = self.single_observation_space + # the environment is created and instantly destroy, might cause issues for some environment + # but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option. + env_spaces = [env().observation_space for env in self.env_fns] + self.single_observation_space = env_spaces[0] + self.observation_space = batch_differing_spaces(env_spaces) else: - raise ValueError("Need to pass in mode for batching observations") - self.action_space = batch_space(self.single_action_space, self.num_envs) + raise ValueError( + f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}" + ) dummy_env.close() del dummy_env @@ -187,9 +187,7 @@ def __init__( ) except CustomSpaceError as e: raise ValueError( - "Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), " - "and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. " - "Set `shared_memory=False` if you use custom observation spaces." + "Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower." ) from e else: _obs_buffer = None @@ -616,20 +614,33 @@ def _poll_pipe_envs(self, timeout: int | None = None): def _check_spaces(self): self._assert_is_running() - spaces = (self.single_observation_space, self.single_action_space) for pipe in self.parent_pipes: - pipe.send(("_check_spaces", spaces)) + pipe.send( + ( + "_check_spaces", + ( + self.observation_mode, + self.single_observation_space, + self.single_action_space, + ), + ) + ) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) same_observation_spaces, same_action_spaces = zip(*results) if not all(same_observation_spaces): - raise RuntimeError( - f"Some environments have an observation space different from `{self.single_observation_space}`. " - "In order to batch observations, the observation spaces from all environments must be equal." - ) + if self.observation_mode == "same": + raise RuntimeError( + "AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead." + ) + else: + raise RuntimeError( + "AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype." + ) + if not all(same_action_spaces): raise RuntimeError( f"Some environments have an action space different from `{self.single_action_space}`. " @@ -739,23 +750,19 @@ def _async_worker( env.set_wrapper_attr(name, value) pipe.send((None, True)) elif command == "_check_spaces": + obs_mode, single_obs_space, single_action_space = data + pipe.send( ( ( - (data[0] == observation_space) - or ( - hasattr(observation_space, "low") - and hasattr(observation_space, "high") - and np.any( - np.all(observation_space.low == data[0].low, axis=1) - ) - and np.any( - np.all( - observation_space.high == data[0].high, axis=1 - ) + ( + single_obs_space == observation_space + if obs_mode == "same" + else is_space_dtype_shape_equiv( + single_obs_space, observation_space ) ), - data[1] == action_space, + single_action_space == action_space, ), True, ) diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index 709a625bb9..b92a268897 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -9,11 +9,13 @@ from gymnasium import Env, Space from gymnasium.core import ActType, ObsType, RenderFrame -from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate -from gymnasium.vector.utils.batched_spaces import ( - all_spaces_have_same_shape, - all_spaces_have_same_type, +from gymnasium.spaces.utils import is_space_dtype_shape_equiv +from gymnasium.vector.utils import ( batch_differing_spaces, + batch_space, + concatenate, + create_empty_array, + iterate, ) from gymnasium.vector.vector_env import ArrayType, VectorEnv @@ -78,6 +80,7 @@ def __init__( """ self.copy = copy self.env_fns = env_fns + self.observation_mode = observation_mode # Initialise all sub-environments self.envs = [env_fn() for env_fn in env_fns] @@ -89,39 +92,42 @@ def __init__( self.render_mode = self.envs[0].render_mode self.single_action_space = self.envs[0].action_space + self.action_space = batch_space(self.single_action_space, self.num_envs) - # Initialise the obs and action space based on the desired mode - - if isinstance(observation_mode, Space): - self.observation_space = observation_mode + if isinstance(observation_mode, tuple) and len(observation_mode) == 2: + assert isinstance(observation_mode[0], Space) + assert isinstance(observation_mode[1], Space) + self.observation_space, self.single_observation_space = observation_mode else: if observation_mode == "same": self.single_observation_space = self.envs[0].observation_space - self.single_action_space = self.envs[0].action_space - self.observation_space = batch_space( self.single_observation_space, self.num_envs ) elif observation_mode == "different": - current_spaces = [env.observation_space for env in self.envs] - - assert all_spaces_have_same_shape( - current_spaces - ), "Low & High values for observation spaces can be different but shapes need to be the same" - assert all_spaces_have_same_type( - current_spaces - ), "Observation spaces must have same Space type" - - self.observation_space = batch_differing_spaces(current_spaces) - - self.single_observation_space = self.observation_space - + self.single_observation_space = self.envs[0].observation_space + self.observation_space = batch_differing_spaces( + [env.observation_space for env in self.envs] + ) else: - raise ValueError("Need to pass in mode for batching observations") + raise ValueError( + f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}" + ) - self._check_spaces() + # check sub-environment obs and action spaces + for env in self.envs: + if observation_mode == "same": + assert ( + env.observation_space == self.single_observation_space + ), f"SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent. single_observation_space={self.single_observation_space}, sub-environment observation_space={env.observation_space}. If this is intentional, use `observation_mode='different'` instead." + else: + assert is_space_dtype_shape_equiv( + env.observation_space, self.single_observation_space + ), f"SyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation space={env.observation_space}" - self.action_space = batch_space(self.single_action_space, self.num_envs) + assert ( + env.action_space == self.single_action_space + ), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}" # Initialise attributes used in `step` and `reset` self._observations = create_empty_array( @@ -297,38 +303,3 @@ def close_extras(self, **kwargs: Any): """Close the environments.""" if hasattr(self, "envs"): [env.close() for env in self.envs] - - def _check_spaces(self) -> bool: - """Check that each of the environments obs and action spaces are equivalent to the single obs and action space.""" - for env in self.envs: - if not (env.observation_space == self.single_observation_space): - if not ( - hasattr(env.observation_space, "low") - and hasattr(env.observation_space, "high") - and np.any( - np.all( - env.observation_space.low - == self.single_observation_space.low, - axis=1, - ) - ) - and np.any( - np.all( - env.observation_space.high - == self.single_observation_space.high, - axis=1, - ) - ) - ): - raise RuntimeError( - f"Some environments have an observation space different from `{self.single_observation_space}`. " - "In order to batch observations, the observation spaces from all environments must be equal." - ) - - if not (env.action_space == self.single_action_space): - raise RuntimeError( - f"Some environments have an action space different from `{self.single_action_space}`. " - "In order to batch actions, the action spaces from all environments must be equal." - ) - - return True diff --git a/gymnasium/vector/utils/__init__.py b/gymnasium/vector/utils/__init__.py index a0ad58c3ed..53c989d5f7 100644 --- a/gymnasium/vector/utils/__init__.py +++ b/gymnasium/vector/utils/__init__.py @@ -7,6 +7,7 @@ write_to_shared_memory, ) from gymnasium.vector.utils.space_utils import ( + batch_differing_spaces, batch_space, concatenate, create_empty_array, @@ -16,6 +17,7 @@ __all__ = [ "batch_space", + "batch_differing_spaces", "iterate", "concatenate", "create_empty_array", diff --git a/gymnasium/vector/utils/batched_spaces.py b/gymnasium/vector/utils/batched_spaces.py deleted file mode 100644 index 5f7f597e6d..0000000000 --- a/gymnasium/vector/utils/batched_spaces.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Batching support for Spaces of same type but possibly varying low/high values.""" - -from __future__ import annotations - -from copy import deepcopy -from functools import singledispatch - -import numpy as np - -from gymnasium import Space -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - Graph, - MultiBinary, - MultiDiscrete, - OneOf, - Sequence, - Text, - Tuple, -) - - -@singledispatch -def batch_differing_spaces(spaces: list[Space]): - """Batch a Sequence of spaces that allows the subspaces to contain minor differences.""" - assert len(spaces) > 0 - assert all(isinstance(space, type(spaces[0])) for space in spaces) - assert type(spaces[0]) in batch_differing_spaces.registry - - return batch_differing_spaces.dispatch(type(spaces[0]))(spaces) - - -@batch_differing_spaces.register(Box) -def _batch_differing_spaces_box(spaces: list[Box]): - assert all(spaces[0].dtype == space for space in spaces) - - return Box( - low=np.array([space.low for space in spaces]), - high=np.array([space.high for space in spaces]), - dtype=spaces[0].dtype, - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(Discrete) -def _batch_differing_spaces_discrete(spaces: list[Discrete]): - return MultiDiscrete( - nvec=np.array([space.n for space in spaces]), - start=np.array([space.start for space in spaces]), - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(MultiDiscrete) -def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]): - return Box( - low=np.array([space.start for space in spaces]), - high=np.array([space.start + space.nvec for space in spaces]) - 1, - dtype=spaces[0].dtype, - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(MultiBinary) -def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]): - assert all(spaces[0].shape == space.shape for space in spaces) - - return Box( - low=0, - high=1, - shape=(len(spaces),) + spaces[0].shape, - dtype=spaces[0].dtype, - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(Tuple) -def _batch_differing_spaces_tuple(spaces: list[Tuple]): - return Tuple( - tuple( - batch_differing_spaces(subspaces) - for subspaces in zip(*[space.spaces for space in spaces]) - ), - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(Dict) -def _batch_differing_spaces_dict(spaces: list[Dict]): - assert all(spaces[0].keys() == space.keys() for space in spaces) - - return Dict( - { - key: batch_differing_spaces([space[key] for space in spaces]) - for key in spaces[0].keys() - }, - seed=deepcopy(spaces[0].np_random), - ) - - -@batch_differing_spaces.register(Graph) -@batch_differing_spaces.register(Text) -@batch_differing_spaces.register(Sequence) -@batch_differing_spaces.register(OneOf) -def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]): - return Tuple(spaces, seed=deepcopy(spaces[0].np_random)) - - -def all_spaces_have_same_shape(spaces): - """Check if all spaces have the same size.""" - if not spaces: - return True # An empty list is considered to have the same shape - - def get_space_shape(space): - if isinstance(space, Box): - return space.shape - elif isinstance(space, Discrete): - return () # Discrete spaces are considered scalar - elif isinstance(space, Dict): - return tuple(get_space_shape(s) for s in space.spaces.values()) - elif isinstance(space, Tuple): - return tuple(get_space_shape(s) for s in space.spaces) - else: - raise ValueError(f"Unsupported space type: {type(space)}") - - first_shape = get_space_shape(spaces[0]) - return all(get_space_shape(space) == first_shape for space in spaces[1:]) - - -def all_spaces_have_same_type(spaces): - """Check if all spaces have the same space type (Box, Discrete, etc).""" - if not spaces: - return True # An empty list is considered to have the same type - - # Get the type of the first space - first_type = type(spaces[0]) - - # Check if all spaces have the same type as the first one - return all(isinstance(space, first_type) for space in spaces) diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index 8eb0dd7dd0..c4d5ef68d8 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -1,6 +1,7 @@ """Space-based utility functions for vector environments. -- ``batch_space``: Create a (batched) space, containing multiple copies of a single space. +- ``batch_space``: Create a (batched) space containing multiple copies of a single space. +- ``batch_differing_spaces``: Create a (batched) space containing copies of different compatible spaces (share a common dtype and shape) - ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object. - ``Iterate``: Iterate over the elements of a (batched) space and items. - ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)`` @@ -32,7 +33,13 @@ from gymnasium.spaces.space import T_cov -__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"] +__all__ = [ + "batch_space", + "batch_differing_spaces", + "iterate", + "concatenate", + "create_empty_array", +] @singledispatch @@ -139,6 +146,116 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1): return batched_space +@singledispatch +def batch_differing_spaces(spaces: list[Space]): + """Batch a Sequence of spaces that allows the subspaces to contain minor differences.""" + assert len(spaces) > 0, "Expects a non-empty list of spaces" + assert all( + isinstance(space, type(spaces[0])) for space in spaces + ), f"Expects all spaces to be the same shape, actual types: {[type(space) for space in spaces]}" + assert ( + type(spaces[0]) in batch_differing_spaces.registry + ), f"Requires the Space type to have a registered `batch_differing_space`, current list: {batch_differing_spaces.registry}" + + return batch_differing_spaces.dispatch(type(spaces[0]))(spaces) + + +@batch_differing_spaces.register(Box) +def _batch_differing_spaces_box(spaces: list[Box]): + assert all( + spaces[0].dtype == space.dtype for space in spaces + ), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}" + assert all( + spaces[0].low.shape == space.low.shape for space in spaces + ), f"Expected all Box.low shape to be equal, actually {[space.low.shape for space in spaces]}" + assert all( + spaces[0].high.shape == space.high.shape for space in spaces + ), f"Expected all Box.high shape to be equal, actually {[space.high.shape for space in spaces]}" + + return Box( + low=np.array([space.low for space in spaces]), + high=np.array([space.high for space in spaces]), + dtype=spaces[0].dtype, + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(Discrete) +def _batch_differing_spaces_discrete(spaces: list[Discrete]): + return MultiDiscrete( + nvec=np.array([space.n for space in spaces]), + start=np.array([space.start for space in spaces]), + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(MultiDiscrete) +def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]): + assert all( + spaces[0].dtype == space.dtype for space in spaces + ), f"Expected all dtypes to be equal, actually {[space.dtype for space in spaces]}" + assert all( + spaces[0].nvec.shape == space.nvec.shape for space in spaces + ), f"Expects all MultiDiscrete.nvec shape, actually {[space.nvec.shape for space in spaces]}" + assert all( + spaces[0].start.shape == space.start.shape for space in spaces + ), f"Expects all MultiDiscrete.start shape, actually {[space.start.shape for space in spaces]}" + + return Box( + low=np.array([space.start for space in spaces]), + high=np.array([space.start + space.nvec for space in spaces]) - 1, + dtype=spaces[0].dtype, + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(MultiBinary) +def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]): + assert all(spaces[0].shape == space.shape for space in spaces) + + return Box( + low=0, + high=1, + shape=(len(spaces),) + spaces[0].shape, + dtype=spaces[0].dtype, + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(Tuple) +def _batch_differing_spaces_tuple(spaces: list[Tuple]): + return Tuple( + tuple( + batch_differing_spaces(subspaces) + for subspaces in zip(*[space.spaces for space in spaces]) + ), + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(Dict) +def _batch_differing_spaces_dict(spaces: list[Dict]): + assert all(spaces[0].keys() == space.keys() for space in spaces) + + return Dict( + { + key: batch_differing_spaces([space[key] for space in spaces]) + for key in spaces[0].keys() + }, + seed=deepcopy(spaces[0].np_random), + ) + + +@batch_differing_spaces.register(Graph) +@batch_differing_spaces.register(Text) +@batch_differing_spaces.register(Sequence) +@batch_differing_spaces.register(OneOf) +def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]): + return Tuple( + [deepcopy(space) for space in spaces], seed=deepcopy(spaces[0].np_random) + ) + + @singledispatch def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator: """Iterate over the elements of a (batched) space. diff --git a/tests/spaces/test_utils.py b/tests/spaces/test_utils.py index 01e445da2b..baa86d2452 100644 --- a/tests/spaces/test_utils.py +++ b/tests/spaces/test_utils.py @@ -6,7 +6,13 @@ import gymnasium as gym from gymnasium.spaces import Box, Graph, Sequence, utils +from gymnasium.spaces.utils import is_space_dtype_shape_equiv from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector.utils import ( + create_shared_memory, + read_from_shared_memory, + write_to_shared_memory, +) from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS @@ -162,3 +168,42 @@ def test_unflatten_multidiscrete_error(): value = np.array([0, 0]) with pytest.raises(ValueError): utils.unflatten(gym.spaces.MultiDiscrete([1, 1]), value) + + +@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) +def test_is_space_dtype_shape_equiv(space): + assert is_space_dtype_shape_equiv(space, space) is True + + +@pytest.mark.parametrize("space_1", TESTING_SPACES, ids=TESTING_SPACES_IDS) +def test_all_space_pairs_for_is_space_dtype_shape_equiv(space_1): + """Practically check that the `is_space_dtype_shape_equiv` works as expected for `shared_memory`.""" + for space_2 in TESTING_SPACES: + compatible = is_space_dtype_shape_equiv(space_1, space_2) + + if compatible: + try: + shared_memory = create_shared_memory(space_1, n=2) + except TypeError as err: + assert ( + "has a dynamic shape so its not possible to make a static shared memory." + in str(err) + ) + pytest.skip("Skipping space with dynamic shape") + + space_1.seed(123) + space_2.seed(123) + sample_1 = space_1.sample() + sample_2 = space_2.sample() + + write_to_shared_memory(space_1, 0, sample_1, shared_memory) + write_to_shared_memory(space_2, 1, sample_2, shared_memory) + + read_sample_1, read_sample_2 = read_from_shared_memory( + space_1, shared_memory, n=2 + ) + + assert data_equivalence(sample_1, read_sample_1) + assert data_equivalence(sample_2, read_sample_2) + else: + pytest.skip("Not compatible") diff --git a/tests/vector/test_batch_spaces.py b/tests/vector/test_batch_spaces.py deleted file mode 100644 index a91276dca8..0000000000 --- a/tests/vector/test_batch_spaces.py +++ /dev/null @@ -1,76 +0,0 @@ -import pytest - -import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete -from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv -from gymnasium.vector.utils import batch_space -from gymnasium.vector.utils.batched_spaces import batch_differing_spaces - - -class CustomEnv(gym.Env): - def __init__(self, observation_space): - super().__init__() - self.observation_space = observation_space - self.action_space = Discrete(2) # Dummy action space - - def reset(self, seed=None, options=None): - return self.observation_space.sample(), {} - - def step(self, action): - return self.observation_space.sample(), 0, False, False, {} - - -def create_env(obs_space): - return lambda: CustomEnv(obs_space) - - -# Test cases for both SyncVectorEnv and AsyncVectorEnv -@pytest.mark.parametrize("VectorEnv", [SyncVectorEnv, AsyncVectorEnv]) -class TestVectorEnvObservationModes: - - def test_invalid_observation_mode(self, VectorEnv): - with pytest.raises( - ValueError, match="Need to pass in mode for batching observations" - ): - VectorEnv( - [create_env(Box(low=0, high=1, shape=(5,))) for _ in range(3)], - observation_mode="invalid", - ) - - def test_mixed_observation_spaces(self, VectorEnv): - spaces = [ - Box(low=0, high=1, shape=(3,)), - Discrete(5), - Dict({"a": Discrete(2), "b": Box(low=0, high=1, shape=(2,))}), - ] - with pytest.raises( - AssertionError, - match="Low & High values for observation spaces can be different but shapes need to be the same", - ): - VectorEnv( - [create_env(space) for space in spaces], observation_mode="different" - ) - - def test_default_observation_mode(self, VectorEnv): - single_space = Box(low=0, high=1, shape=(5,)) - env = VectorEnv( - [create_env(single_space) for _ in range(3)] - ) # No observation_mode specified - assert env.observation_space == batch_space(single_space, 3) - - def test_different_observation_mode_same_shape(self, VectorEnv): - spaces = [Box(low=0, high=i, shape=(5,)) for i in range(1, 4)] - env = VectorEnv( - [create_env(space) for space in spaces], observation_mode="different" - ) - assert env.observation_space == batch_differing_spaces(spaces) - - def test_different_observation_mode_different_shapes(self, VectorEnv): - spaces = [Box(low=0, high=1, shape=(i + 1,)) for i in range(3)] - with pytest.raises( - AssertionError, - match="Low & High values for observation spaces can be different but shapes need to be the same", - ): - VectorEnv( - [create_env(space) for space in spaces], observation_mode="different" - ) diff --git a/tests/vector/test_observation_mode.py b/tests/vector/test_observation_mode.py new file mode 100644 index 0000000000..7aff2e08f7 --- /dev/null +++ b/tests/vector/test_observation_mode.py @@ -0,0 +1,121 @@ +import re +from functools import partial + +import numpy as np +import pytest + +from gymnasium.spaces import Box, Dict, Discrete +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.vector.utils import batch_differing_spaces +from tests.testing_env import GenericTestEnv + + +def create_env(obs_space): + return lambda: GenericTestEnv(observation_space=obs_space) + + +# Test cases for both SyncVectorEnv and AsyncVectorEnv +@pytest.mark.parametrize( + "vector_env_fn", + [SyncVectorEnv, AsyncVectorEnv, partial(AsyncVectorEnv, shared_memory=False)], + ids=[ + "SyncVectorEnv", + "AsyncVectorEnv(shared_memory=True)", + "AsyncVectorEnv(shared_memory=False)", + ], +) +class TestVectorEnvObservationModes: + + def test_invalid_observation_mode(self, vector_env_fn): + with pytest.raises( + ValueError, + match=re.escape( + "Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got invalid" + ), + ): + vector_env_fn( + [create_env(Box(low=0, high=1, shape=(5,))) for _ in range(3)], + observation_mode="invalid", + ) + + def test_obs_mode_same_different_spaces(self, vector_env_fn): + spaces = [Box(low=0, high=i, shape=(2,)) for i in range(1, 4)] + with pytest.raises( + (AssertionError, RuntimeError), + match="the sub-environments observation spaces are not equivalent. .*If this is intentional, use `observation_mode='different'` instead.", + ): + vector_env_fn( + [create_env(space) for space in spaces], observation_mode="same" + ) + + @pytest.mark.parametrize( + "observation_mode", + [ + "different", + ( + Box( + low=0, + high=np.repeat(np.arange(1, 4), 5).reshape((3, 5)), + shape=(3, 5), + ), + Box(low=0, high=1, shape=(5,)), + ), + ], + ) + def test_obs_mode_different_different_spaces(self, vector_env_fn, observation_mode): + spaces = [Box(low=0, high=i, shape=(5,)) for i in range(1, 4)] + envs = vector_env_fn( + [create_env(space) for space in spaces], observation_mode=observation_mode + ) + assert envs.observation_space == batch_differing_spaces(spaces) + assert envs.single_observation_space == spaces[0] + + envs.reset() + envs.step(envs.action_space.sample()) + envs.close() + + @pytest.mark.parametrize( + "observation_mode", + [ + "different", + (Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))), + ], + ) + def test_obs_mode_different_different_shapes(self, vector_env_fn, observation_mode): + spaces = [Box(low=0, high=1, shape=(i + 1,)) for i in range(3)] + with pytest.raises( + (AssertionError, RuntimeError), + # match=re.escape( + # "Expected all Box.low shape to be equal, actually [(1,), (2,), (3,)]" + # ), + ): + vector_env_fn( + [create_env(space) for space in spaces], + observation_mode=observation_mode, + ) + + @pytest.mark.parametrize( + "observation_mode", + [ + "same", + "different", + (Box(low=0, high=4, shape=(3, 5)), Box(low=0, high=4, shape=(5,))), + ], + ) + def test_mixed_observation_spaces(self, vector_env_fn, observation_mode): + spaces = [ + Box(low=0, high=1, shape=(3,)), + Discrete(5), + Dict({"a": Discrete(2), "b": Box(low=0, high=1, shape=(2,))}), + ] + + with pytest.raises( + (AssertionError, RuntimeError), + # match=re.escape( + # "Expects all spaces to be the same shape, actual types: [, , ]" + # ), + ): + vector_env_fn( + [create_env(space) for space in spaces], + observation_mode=observation_mode, + ) diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index 8ec187b981..cd0d685e7e 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -1,5 +1,7 @@ """Test the `SyncVectorEnv` implementation.""" +import re + import numpy as np import pytest @@ -139,7 +141,12 @@ def test_check_spaces_sync_vector_env(): env_fns = [make_env("CartPole-v1", i) for i in range(8)] # FrozenLake-v1 - Discrete(16), action_space: Discrete(4) env_fns[1] = make_env("FrozenLake-v1", 1) - with pytest.raises(RuntimeError): + with pytest.raises( + AssertionError, + match=re.escape( + "SyncVectorEnv(..., observation_mode='same') however the sub-environments observation spaces are not equivalent." + ), + ): env = SyncVectorEnv(env_fns) env.close() diff --git a/tests/vector/utils/test_space_utils.py b/tests/vector/utils/test_space_utils.py index 86ff1dae50..c03df12691 100644 --- a/tests/vector/utils/test_space_utils.py +++ b/tests/vector/utils/test_space_utils.py @@ -11,8 +11,13 @@ from gymnasium.error import CustomSpaceError from gymnasium.spaces import Box, Tuple from gymnasium.utils.env_checker import data_equivalence -from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate -from gymnasium.vector.utils.batched_spaces import batch_differing_spaces +from gymnasium.vector.utils import ( + batch_differing_spaces, + batch_space, + concatenate, + create_empty_array, + iterate, +) from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace from tests.vector.utils.utils import is_rng_equal @@ -70,13 +75,13 @@ def test_batch_space_deterministic(space: Space, n: int, base_seed: int): space_a = space space_a.seed(base_seed) space_b = copy.deepcopy(space_a) - is_rng_equal(space_a.np_random, space_b.np_random) + assert is_rng_equal(space_a.np_random, space_b.np_random) assert space_a.np_random is not space_b.np_random # Batch the spaces and check that the np_random are not reference equal space_a_batched = batch_space(space_a, n) space_b_batched = batch_space(space_b, n) - is_rng_equal(space_a_batched.np_random, space_b_batched.np_random) + assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random) assert space_a_batched.np_random is not space_b_batched.np_random # Create that the batched space is not reference equal to the origin spaces assert space_a.np_random is not space_a_batched.np_random @@ -103,7 +108,7 @@ def test_batch_space_different_samples(space: Space, n: int, base_seed: int): batched_space = batch_space(space, n) assert space.np_random is not batched_space.np_random - is_rng_equal(space.np_random, batched_space.np_random) + assert is_rng_equal(space.np_random, batched_space.np_random) batched_sample = batched_space.sample() unbatched_samples = list(iterate(batched_space, batched_sample)) @@ -177,9 +182,42 @@ def test_varying_spaces(spaces: "list[Space]", expected_space): @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) @pytest.mark.parametrize("n", [1, 3]) -def test_batch_spaces_vs_batch_space(space, n): +def test_batch_differing_space_vs_batch_space(space, n): """Test the batch_spaces and batch_space functions.""" batched_space = batch_space(space, n) batched_spaces = batch_differing_spaces([copy.deepcopy(space) for _ in range(n)]) assert batched_space == batched_spaces, f"{batched_space=}, {batched_spaces=}" + + +@pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) +@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]]) +@pytest.mark.parametrize( + "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] +) +def test_batch_differing_spaces_deterministic(space: Space, n: int, base_seed: int): + """Tests the batched spaces are deterministic by using a copied version.""" + # Copy the spaces and check that the np_random are not reference equal + space_a = space + space_a.seed(base_seed) + space_b = copy.deepcopy(space_a) + assert is_rng_equal(space_a.np_random, space_b.np_random) + assert space_a.np_random is not space_b.np_random + + # Batch the spaces and check that the np_random are not reference equal + space_a_batched = batch_differing_spaces([space_a for _ in range(n)]) + space_b_batched = batch_differing_spaces([space_b for _ in range(n)]) + assert is_rng_equal(space_a_batched.np_random, space_b_batched.np_random) + assert space_a_batched.np_random is not space_b_batched.np_random + # Create that the batched space is not reference equal to the origin spaces + assert space_a.np_random is not space_a_batched.np_random + + # Check that batched space a and b random number generator are not effected by the original space + space_a.sample() + space_a_batched_sample = space_a_batched.sample() + space_b_batched_sample = space_b_batched.sample() + for a_sample, b_sample in zip( + iterate(space_a_batched, space_a_batched_sample), + iterate(space_b_batched, space_b_batched_sample), + ): + assert data_equivalence(a_sample, b_sample)