11"""Vector wrappers."""
22
3- import os
3+ import multiprocessing
44import sys
55import time
66import traceback
77from copy import deepcopy
8- from typing import Any , Callable , Dict , Iterator , Tuple , Sequence , Union
9-
10- import multiprocessing
118from multiprocessing import Array , Queue
129from multiprocessing .connection import Connection
13- import numpy as np
10+ from typing import Any , Callable , Dict , Iterator , Sequence , Tuple , Union
1411
1512import gymnasium as gym
1613import numpy as np
1714from gymnasium .core import ActType , ObsType
18- from gymnasium .vector import SyncVectorEnv , AsyncVectorEnv
19- from gymnasium .vector .async_vector_env import AsyncState
20- from gymnasium .vector .utils import concatenate , iterate , create_empty_array , write_to_shared_memory
21- from gymnasium .vector .vector_env import ArrayType , VectorEnv , AutoresetMode
22- from gymnasium .wrappers .vector import RecordEpisodeStatistics
2315from gymnasium .error import NoAsyncCallError
2416from gymnasium .spaces .utils import is_space_dtype_shape_equiv
17+ from gymnasium .vector import AsyncVectorEnv , SyncVectorEnv
18+ from gymnasium .vector .async_vector_env import AsyncState
19+ from gymnasium .vector .utils import (
20+ concatenate ,
21+ create_empty_array ,
22+ iterate ,
23+ write_to_shared_memory ,
24+ )
25+ from gymnasium .vector .vector_env import ArrayType , AutoresetMode , VectorEnv
26+ from gymnasium .wrappers .vector import RecordEpisodeStatistics
2527
2628
2729class MOSyncVectorEnv (SyncVectorEnv ):
@@ -112,7 +114,8 @@ def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayTy
112114 np .copy (self ._truncations ),
113115 infos ,
114116 )
115-
117+
118+
116119def _mo_async_worker (
117120 index : int ,
118121 env_fn : callable ,
@@ -138,9 +141,7 @@ def _mo_async_worker(
138141 if command == "reset" :
139142 observation , info = env .reset (** data )
140143 if shared_memory :
141- write_to_shared_memory (
142- observation_space , index , observation , shared_memory
143- )
144+ write_to_shared_memory (observation_space , index , observation , shared_memory )
144145 observation = None
145146 autoreset = False
146147 pipe .send (((observation , info ), True ))
@@ -150,7 +151,11 @@ def _mo_async_worker(
150151 if autoreset_mode == AutoresetMode .NEXT_STEP :
151152 if autoreset :
152153 observation , info = env .reset ()
153- reward , terminated , truncated = np .zeros (reward_space .shape [0 ], dtype = np .float32 ), False , False
154+ reward , terminated , truncated = (
155+ np .zeros (reward_space .shape [0 ], dtype = np .float32 ),
156+ False ,
157+ False ,
158+ )
154159 else :
155160 (
156161 observation ,
@@ -191,9 +196,7 @@ def _mo_async_worker(
191196 raise ValueError (f"Unexpected autoreset_mode: { autoreset_mode } " )
192197
193198 if shared_memory :
194- write_to_shared_memory (
195- observation_space , index , observation , shared_memory
196- )
199+ write_to_shared_memory (observation_space , index , observation , shared_memory )
197200 observation = None
198201
199202 pipe .send (((observation , reward , terminated , truncated , info ), True ))
@@ -203,9 +206,7 @@ def _mo_async_worker(
203206 elif command == "_call" :
204207 name , args , kwargs = data
205208 if name in ["reset" , "step" , "close" , "_setattr" , "_check_spaces" ]:
206- raise ValueError (
207- f"Trying to call function `{ name } ` with `call`, use `{ name } ` directly instead."
208- )
209+ raise ValueError (f"Trying to call function `{ name } ` with `call`, use `{ name } ` directly instead." )
209210
210211 attr = env .get_wrapper_attr (name )
211212 if callable (attr ):
@@ -225,9 +226,7 @@ def _mo_async_worker(
225226 (
226227 single_obs_space == observation_space
227228 if obs_mode == "same"
228- else is_space_dtype_shape_equiv (
229- single_obs_space , observation_space
230- )
229+ else is_space_dtype_shape_equiv (single_obs_space , observation_space )
231230 ),
232231 single_action_space == action_space ,
233232 ),
@@ -246,14 +245,15 @@ def _mo_async_worker(
246245 pipe .send ((None , False ))
247246 finally :
248247 env .close ()
249-
248+
249+
250250class MOAsyncVectorEnv (AsyncVectorEnv ):
251251 """Vectorized environment that runs multiple environments in parallel.
252252
253253 It uses ``multiprocessing`` processes, and pipes for communication.
254254
255- Mofified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards.
256-
255+ Modified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards.
256+
257257 Example:
258258 >>> import mo_gymnasium as mo_gym
259259 >>> envs = mo_gym.wrappers.vector.MOAsyncVectorEnv([
@@ -274,11 +274,13 @@ class MOAsyncVectorEnv(AsyncVectorEnv):
274274 >>> terminateds
275275 array([False, True, False, False])
276276 """
277- def __init__ (
278- self ,
279- env_fns : Sequence [Callable [[], gym .Env ]],
280- ** kwargs
281- ):
277+
278+ def __init__ (self , env_fns : Sequence [Callable [[], gym .Env ]], ** kwargs ):
279+ """Vectorized environment that runs multiple environments in parallel.
280+
281+ Args:
282+ env_fns: env constructors
283+ """
282284 super ().__init__ (env_fns = env_fns , worker = _mo_async_worker , ** kwargs )
283285
284286 # extract reward space from first vector env and create 2d array to store vector rewards
@@ -288,10 +290,7 @@ def __init__(
288290 del dummy_env
289291 self .rewards = create_empty_array (self .reward_space , n = self .num_envs , fn = np .zeros )
290292
291-
292- def step_wait (
293- self , timeout : int | float | None = None
294- ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , dict ]:
293+ def step_wait (self , timeout : int | float | None = None ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , dict ]:
295294 """Wait for the calls to :obj:`step` in each sub-environment to finish.
296295
297296 Args:
@@ -308,15 +307,13 @@ def step_wait(
308307 self ._assert_is_running ()
309308 if self ._state != AsyncState .WAITING_STEP :
310309 raise NoAsyncCallError (
311- "Calling `step_wait` without any prior call " " to `step_async`." ,
310+ "Calling `step_wait` without any prior call to `step_async`." ,
312311 AsyncState .WAITING_STEP .value ,
313312 )
314313
315314 if not self ._poll_pipe_envs (timeout ):
316315 self ._state = AsyncState .DEFAULT
317- raise multiprocessing .TimeoutError (
318- f"The call to `step_wait` has timed out after { timeout } second(s)."
319- )
316+ raise multiprocessing .TimeoutError (f"The call to `step_wait` has timed out after { timeout } second(s)." )
320317
321318 observations , rewards , terminations , truncations , infos = [], [], [], [], {}
322319 successes = []
@@ -339,7 +336,7 @@ def step_wait(
339336 observations ,
340337 self .observations ,
341338 )
342-
339+
343340 # modify to allow return of vector rewards
344341 self .rewards = concatenate (
345342 self .reward_space ,
0 commit comments