Skip to content

Commit 95615fc

Browse files
JaydenTeohffelten
andauthored
Add MOAsyncVectorEnv wrapper (#115)
Co-authored-by: Florian Felten <felten.florian@hotmail.fr>
1 parent f2324f8 commit 95615fc

File tree

4 files changed

+329
-24
lines changed

4 files changed

+329
-24
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Vector wrappers."""
22

33
from mo_gymnasium.wrappers.vector.wrappers import (
4+
MOAsyncVectorEnv,
45
MORecordEpisodeStatistics,
56
MOSyncVectorEnv,
67
)

mo_gymnasium/wrappers/vector/wrappers.py

Lines changed: 259 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
"""Vector wrappers."""
22

3+
from __future__ import annotations
4+
5+
import multiprocessing
6+
import sys
37
import time
8+
import traceback
49
from copy import deepcopy
5-
from typing import Any, Dict, Iterator, Tuple
10+
from multiprocessing import Array, Queue
11+
from multiprocessing.connection import Connection
12+
from typing import Any, Callable, Iterator, Sequence
613

714
import gymnasium as gym
815
import numpy as np
916
from gymnasium.core import ActType, ObsType
10-
from gymnasium.vector import SyncVectorEnv
11-
from gymnasium.vector.utils import concatenate, iterate
12-
from gymnasium.vector.vector_env import ArrayType, VectorEnv
17+
from gymnasium.error import NoAsyncCallError
18+
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
19+
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
20+
from gymnasium.vector.async_vector_env import AsyncState
21+
from gymnasium.vector.utils import (
22+
concatenate,
23+
create_empty_array,
24+
iterate,
25+
write_to_shared_memory,
26+
)
27+
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv
1328
from gymnasium.wrappers.vector import RecordEpisodeStatistics
1429

1530

@@ -60,7 +75,7 @@ def __init__(
6075
dtype=np.float32,
6176
)
6277

63-
def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]:
78+
def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
6479
"""Steps through each of the environments returning the batched results.
6580
6681
Returns:
@@ -103,6 +118,244 @@ def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayTy
103118
)
104119

105120

121+
def _mo_async_worker(
122+
index: int,
123+
env_fn: callable,
124+
pipe: Connection,
125+
parent_pipe: Connection,
126+
shared_memory: Array | dict[str, Any] | tuple[Any, ...],
127+
error_queue: Queue,
128+
autoreset_mode: AutoresetMode,
129+
):
130+
env = env_fn()
131+
observation_space = env.observation_space
132+
action_space = env.action_space
133+
reward_space = env.unwrapped.reward_space
134+
autoreset = False
135+
observation = None
136+
137+
parent_pipe.close()
138+
139+
try:
140+
while True:
141+
command, data = pipe.recv()
142+
143+
if command == "reset":
144+
observation, info = env.reset(**data)
145+
if shared_memory:
146+
write_to_shared_memory(observation_space, index, observation, shared_memory)
147+
observation = None
148+
autoreset = False
149+
pipe.send(((observation, info), True))
150+
elif command == "reset-noop":
151+
pipe.send(((observation, {}), True))
152+
elif command == "step":
153+
if autoreset_mode == AutoresetMode.NEXT_STEP:
154+
if autoreset:
155+
observation, info = env.reset()
156+
reward, terminated, truncated = (
157+
np.zeros(reward_space.shape[0], dtype=np.float32),
158+
False,
159+
False,
160+
)
161+
else:
162+
(
163+
observation,
164+
reward,
165+
terminated,
166+
truncated,
167+
info,
168+
) = env.step(data)
169+
autoreset = terminated or truncated
170+
elif autoreset_mode == AutoresetMode.SAME_STEP:
171+
(
172+
observation,
173+
reward,
174+
terminated,
175+
truncated,
176+
info,
177+
) = env.step(data)
178+
179+
if terminated or truncated:
180+
reset_observation, reset_info = env.reset()
181+
182+
info = {
183+
"final_info": info,
184+
"final_obs": observation,
185+
**reset_info,
186+
}
187+
observation = reset_observation
188+
elif autoreset_mode == AutoresetMode.DISABLED:
189+
assert autoreset is False
190+
(
191+
observation,
192+
reward,
193+
terminated,
194+
truncated,
195+
info,
196+
) = env.step(data)
197+
else:
198+
raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")
199+
200+
if shared_memory:
201+
write_to_shared_memory(observation_space, index, observation, shared_memory)
202+
observation = None
203+
204+
pipe.send(((observation, reward, terminated, truncated, info), True))
205+
elif command == "close":
206+
pipe.send((None, True))
207+
break
208+
elif command == "_call":
209+
name, args, kwargs = data
210+
if name in ["reset", "step", "close", "_setattr", "_check_spaces"]:
211+
raise ValueError(f"Trying to call function `{name}` with `call`, use `{name}` directly instead.")
212+
213+
attr = env.get_wrapper_attr(name)
214+
if callable(attr):
215+
pipe.send((attr(*args, **kwargs), True))
216+
else:
217+
pipe.send((attr, True))
218+
elif command == "_setattr":
219+
name, value = data
220+
env.set_wrapper_attr(name, value)
221+
pipe.send((None, True))
222+
elif command == "_check_spaces":
223+
obs_mode, single_obs_space, single_action_space = data
224+
225+
pipe.send(
226+
(
227+
(
228+
(
229+
single_obs_space == observation_space
230+
if obs_mode == "same"
231+
else is_space_dtype_shape_equiv(single_obs_space, observation_space)
232+
),
233+
single_action_space == action_space,
234+
),
235+
True,
236+
)
237+
)
238+
else:
239+
raise RuntimeError(
240+
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
241+
)
242+
except (KeyboardInterrupt, Exception):
243+
error_type, error_message, _ = sys.exc_info()
244+
trace = traceback.format_exc()
245+
246+
error_queue.put((index, error_type, error_message, trace))
247+
pipe.send((None, False))
248+
finally:
249+
env.close()
250+
251+
252+
class MOAsyncVectorEnv(AsyncVectorEnv):
253+
"""Vectorized environment that runs multiple environments in parallel.
254+
255+
It uses ``multiprocessing`` processes, and pipes for communication.
256+
257+
Modified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards.
258+
259+
Example:
260+
>>> import mo_gymnasium as mo_gym
261+
>>> envs = mo_gym.wrappers.vector.MOAsyncVectorEnv([
262+
... lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(4)
263+
... ])
264+
>>> envs
265+
MOAsyncVectorEnv(num_envs=4)
266+
>>> obs, infos = envs.reset()
267+
>>> obs
268+
array([[0, 0], [0, 0], [0, 0], [0, 0]], dtype=int32)
269+
>>> _ = envs.action_space.seed(42)
270+
>>> actions = envs.action_space.sample()
271+
>>> obs, rewards, terminateds, truncateds, infos = envs.step([0, 1, 2, 3])
272+
>>> obs
273+
array([[0, 0], [1, 0], [0, 0], [0, 3]], dtype=int32)
274+
>>> rewards
275+
array([[0., -1.], [0.7, -1.], [0., -1.], [0., -1.]], dtype=float32)
276+
>>> terminateds
277+
array([False, True, False, False])
278+
"""
279+
280+
def __init__(self, env_fns: Sequence[Callable[[], gym.Env]], **kwargs):
281+
"""Vectorized environment that runs multiple environments in parallel.
282+
283+
Args:
284+
env_fns: env constructors
285+
"""
286+
super().__init__(env_fns=env_fns, worker=_mo_async_worker, **kwargs)
287+
288+
# extract reward space from first vector env and create 2d array to store vector rewards
289+
dummy_env = env_fns[0]()
290+
self.reward_space = dummy_env.unwrapped.reward_space
291+
dummy_env.close()
292+
del dummy_env
293+
self.rewards = create_empty_array(self.reward_space, n=self.num_envs, fn=np.zeros)
294+
295+
def step_wait(self, timeout: int | float | None = None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
296+
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
297+
298+
Args:
299+
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
300+
301+
Returns:
302+
The batched environment step information, (obs, reward, terminated, truncated, info)
303+
304+
Raises:
305+
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
306+
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
307+
TimeoutError: If :meth:`step_wait` timed out.
308+
"""
309+
self._assert_is_running()
310+
if self._state != AsyncState.WAITING_STEP:
311+
raise NoAsyncCallError(
312+
"Calling `step_wait` without any prior call to `step_async`.",
313+
AsyncState.WAITING_STEP.value,
314+
)
315+
316+
if not self._poll_pipe_envs(timeout):
317+
self._state = AsyncState.DEFAULT
318+
raise multiprocessing.TimeoutError(f"The call to `step_wait` has timed out after {timeout} second(s).")
319+
320+
observations, rewards, terminations, truncations, infos = [], [], [], [], {}
321+
successes = []
322+
for env_idx, pipe in enumerate(self.parent_pipes):
323+
env_step_return, success = pipe.recv()
324+
325+
successes.append(success)
326+
if success:
327+
observations.append(env_step_return[0])
328+
rewards.append(env_step_return[1])
329+
terminations.append(env_step_return[2])
330+
truncations.append(env_step_return[3])
331+
infos = self._add_info(infos, env_step_return[4], env_idx)
332+
333+
self._raise_if_errors(successes)
334+
335+
if not self.shared_memory:
336+
self.observations = concatenate(
337+
self.single_observation_space,
338+
observations,
339+
self.observations,
340+
)
341+
342+
# modify to allow return of vector rewards
343+
self.rewards = concatenate(
344+
self.reward_space,
345+
rewards,
346+
self.rewards,
347+
)
348+
349+
self._state = AsyncState.DEFAULT
350+
return (
351+
deepcopy(self.observations) if self.copy else self.observations,
352+
deepcopy(self.rewards) if self.copy else self.rewards,
353+
np.array(terminations, dtype=np.bool_),
354+
np.array(truncations, dtype=np.bool_),
355+
infos,
356+
)
357+
358+
106359
class MORecordEpisodeStatistics(RecordEpisodeStatistics):
107360
"""This wrapper will keep track of cumulative rewards and episode lengths.
108361
@@ -162,7 +415,7 @@ def reset(self, **kwargs):
162415

163416
return obs, info
164417

165-
def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]:
418+
def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
166419
"""Steps through the environment, recording the episode statistics."""
167420
(
168421
observations,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
'Topic :: Scientific/Engineering :: Artificial Intelligence',
2323
]
2424
dependencies = [
25-
"gymnasium >=1.0.0",
25+
"gymnasium >=1.1.0",
2626
"numpy >=1.21.0,<2.0",
2727
"pygame >=2.1.3",
2828
"scipy >=1.7.3",

0 commit comments

Comments
 (0)