Skip to content

Commit 21d5ea3

Browse files
Extend testing and make generic the batch space for vector envs (#1139)
Co-authored-by: Reggie McLean <[email protected]>
1 parent d20ac56 commit 21d5ea3

12 files changed

+520
-330
lines changed

Diff for: gymnasium/spaces/tuple.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def is_np_flattenable(self):
4747
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
4848
return all(space.is_np_flattenable for space in self.spaces)
4949

50-
def seed(self, seed: int | tuple[int] | None = None) -> tuple[int, ...]:
50+
def seed(self, seed: int | typing.Sequence[int] | None = None) -> tuple[int, ...]:
5151
"""Seed the PRNG of this space and all subspaces.
5252
5353
Depending on the type of seed, the subspaces will be seeded differently
5454
5555
* ``None`` - All the subspaces will use a random initial seed
5656
* ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces.
57-
* ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
57+
* ``List`` / ``Tuple`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``.
5858
5959
Args:
6060
seed: An optional list of ints or int to seed the (sub-)spaces.

Diff for: gymnasium/spaces/utils.py

+99
Original file line numberDiff line numberDiff line change
@@ -573,3 +573,102 @@ def _flatten_space_oneof(space: OneOf) -> Box:
573573

574574
dtype = np.result_type(*[s.dtype for s in space.spaces if hasattr(s, "dtype")])
575575
return Box(low=low, high=high, shape=(max_flatdim,), dtype=dtype)
576+
577+
578+
@singledispatch
579+
def is_space_dtype_shape_equiv(space_1: Space, space_2: Space) -> bool:
580+
"""Returns if two spaces share a common dtype and shape (plus any critical variables).
581+
582+
This function is primarily used to check for compatibility of different spaces in a vector environment.
583+
584+
Args:
585+
space_1: A Gymnasium space
586+
space_2: A Gymnasium space
587+
588+
Returns:
589+
If the two spaces share a common dtype and shape (plus any critical variables).
590+
"""
591+
if isinstance(space_1, Space) and isinstance(space_2, Space):
592+
raise NotImplementedError(
593+
"`check_dtype_shape_equivalence` doesn't support Generic Gymnasium Spaces, "
594+
)
595+
else:
596+
raise TypeError()
597+
598+
599+
@is_space_dtype_shape_equiv.register(Box)
600+
@is_space_dtype_shape_equiv.register(Discrete)
601+
@is_space_dtype_shape_equiv.register(MultiDiscrete)
602+
@is_space_dtype_shape_equiv.register(MultiBinary)
603+
def _is_space_fundamental_dtype_shape_equiv(space_1, space_2):
604+
return (
605+
# this check is necessary as singledispatch only checks the first variable and there are many options
606+
type(space_1) is type(space_2)
607+
and space_1.shape == space_2.shape
608+
and space_1.dtype == space_2.dtype
609+
)
610+
611+
612+
@is_space_dtype_shape_equiv.register(Text)
613+
def _is_space_text_dtype_shape_equiv(space_1: Text, space_2):
614+
return (
615+
isinstance(space_2, Text)
616+
and space_1.max_length == space_2.max_length
617+
and space_1.character_set == space_2.character_set
618+
)
619+
620+
621+
@is_space_dtype_shape_equiv.register(Dict)
622+
def _is_space_dict_dtype_shape_equiv(space_1: Dict, space_2):
623+
return (
624+
isinstance(space_2, Dict)
625+
and space_1.keys() == space_2.keys()
626+
and all(
627+
is_space_dtype_shape_equiv(space_1[key], space_2[key])
628+
for key in space_1.keys()
629+
)
630+
)
631+
632+
633+
@is_space_dtype_shape_equiv.register(Tuple)
634+
def _is_space_tuple_dtype_shape_equiv(space_1, space_2):
635+
return isinstance(space_2, Tuple) and all(
636+
is_space_dtype_shape_equiv(space_1[i], space_2[i]) for i in range(len(space_1))
637+
)
638+
639+
640+
@is_space_dtype_shape_equiv.register(Graph)
641+
def _is_space_graph_dtype_shape_equiv(space_1: Graph, space_2):
642+
return (
643+
isinstance(space_2, Graph)
644+
and is_space_dtype_shape_equiv(space_1.node_space, space_2.node_space)
645+
and (
646+
(space_1.edge_space is None and space_2.edge_space is None)
647+
or (
648+
space_1.edge_space is not None
649+
and space_2.edge_space is not None
650+
and is_space_dtype_shape_equiv(space_1.edge_space, space_2.edge_space)
651+
)
652+
)
653+
)
654+
655+
656+
@is_space_dtype_shape_equiv.register(OneOf)
657+
def _is_space_oneof_dtype_shape_equiv(space_1: OneOf, space_2):
658+
return (
659+
isinstance(space_2, OneOf)
660+
and len(space_1) == len(space_2)
661+
and all(
662+
is_space_dtype_shape_equiv(space_1[i], space_2[i])
663+
for i in range(len(space_1))
664+
)
665+
)
666+
667+
668+
@is_space_dtype_shape_equiv.register(Sequence)
669+
def _is_space_sequence_dtype_shape_equiv(space_1: Sequence, space_2):
670+
return (
671+
isinstance(space_2, Sequence)
672+
and space_1.stack is space_2.stack
673+
and is_space_dtype_shape_equiv(space_1.feature_space, space_2.feature_space)
674+
)

Diff for: gymnasium/vector/async_vector_env.py

+48-41
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
CustomSpaceError,
2323
NoAsyncCallError,
2424
)
25+
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
2526
from gymnasium.vector.utils import (
2627
CloudpickleWrapper,
28+
batch_differing_spaces,
2729
batch_space,
2830
clear_mpi_env_vars,
2931
concatenate,
@@ -33,10 +35,6 @@
3335
read_from_shared_memory,
3436
write_to_shared_memory,
3537
)
36-
from gymnasium.vector.utils.batched_spaces import (
37-
all_spaces_have_same_shape,
38-
batch_differing_spaces,
39-
)
4038
from gymnasium.vector.vector_env import ArrayType, VectorEnv
4139

4240

@@ -119,13 +117,14 @@ def __init__(
119117
worker: If set, then use that worker in a subprocess instead of a default one.
120118
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
121119
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
122-
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
123-
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
120+
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
121+
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
122+
``observation_space``, warning, may raise unexpected errors.
124123
125124
Warnings:
126125
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
127126
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
128-
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
127+
from the code for ``_worker`` (or ``_async_worker``) method, and add changes.
129128
130129
Raises:
131130
RuntimeError: If the observation space of some sub-environment does not match observation_space
@@ -136,6 +135,7 @@ def __init__(
136135
self.env_fns = env_fns
137136
self.shared_memory = shared_memory
138137
self.copy = copy
138+
self.observation_mode = observation_mode
139139

140140
self.num_envs = len(env_fns)
141141

@@ -148,29 +148,29 @@ def __init__(
148148
self.render_mode = dummy_env.render_mode
149149

150150
self.single_action_space = dummy_env.action_space
151+
self.action_space = batch_space(self.single_action_space, self.num_envs)
151152

152-
if isinstance(observation_mode, Space):
153-
self.observation_space = observation_mode
153+
if isinstance(observation_mode, tuple) and len(observation_mode) == 2:
154+
assert isinstance(observation_mode[0], Space)
155+
assert isinstance(observation_mode[1], Space)
156+
self.observation_space, self.single_observation_space = observation_mode
154157
else:
155158
if observation_mode == "same":
156159
self.single_observation_space = dummy_env.observation_space
157160
self.observation_space = batch_space(
158161
self.single_observation_space, self.num_envs
159162
)
160163
elif observation_mode == "different":
161-
current_spaces = [env().observation_space for env in self.env_fns]
162-
163-
assert all_spaces_have_same_shape(
164-
current_spaces
165-
), "Low & High values for observation spaces can be different but shapes need to be the same"
166-
167-
self.single_observation_space = batch_differing_spaces(current_spaces)
168-
169-
self.observation_space = self.single_observation_space
164+
# the environment is created and instantly destroy, might cause issues for some environment
165+
# but I don't believe there is anything else we can do, for users with issues, pre-compute the spaces and use the custom option.
166+
env_spaces = [env().observation_space for env in self.env_fns]
170167

168+
self.single_observation_space = env_spaces[0]
169+
self.observation_space = batch_differing_spaces(env_spaces)
171170
else:
172-
raise ValueError("Need to pass in mode for batching observations")
173-
self.action_space = batch_space(self.single_action_space, self.num_envs)
171+
raise ValueError(
172+
f"Invalid `observation_mode`, expected: 'same' or 'different' or tuple of single and batch observation space, actual got {observation_mode}"
173+
)
174174

175175
dummy_env.close()
176176
del dummy_env
@@ -187,9 +187,7 @@ def __init__(
187187
)
188188
except CustomSpaceError as e:
189189
raise ValueError(
190-
"Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), "
191-
"and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. "
192-
"Set `shared_memory=False` if you use custom observation spaces."
190+
"Using `AsyncVector(..., shared_memory=True)` caused an error, you can disable this feature with `shared_memory=False` however this is slower."
193191
) from e
194192
else:
195193
_obs_buffer = None
@@ -616,20 +614,33 @@ def _poll_pipe_envs(self, timeout: int | None = None):
616614

617615
def _check_spaces(self):
618616
self._assert_is_running()
619-
spaces = (self.single_observation_space, self.single_action_space)
620617

621618
for pipe in self.parent_pipes:
622-
pipe.send(("_check_spaces", spaces))
619+
pipe.send(
620+
(
621+
"_check_spaces",
622+
(
623+
self.observation_mode,
624+
self.single_observation_space,
625+
self.single_action_space,
626+
),
627+
)
628+
)
623629

624630
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
625631
self._raise_if_errors(successes)
626632
same_observation_spaces, same_action_spaces = zip(*results)
627633

628634
if not all(same_observation_spaces):
629-
raise RuntimeError(
630-
f"Some environments have an observation space different from `{self.single_observation_space}`. "
631-
"In order to batch observations, the observation spaces from all environments must be equal."
632-
)
635+
if self.observation_mode == "same":
636+
raise RuntimeError(
637+
"AsyncVectorEnv(..., observation_mode='same') however some of the sub-environments observation spaces are not equivalent. If this is intentional, use `observation_mode='different'` instead."
638+
)
639+
else:
640+
raise RuntimeError(
641+
"AsyncVectorEnv(..., observation_mode='different' or custom space) however the sub-environment's observation spaces do not share a common shape and dtype."
642+
)
643+
633644
if not all(same_action_spaces):
634645
raise RuntimeError(
635646
f"Some environments have an action space different from `{self.single_action_space}`. "
@@ -739,23 +750,19 @@ def _async_worker(
739750
env.set_wrapper_attr(name, value)
740751
pipe.send((None, True))
741752
elif command == "_check_spaces":
753+
obs_mode, single_obs_space, single_action_space = data
754+
742755
pipe.send(
743756
(
744757
(
745-
(data[0] == observation_space)
746-
or (
747-
hasattr(observation_space, "low")
748-
and hasattr(observation_space, "high")
749-
and np.any(
750-
np.all(observation_space.low == data[0].low, axis=1)
751-
)
752-
and np.any(
753-
np.all(
754-
observation_space.high == data[0].high, axis=1
755-
)
758+
(
759+
single_obs_space == observation_space
760+
if obs_mode == "same"
761+
else is_space_dtype_shape_equiv(
762+
single_obs_space, observation_space
756763
)
757764
),
758-
data[1] == action_space,
765+
single_action_space == action_space,
759766
),
760767
True,
761768
)

0 commit comments

Comments
 (0)