Skip to content

Commit b9fc05b

Browse files
committed
Cleanup PR, bump minimal gymnasium version
1 parent 628666e commit b9fc05b

File tree

4 files changed

+108
-59
lines changed

4 files changed

+108
-59
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Vector wrappers."""
22

33
from mo_gymnasium.wrappers.vector.wrappers import (
4+
MOAsyncVectorEnv,
45
MORecordEpisodeStatistics,
56
MOSyncVectorEnv,
67
)

mo_gymnasium/wrappers/vector/wrappers.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
"""Vector wrappers."""
22

3-
import os
3+
import multiprocessing
44
import sys
55
import time
66
import traceback
77
from copy import deepcopy
8-
from typing import Any, Callable, Dict, Iterator, Tuple, Sequence, Union
9-
10-
import multiprocessing
118
from multiprocessing import Array, Queue
129
from multiprocessing.connection import Connection
13-
import numpy as np
10+
from typing import Any, Callable, Dict, Iterator, Sequence, Tuple, Union
1411

1512
import gymnasium as gym
1613
import numpy as np
1714
from gymnasium.core import ActType, ObsType
18-
from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv
19-
from gymnasium.vector.async_vector_env import AsyncState
20-
from gymnasium.vector.utils import concatenate, iterate, create_empty_array, write_to_shared_memory
21-
from gymnasium.vector.vector_env import ArrayType, VectorEnv, AutoresetMode
22-
from gymnasium.wrappers.vector import RecordEpisodeStatistics
2315
from gymnasium.error import NoAsyncCallError
2416
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
17+
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
18+
from gymnasium.vector.async_vector_env import AsyncState
19+
from gymnasium.vector.utils import (
20+
concatenate,
21+
create_empty_array,
22+
iterate,
23+
write_to_shared_memory,
24+
)
25+
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv
26+
from gymnasium.wrappers.vector import RecordEpisodeStatistics
2527

2628

2729
class MOSyncVectorEnv(SyncVectorEnv):
@@ -112,7 +114,8 @@ def step(self, actions: ActType) -> Tuple[ObsType, ArrayType, ArrayType, ArrayTy
112114
np.copy(self._truncations),
113115
infos,
114116
)
115-
117+
118+
116119
def _mo_async_worker(
117120
index: int,
118121
env_fn: callable,
@@ -138,9 +141,7 @@ def _mo_async_worker(
138141
if command == "reset":
139142
observation, info = env.reset(**data)
140143
if shared_memory:
141-
write_to_shared_memory(
142-
observation_space, index, observation, shared_memory
143-
)
144+
write_to_shared_memory(observation_space, index, observation, shared_memory)
144145
observation = None
145146
autoreset = False
146147
pipe.send(((observation, info), True))
@@ -150,7 +151,11 @@ def _mo_async_worker(
150151
if autoreset_mode == AutoresetMode.NEXT_STEP:
151152
if autoreset:
152153
observation, info = env.reset()
153-
reward, terminated, truncated = np.zeros(reward_space.shape[0], dtype=np.float32), False, False
154+
reward, terminated, truncated = (
155+
np.zeros(reward_space.shape[0], dtype=np.float32),
156+
False,
157+
False,
158+
)
154159
else:
155160
(
156161
observation,
@@ -191,9 +196,7 @@ def _mo_async_worker(
191196
raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")
192197

193198
if shared_memory:
194-
write_to_shared_memory(
195-
observation_space, index, observation, shared_memory
196-
)
199+
write_to_shared_memory(observation_space, index, observation, shared_memory)
197200
observation = None
198201

199202
pipe.send(((observation, reward, terminated, truncated, info), True))
@@ -203,9 +206,7 @@ def _mo_async_worker(
203206
elif command == "_call":
204207
name, args, kwargs = data
205208
if name in ["reset", "step", "close", "_setattr", "_check_spaces"]:
206-
raise ValueError(
207-
f"Trying to call function `{name}` with `call`, use `{name}` directly instead."
208-
)
209+
raise ValueError(f"Trying to call function `{name}` with `call`, use `{name}` directly instead.")
209210

210211
attr = env.get_wrapper_attr(name)
211212
if callable(attr):
@@ -225,9 +226,7 @@ def _mo_async_worker(
225226
(
226227
single_obs_space == observation_space
227228
if obs_mode == "same"
228-
else is_space_dtype_shape_equiv(
229-
single_obs_space, observation_space
230-
)
229+
else is_space_dtype_shape_equiv(single_obs_space, observation_space)
231230
),
232231
single_action_space == action_space,
233232
),
@@ -246,14 +245,15 @@ def _mo_async_worker(
246245
pipe.send((None, False))
247246
finally:
248247
env.close()
249-
248+
249+
250250
class MOAsyncVectorEnv(AsyncVectorEnv):
251251
"""Vectorized environment that runs multiple environments in parallel.
252252
253253
It uses ``multiprocessing`` processes, and pipes for communication.
254254
255-
Mofified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards.
256-
255+
Modified from gymnasium.vector.async_vector_env.AsyncVectorEnv to allow for multi-objective rewards.
256+
257257
Example:
258258
>>> import mo_gymnasium as mo_gym
259259
>>> envs = mo_gym.wrappers.vector.MOAsyncVectorEnv([
@@ -274,11 +274,13 @@ class MOAsyncVectorEnv(AsyncVectorEnv):
274274
>>> terminateds
275275
array([False, True, False, False])
276276
"""
277-
def __init__(
278-
self,
279-
env_fns: Sequence[Callable[[], gym.Env]],
280-
**kwargs
281-
):
277+
278+
def __init__(self, env_fns: Sequence[Callable[[], gym.Env]], **kwargs):
279+
"""Vectorized environment that runs multiple environments in parallel.
280+
281+
Args:
282+
env_fns: env constructors
283+
"""
282284
super().__init__(env_fns=env_fns, worker=_mo_async_worker, **kwargs)
283285

284286
# extract reward space from first vector env and create 2d array to store vector rewards
@@ -288,10 +290,7 @@ def __init__(
288290
del dummy_env
289291
self.rewards = create_empty_array(self.reward_space, n=self.num_envs, fn=np.zeros)
290292

291-
292-
def step_wait(
293-
self, timeout: int | float | None = None
294-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
293+
def step_wait(self, timeout: int | float | None = None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
295294
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
296295
297296
Args:
@@ -308,15 +307,13 @@ def step_wait(
308307
self._assert_is_running()
309308
if self._state != AsyncState.WAITING_STEP:
310309
raise NoAsyncCallError(
311-
"Calling `step_wait` without any prior call " "to `step_async`.",
310+
"Calling `step_wait` without any prior call to `step_async`.",
312311
AsyncState.WAITING_STEP.value,
313312
)
314313

315314
if not self._poll_pipe_envs(timeout):
316315
self._state = AsyncState.DEFAULT
317-
raise multiprocessing.TimeoutError(
318-
f"The call to `step_wait` has timed out after {timeout} second(s)."
319-
)
316+
raise multiprocessing.TimeoutError(f"The call to `step_wait` has timed out after {timeout} second(s).")
320317

321318
observations, rewards, terminations, truncations, infos = [], [], [], [], {}
322319
successes = []
@@ -339,7 +336,7 @@ def step_wait(
339336
observations,
340337
self.observations,
341338
)
342-
339+
343340
# modify to allow return of vector rewards
344341
self.rewards = concatenate(
345342
self.reward_space,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
'Topic :: Scientific/Engineering :: Artificial Intelligence',
2323
]
2424
dependencies = [
25-
"gymnasium >=1.0.0",
25+
"gymnasium >=1.1.0",
2626
"numpy >=1.21.0,<2.0",
2727
"pygame >=2.1.3",
2828
"scipy >=1.7.3",

tests/test_vector_wrappers.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,39 @@
22
import numpy as np
33

44
import mo_gymnasium as mo_gym
5-
from mo_gymnasium.wrappers.vector import MORecordEpisodeStatistics, MOSyncVectorEnv
5+
from mo_gymnasium.wrappers.vector import (
6+
MOAsyncVectorEnv,
7+
MORecordEpisodeStatistics,
8+
MOSyncVectorEnv,
9+
)
610

711

8-
def test_mo_sync_wrapper():
9-
num_envs = 3
10-
envs = MOSyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
11-
12-
envs.reset()
12+
def _test_logic(envs, num_envs: int):
1313
obs, rewards, terminateds, truncateds, infos = envs.step(envs.action_space.sample())
1414
assert len(obs) == num_envs, "Number of observations do not match the number of envs"
1515
assert len(rewards) == num_envs, "Number of rewards do not match the number of envs"
1616
assert len(terminateds) == num_envs, "Number of terminateds do not match the number of envs"
1717
assert len(truncateds) == num_envs, "Number of truncateds do not match the number of envs"
18-
envs.close()
1918

2019

21-
def test_mo_sync_autoreset():
22-
num_envs = 2
20+
def test_mo_sync_wrapper():
21+
num_envs = 3
2322
envs = MOSyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
2423

24+
envs.reset()
25+
_test_logic(envs, num_envs)
26+
envs.close()
27+
28+
29+
def test_mo_async_wrapper():
30+
num_envs = 3
31+
envs = MOAsyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
32+
envs.reset()
33+
_test_logic(envs, num_envs)
34+
envs.close()
35+
36+
37+
def _test_autoreset_logic(envs):
2538
obs, infos = envs.reset()
2639
assert (obs[0] == [0, 0]).all()
2740
assert (obs[1] == [0, 0]).all()
@@ -42,14 +55,25 @@ def test_mo_sync_autoreset():
4255
assert (rewards[1] == [0.0, 0.0]).all() # Reset step
4356
assert not terminateds[0]
4457
assert not terminateds[1] # Not done anymore
45-
envs.close()
4658

4759

48-
def test_mo_record_ep_statistic_vector_env():
60+
def test_mo_sync_autoreset():
4961
num_envs = 2
5062
envs = MOSyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
51-
envs = MORecordEpisodeStatistics(envs, gamma=0.97)
5263

64+
_test_autoreset_logic(envs)
65+
envs.close()
66+
67+
68+
def test_mo_async_autoreset():
69+
num_envs = 2
70+
envs = MOAsyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
71+
72+
_test_autoreset_logic(envs)
73+
envs.close()
74+
75+
76+
def _test_record_ep_statistic_logic(envs, num_envs: int):
5377
envs.reset()
5478
terminateds = np.array([False] * num_envs)
5579
info = {}
@@ -69,21 +93,48 @@ def test_mo_record_ep_statistic_vector_env():
6993
assert isinstance(info["episode"]["l"], np.ndarray)
7094
np.testing.assert_almost_equal(info["episode"]["l"], np.array([0, 3], dtype=np.float32), decimal=2)
7195
assert isinstance(info["episode"]["t"], np.ndarray)
96+
97+
98+
def test_mo_record_ep_statistic_vector_env():
99+
num_envs = 2
100+
envs = MOSyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
101+
envs = MORecordEpisodeStatistics(envs, gamma=0.97)
102+
_test_record_ep_statistic_logic(envs, num_envs)
72103
envs.close()
73104

74105

75-
def test_gym_wrapper_and_vector():
76-
# This tests the integration of gym-wrapped envs with MO-Gymnasium vectorized envs
106+
def test_mo_record_ep_statistic_vector_env_async():
77107
num_envs = 2
78-
envs = MOSyncVectorEnv(
79-
[lambda: gym.wrappers.NormalizeObservation(mo_gym.make("deep-sea-treasure-v0")) for _ in range(num_envs)]
80-
)
108+
envs = MOAsyncVectorEnv([lambda: mo_gym.make("deep-sea-treasure-v0") for _ in range(num_envs)])
109+
envs = MORecordEpisodeStatistics(envs, gamma=0.97)
110+
_test_record_ep_statistic_logic(envs, num_envs)
111+
envs.close()
112+
81113

114+
def _test_gym_wrapper_and_vector_logic(envs, num_envs: int):
82115
envs.reset()
83116
for i in range(30):
84117
obs, rewards, terminateds, truncateds, infos = envs.step(envs.action_space.sample())
85118
assert len(obs) == num_envs, "Number of observations do not match the number of envs"
86119
assert len(rewards) == num_envs, "Number of rewards do not match the number of envs"
87120
assert len(terminateds) == num_envs, "Number of terminateds do not match the number of envs"
88121
assert len(truncateds) == num_envs, "Number of truncateds do not match the number of envs"
122+
123+
124+
def test_gym_wrapper_and_vector():
125+
# This tests the integration of gym-wrapped envs with MO-Gymnasium vectorized envs
126+
num_envs = 2
127+
envs = MOSyncVectorEnv(
128+
[lambda: gym.wrappers.NormalizeObservation(mo_gym.make("deep-sea-treasure-v0")) for _ in range(num_envs)]
129+
)
130+
_test_gym_wrapper_and_vector_logic(envs, num_envs)
131+
envs.close()
132+
133+
134+
def test_gym_wrapper_and_vector_async():
135+
num_envs = 2
136+
envs = MOAsyncVectorEnv(
137+
[lambda: gym.wrappers.NormalizeObservation(mo_gym.make("deep-sea-treasure-v0")) for _ in range(num_envs)]
138+
)
139+
_test_gym_wrapper_and_vector_logic(envs, num_envs)
89140
envs.close()

0 commit comments

Comments
 (0)