|
1 | 1 | """Vector wrappers.""" |
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import multiprocessing |
| 6 | +import sys |
3 | 7 | import time |
| 8 | +import traceback |
4 | 9 | from copy import deepcopy |
5 | | -from typing import Any, Dict, Iterator, Tuple |
| 10 | +from multiprocessing import Array, Queue |
| 11 | +from multiprocessing.connection import Connection |
| 12 | +from typing import Any, Callable, Iterator, Sequence |
6 | 13 |
|
7 | 14 | import gymnasium as gym |
8 | 15 | import numpy as np |
9 | 16 | from gymnasium.core import ActType, ObsType |
10 | | -from gymnasium.vector import SyncVectorEnv |
11 | | -from gymnasium.vector.utils import concatenate, iterate |
12 | | -from gymnasium.vector.vector_env import ArrayType, VectorEnv |
| 17 | +from gymnasium.error import NoAsyncCallError |
| 18 | +from gymnasium.spaces.utils import is_space_dtype_shape_equiv |
| 19 | +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv |
| 20 | +from gymnasium.vector.async_vector_env import AsyncState |
| 21 | +from gymnasium.vector.utils import ( |
| 22 | + concatenate, |
| 23 | + create_empty_array, |
| 24 | + iterate, |
| 25 | + write_to_shared_memory, |
| 26 | +) |
| 27 | +from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv |
13 | 28 | from gymnasium.wrappers.vector import RecordEpisodeStatistics |
14 | 29 |
|
15 | 30 |
|
@@ -60,7 +75,7 @@ def __init__( |
60 | 75 | dtype=np.float32, |
61 | 76 | ) |
62 | 77 |
|
63 | | - def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]: |
| 78 | + def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: |
64 | 79 | """Steps through each of the environments returning the batched results. |
65 | 80 |
|
66 | 81 | Returns: |
@@ -103,6 +118,244 @@ def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayTy |
103 | 118 | ) |
104 | 119 |
|
105 | 120 |
|
| 121 | +def _mo_async_worker( |
| 122 | + index: int, |
| 123 | + env_fn: callable, |
| 124 | + pipe: Connection, |
| 125 | + parent_pipe: Connection, |
| 126 | + shared_memory: Array | dict[str, Any] | tuple[Any, ...], |
| 127 | + error_queue: Queue, |
| 128 | + autoreset_mode: AutoresetMode, |
| 129 | +): |
| 130 | + env = env_fn() |
| 131 | + observation_space = env.observation_space |
| 132 | + action_space = env.action_space |
| 133 | + reward_space = env.unwrapped.reward_space |
| 134 | + autoreset = False |
| 135 | + observation = None |
| 136 | + |
| 137 | + parent_pipe.close() |
| 138 | + |
| 139 | + try: |
| 140 | + while True: |
| 141 | + command, data = pipe.recv() |
| 142 | + |
| 143 | + if command == "reset": |
| 144 | + observation, info = env.reset(**data) |
| 145 | + if shared_memory: |
| 146 | + write_to_shared_memory(observation_space, index, observation, shared_memory) |
| 147 | + observation = None |
| 148 | + autoreset = False |
| 149 | + pipe.send(((observation, info), True)) |
| 150 | + elif command == "reset-noop": |
| 151 | + pipe.send(((observation, {}), True)) |
| 152 | + elif command == "step": |
| 153 | + if autoreset_mode == AutoresetMode.NEXT_STEP: |
| 154 | + if autoreset: |
| 155 | + observation, info = env.reset() |
| 156 | + reward, terminated, truncated = ( |
| 157 | + np.zeros(reward_space.shape[0], dtype=np.float32), |
| 158 | + False, |
| 159 | + False, |
| 160 | + ) |
| 161 | + else: |
| 162 | + ( |
| 163 | + observation, |
| 164 | + reward, |
| 165 | + terminated, |
| 166 | + truncated, |
| 167 | + info, |
| 168 | + ) = env.step(data) |
| 169 | + autoreset = terminated or truncated |
| 170 | + elif autoreset_mode == AutoresetMode.SAME_STEP: |
| 171 | + ( |
| 172 | + observation, |
| 173 | + reward, |
| 174 | + terminated, |
| 175 | + truncated, |
| 176 | + info, |
| 177 | + ) = env.step(data) |
| 178 | + |
| 179 | + if terminated or truncated: |
| 180 | + reset_observation, reset_info = env.reset() |
| 181 | + |
| 182 | + info = { |
| 183 | + "final_info": info, |
| 184 | + "final_obs": observation, |
| 185 | + **reset_info, |
| 186 | + } |
| 187 | + observation = reset_observation |
| 188 | + elif autoreset_mode == AutoresetMode.DISABLED: |
| 189 | + assert autoreset is False |
| 190 | + ( |
| 191 | + observation, |
| 192 | + reward, |
| 193 | + terminated, |
| 194 | + truncated, |
| 195 | + info, |
| 196 | + ) = env.step(data) |
| 197 | + else: |
| 198 | + raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}") |
| 199 | + |
| 200 | + if shared_memory: |
| 201 | + write_to_shared_memory(observation_space, index, observation, shared_memory) |
| 202 | + observation = None |
| 203 | + |
| 204 | + pipe.send(((observation, reward, terminated, truncated, info), True)) |
| 205 | + elif command == "close": |
| 206 | + pipe.send((None, True)) |
| 207 | + break |
| 208 | + elif command == "_call": |
| 209 | + name, args, kwargs = data |
| 210 | + if name in ["reset", "step", "close", "_setattr", "_check_spaces"]: |
| 211 | + raise ValueError(f"Trying to call function `{name}` with `call`, use `{name}` directly instead.") |
| 212 | + |
| 213 | + attr = env.get_wrapper_attr(name) |
| 214 | + if callable(attr): |
| 215 | + pipe.send((attr(*args, **kwargs), True)) |
| 216 | + else: |
| 217 | + pipe.send((attr, True)) |
| 218 | + elif command == "_setattr": |
| 219 | + name, value = data |
| 220 | + env.set_wrapper_attr(name, value) |
| 221 | + pipe.send((None, True)) |
| 222 | + elif command == "_check_spaces": |
| 223 | + obs_mode, single_obs_space, single_action_space = data |
| 224 | + |
| 225 | + pipe.send( |
| 226 | + ( |
| 227 | + ( |
| 228 | + ( |
| 229 | + single_obs_space == observation_space |
| 230 | + if obs_mode == "same" |
| 231 | + else is_space_dtype_shape_equiv(single_obs_space, observation_space) |
| 232 | + ), |
| 233 | + single_action_space == action_space, |
| 234 | + ), |
| 235 | + True, |
| 236 | + ) |
| 237 | + ) |
| 238 | + else: |
| 239 | + raise RuntimeError( |
| 240 | + f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." |
| 241 | + ) |
| 242 | + except (KeyboardInterrupt, Exception): |
| 243 | + error_type, error_message, _ = sys.exc_info() |
| 244 | + trace = traceback.format_exc() |
| 245 | + |
| 246 | + error_queue.put((index, error_type, error_message, trace)) |
| 247 | + pipe.send((None, False)) |
| 248 | + finally: |
| 249 | + env.close() |
| 250 | + |
| 251 | + |
| 252 | +class MOAsyncVectorEnv(AsyncVectorEnv): |
| 253 | + """Vectorized environment that runs multiple environments in parallel. |
| 254 | +
|
| 255 | + It uses ``multiprocessing`` processes, and pipes for communication. |
| 256 | +
|
| 257 | + Modified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards. |
| 258 | +
|
| 259 | + Example: |
| 260 | + >>> import mo_gymnasium as mo_gym |
| 261 | + >>> envs = mo_gym.wrappers.vector.MOAsyncVectorEnv([ |
| 262 | + ... lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(4) |
| 263 | + ... ]) |
| 264 | + >>> envs |
| 265 | + MOAsyncVectorEnv(num_envs=4) |
| 266 | + >>> obs, infos = envs.reset() |
| 267 | + >>> obs |
| 268 | + array([[0, 0], [0, 0], [0, 0], [0, 0]], dtype=int32) |
| 269 | + >>> _ = envs.action_space.seed(42) |
| 270 | + >>> actions = envs.action_space.sample() |
| 271 | + >>> obs, rewards, terminateds, truncateds, infos = envs.step([0, 1, 2, 3]) |
| 272 | + >>> obs |
| 273 | + array([[0, 0], [1, 0], [0, 0], [0, 3]], dtype=int32) |
| 274 | + >>> rewards |
| 275 | + array([[0., -1.], [0.7, -1.], [0., -1.], [0., -1.]], dtype=float32) |
| 276 | + >>> terminateds |
| 277 | + array([False, True, False, False]) |
| 278 | + """ |
| 279 | + |
| 280 | + def __init__(self, env_fns: Sequence[Callable[[], gym.Env]], **kwargs): |
| 281 | + """Vectorized environment that runs multiple environments in parallel. |
| 282 | +
|
| 283 | + Args: |
| 284 | + env_fns: env constructors |
| 285 | + """ |
| 286 | + super().__init__(env_fns=env_fns, worker=_mo_async_worker, **kwargs) |
| 287 | + |
| 288 | + # extract reward space from first vector env and create 2d array to store vector rewards |
| 289 | + dummy_env = env_fns[0]() |
| 290 | + self.reward_space = dummy_env.unwrapped.reward_space |
| 291 | + dummy_env.close() |
| 292 | + del dummy_env |
| 293 | + self.rewards = create_empty_array(self.reward_space, n=self.num_envs, fn=np.zeros) |
| 294 | + |
| 295 | + def step_wait(self, timeout: int | float | None = None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]: |
| 296 | + """Wait for the calls to :obj:`step` in each sub-environment to finish. |
| 297 | +
|
| 298 | + Args: |
| 299 | + timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. |
| 300 | +
|
| 301 | + Returns: |
| 302 | + The batched environment step information, (obs, reward, terminated, truncated, info) |
| 303 | +
|
| 304 | + Raises: |
| 305 | + ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). |
| 306 | + NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`. |
| 307 | + TimeoutError: If :meth:`step_wait` timed out. |
| 308 | + """ |
| 309 | + self._assert_is_running() |
| 310 | + if self._state != AsyncState.WAITING_STEP: |
| 311 | + raise NoAsyncCallError( |
| 312 | + "Calling `step_wait` without any prior call to `step_async`.", |
| 313 | + AsyncState.WAITING_STEP.value, |
| 314 | + ) |
| 315 | + |
| 316 | + if not self._poll_pipe_envs(timeout): |
| 317 | + self._state = AsyncState.DEFAULT |
| 318 | + raise multiprocessing.TimeoutError(f"The call to `step_wait` has timed out after {timeout} second(s).") |
| 319 | + |
| 320 | + observations, rewards, terminations, truncations, infos = [], [], [], [], {} |
| 321 | + successes = [] |
| 322 | + for env_idx, pipe in enumerate(self.parent_pipes): |
| 323 | + env_step_return, success = pipe.recv() |
| 324 | + |
| 325 | + successes.append(success) |
| 326 | + if success: |
| 327 | + observations.append(env_step_return[0]) |
| 328 | + rewards.append(env_step_return[1]) |
| 329 | + terminations.append(env_step_return[2]) |
| 330 | + truncations.append(env_step_return[3]) |
| 331 | + infos = self._add_info(infos, env_step_return[4], env_idx) |
| 332 | + |
| 333 | + self._raise_if_errors(successes) |
| 334 | + |
| 335 | + if not self.shared_memory: |
| 336 | + self.observations = concatenate( |
| 337 | + self.single_observation_space, |
| 338 | + observations, |
| 339 | + self.observations, |
| 340 | + ) |
| 341 | + |
| 342 | + # modify to allow return of vector rewards |
| 343 | + self.rewards = concatenate( |
| 344 | + self.reward_space, |
| 345 | + rewards, |
| 346 | + self.rewards, |
| 347 | + ) |
| 348 | + |
| 349 | + self._state = AsyncState.DEFAULT |
| 350 | + return ( |
| 351 | + deepcopy(self.observations) if self.copy else self.observations, |
| 352 | + deepcopy(self.rewards) if self.copy else self.rewards, |
| 353 | + np.array(terminations, dtype=np.bool_), |
| 354 | + np.array(truncations, dtype=np.bool_), |
| 355 | + infos, |
| 356 | + ) |
| 357 | + |
| 358 | + |
106 | 359 | class MORecordEpisodeStatistics(RecordEpisodeStatistics): |
107 | 360 | """This wrapper will keep track of cumulative rewards and episode lengths. |
108 | 361 |
|
@@ -162,7 +415,7 @@ def reset(self, **kwargs): |
162 | 415 |
|
163 | 416 | return obs, info |
164 | 417 |
|
165 | | - def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, Dict[str, Any]]: |
| 418 | + def step(self, actions: ActType) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: |
166 | 419 | """Steps through the environment, recording the episode statistics.""" |
167 | 420 | ( |
168 | 421 | observations, |
|
0 commit comments