Skip to content

Commit d20ac56

Browse files
Add (A)syncVectorEnv support for sub-envs with different obs spaces (#1126)
1 parent 1a92702 commit d20ac56

File tree

5 files changed

+360
-19
lines changed

5 files changed

+360
-19
lines changed

Diff for: gymnasium/vector/async_vector_env.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616

17-
from gymnasium import logger
17+
from gymnasium import Space, logger
1818
from gymnasium.core import ActType, Env, ObsType, RenderFrame
1919
from gymnasium.error import (
2020
AlreadyPendingCallError,
@@ -33,6 +33,10 @@
3333
read_from_shared_memory,
3434
write_to_shared_memory,
3535
)
36+
from gymnasium.vector.utils.batched_spaces import (
37+
all_spaces_have_same_shape,
38+
batch_differing_spaces,
39+
)
3640
from gymnasium.vector.vector_env import ArrayType, VectorEnv
3741

3842

@@ -98,6 +102,7 @@ def __init__(
98102
]
99103
| None
100104
) = None,
105+
observation_mode: str | Space = "same",
101106
):
102107
"""Vectorized environment that runs multiple environments in parallel.
103108
@@ -113,6 +118,9 @@ def __init__(
113118
so for some environments you may want to have it set to ``False``.
114119
worker: If set, then use that worker in a subprocess instead of a default one.
115120
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
121+
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
122+
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
123+
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
116124
117125
Warnings:
118126
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
@@ -139,12 +147,29 @@ def __init__(
139147
self.metadata = dummy_env.metadata
140148
self.render_mode = dummy_env.render_mode
141149

142-
self.single_observation_space = dummy_env.observation_space
143150
self.single_action_space = dummy_env.action_space
144151

145-
self.observation_space = batch_space(
146-
self.single_observation_space, self.num_envs
147-
)
152+
if isinstance(observation_mode, Space):
153+
self.observation_space = observation_mode
154+
else:
155+
if observation_mode == "same":
156+
self.single_observation_space = dummy_env.observation_space
157+
self.observation_space = batch_space(
158+
self.single_observation_space, self.num_envs
159+
)
160+
elif observation_mode == "different":
161+
current_spaces = [env().observation_space for env in self.env_fns]
162+
163+
assert all_spaces_have_same_shape(
164+
current_spaces
165+
), "Low & High values for observation spaces can be different but shapes need to be the same"
166+
167+
self.single_observation_space = batch_differing_spaces(current_spaces)
168+
169+
self.observation_space = self.single_observation_space
170+
171+
else:
172+
raise ValueError("Need to pass in mode for batching observations")
148173
self.action_space = batch_space(self.single_action_space, self.num_envs)
149174

150175
dummy_env.close()
@@ -716,7 +741,22 @@ def _async_worker(
716741
elif command == "_check_spaces":
717742
pipe.send(
718743
(
719-
(data[0] == observation_space, data[1] == action_space),
744+
(
745+
(data[0] == observation_space)
746+
or (
747+
hasattr(observation_space, "low")
748+
and hasattr(observation_space, "high")
749+
and np.any(
750+
np.all(observation_space.low == data[0].low, axis=1)
751+
)
752+
and np.any(
753+
np.all(
754+
observation_space.high == data[0].high, axis=1
755+
)
756+
)
757+
),
758+
data[1] == action_space,
759+
),
720760
True,
721761
)
722762
)

Diff for: gymnasium/vector/sync_vector_env.py

+62-12
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77

88
import numpy as np
99

10-
from gymnasium import Env
10+
from gymnasium import Env, Space
1111
from gymnasium.core import ActType, ObsType, RenderFrame
1212
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
13+
from gymnasium.vector.utils.batched_spaces import (
14+
all_spaces_have_same_shape,
15+
all_spaces_have_same_type,
16+
batch_differing_spaces,
17+
)
1318
from gymnasium.vector.vector_env import ArrayType, VectorEnv
1419

1520

@@ -57,13 +62,16 @@ def __init__(
5762
self,
5863
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
5964
copy: bool = True,
65+
observation_mode: str | Space = "same",
6066
):
6167
"""Vectorized environment that serially runs multiple environments.
6268
6369
Args:
6470
env_fns: iterable of callable functions that create the environments.
6571
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
66-
72+
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
73+
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
74+
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
6775
Raises:
6876
RuntimeError: If the observation space of some sub-environment does not match observation_space
6977
(or, by default, the observation space of the first sub-environment).
@@ -80,15 +88,39 @@ def __init__(
8088
self.metadata = self.envs[0].metadata
8189
self.render_mode = self.envs[0].render_mode
8290

83-
# Initialises the single spaces from the sub-environments
84-
self.single_observation_space = self.envs[0].observation_space
8591
self.single_action_space = self.envs[0].action_space
92+
93+
# Initialise the obs and action space based on the desired mode
94+
95+
if isinstance(observation_mode, Space):
96+
self.observation_space = observation_mode
97+
else:
98+
if observation_mode == "same":
99+
self.single_observation_space = self.envs[0].observation_space
100+
self.single_action_space = self.envs[0].action_space
101+
102+
self.observation_space = batch_space(
103+
self.single_observation_space, self.num_envs
104+
)
105+
elif observation_mode == "different":
106+
current_spaces = [env.observation_space for env in self.envs]
107+
108+
assert all_spaces_have_same_shape(
109+
current_spaces
110+
), "Low & High values for observation spaces can be different but shapes need to be the same"
111+
assert all_spaces_have_same_type(
112+
current_spaces
113+
), "Observation spaces must have same Space type"
114+
115+
self.observation_space = batch_differing_spaces(current_spaces)
116+
117+
self.single_observation_space = self.observation_space
118+
119+
else:
120+
raise ValueError("Need to pass in mode for batching observations")
121+
86122
self._check_spaces()
87123

88-
# Initialise the obs and action space based on the single versions and num of sub-environments
89-
self.observation_space = batch_space(
90-
self.single_observation_space, self.num_envs
91-
)
92124
self.action_space = batch_space(self.single_action_space, self.num_envs)
93125

94126
# Initialise attributes used in `step` and `reset`
@@ -270,10 +302,28 @@ def _check_spaces(self) -> bool:
270302
"""Check that each of the environments obs and action spaces are equivalent to the single obs and action space."""
271303
for env in self.envs:
272304
if not (env.observation_space == self.single_observation_space):
273-
raise RuntimeError(
274-
f"Some environments have an observation space different from `{self.single_observation_space}`. "
275-
"In order to batch observations, the observation spaces from all environments must be equal."
276-
)
305+
if not (
306+
hasattr(env.observation_space, "low")
307+
and hasattr(env.observation_space, "high")
308+
and np.any(
309+
np.all(
310+
env.observation_space.low
311+
== self.single_observation_space.low,
312+
axis=1,
313+
)
314+
)
315+
and np.any(
316+
np.all(
317+
env.observation_space.high
318+
== self.single_observation_space.high,
319+
axis=1,
320+
)
321+
)
322+
):
323+
raise RuntimeError(
324+
f"Some environments have an observation space different from `{self.single_observation_space}`. "
325+
"In order to batch observations, the observation spaces from all environments must be equal."
326+
)
277327

278328
if not (env.action_space == self.single_action_space):
279329
raise RuntimeError(

Diff for: gymnasium/vector/utils/batched_spaces.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Batching support for Spaces of same type but possibly varying low/high values."""
2+
3+
from __future__ import annotations
4+
5+
from copy import deepcopy
6+
from functools import singledispatch
7+
8+
import numpy as np
9+
10+
from gymnasium import Space
11+
from gymnasium.spaces import (
12+
Box,
13+
Dict,
14+
Discrete,
15+
Graph,
16+
MultiBinary,
17+
MultiDiscrete,
18+
OneOf,
19+
Sequence,
20+
Text,
21+
Tuple,
22+
)
23+
24+
25+
@singledispatch
26+
def batch_differing_spaces(spaces: list[Space]):
27+
"""Batch a Sequence of spaces that allows the subspaces to contain minor differences."""
28+
assert len(spaces) > 0
29+
assert all(isinstance(space, type(spaces[0])) for space in spaces)
30+
assert type(spaces[0]) in batch_differing_spaces.registry
31+
32+
return batch_differing_spaces.dispatch(type(spaces[0]))(spaces)
33+
34+
35+
@batch_differing_spaces.register(Box)
36+
def _batch_differing_spaces_box(spaces: list[Box]):
37+
assert all(spaces[0].dtype == space for space in spaces)
38+
39+
return Box(
40+
low=np.array([space.low for space in spaces]),
41+
high=np.array([space.high for space in spaces]),
42+
dtype=spaces[0].dtype,
43+
seed=deepcopy(spaces[0].np_random),
44+
)
45+
46+
47+
@batch_differing_spaces.register(Discrete)
48+
def _batch_differing_spaces_discrete(spaces: list[Discrete]):
49+
return MultiDiscrete(
50+
nvec=np.array([space.n for space in spaces]),
51+
start=np.array([space.start for space in spaces]),
52+
seed=deepcopy(spaces[0].np_random),
53+
)
54+
55+
56+
@batch_differing_spaces.register(MultiDiscrete)
57+
def _batch_differing_spaces_multi_discrete(spaces: list[MultiDiscrete]):
58+
return Box(
59+
low=np.array([space.start for space in spaces]),
60+
high=np.array([space.start + space.nvec for space in spaces]) - 1,
61+
dtype=spaces[0].dtype,
62+
seed=deepcopy(spaces[0].np_random),
63+
)
64+
65+
66+
@batch_differing_spaces.register(MultiBinary)
67+
def _batch_differing_spaces_multi_binary(spaces: list[MultiBinary]):
68+
assert all(spaces[0].shape == space.shape for space in spaces)
69+
70+
return Box(
71+
low=0,
72+
high=1,
73+
shape=(len(spaces),) + spaces[0].shape,
74+
dtype=spaces[0].dtype,
75+
seed=deepcopy(spaces[0].np_random),
76+
)
77+
78+
79+
@batch_differing_spaces.register(Tuple)
80+
def _batch_differing_spaces_tuple(spaces: list[Tuple]):
81+
return Tuple(
82+
tuple(
83+
batch_differing_spaces(subspaces)
84+
for subspaces in zip(*[space.spaces for space in spaces])
85+
),
86+
seed=deepcopy(spaces[0].np_random),
87+
)
88+
89+
90+
@batch_differing_spaces.register(Dict)
91+
def _batch_differing_spaces_dict(spaces: list[Dict]):
92+
assert all(spaces[0].keys() == space.keys() for space in spaces)
93+
94+
return Dict(
95+
{
96+
key: batch_differing_spaces([space[key] for space in spaces])
97+
for key in spaces[0].keys()
98+
},
99+
seed=deepcopy(spaces[0].np_random),
100+
)
101+
102+
103+
@batch_differing_spaces.register(Graph)
104+
@batch_differing_spaces.register(Text)
105+
@batch_differing_spaces.register(Sequence)
106+
@batch_differing_spaces.register(OneOf)
107+
def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]):
108+
return Tuple(spaces, seed=deepcopy(spaces[0].np_random))
109+
110+
111+
def all_spaces_have_same_shape(spaces):
112+
"""Check if all spaces have the same size."""
113+
if not spaces:
114+
return True # An empty list is considered to have the same shape
115+
116+
def get_space_shape(space):
117+
if isinstance(space, Box):
118+
return space.shape
119+
elif isinstance(space, Discrete):
120+
return () # Discrete spaces are considered scalar
121+
elif isinstance(space, Dict):
122+
return tuple(get_space_shape(s) for s in space.spaces.values())
123+
elif isinstance(space, Tuple):
124+
return tuple(get_space_shape(s) for s in space.spaces)
125+
else:
126+
raise ValueError(f"Unsupported space type: {type(space)}")
127+
128+
first_shape = get_space_shape(spaces[0])
129+
return all(get_space_shape(space) == first_shape for space in spaces[1:])
130+
131+
132+
def all_spaces_have_same_type(spaces):
133+
"""Check if all spaces have the same space type (Box, Discrete, etc)."""
134+
if not spaces:
135+
return True # An empty list is considered to have the same type
136+
137+
# Get the type of the first space
138+
first_type = type(spaces[0])
139+
140+
# Check if all spaces have the same type as the first one
141+
return all(isinstance(space, first_type) for space in spaces)

0 commit comments

Comments
 (0)