Skip to content

Commit f674037

Browse files
Support Gym 0.26.0 (#205)
Co-authored-by: Jiayi Weng <[email protected]>
1 parent 3aa3657 commit f674037

File tree

21 files changed

+237
-127
lines changed

21 files changed

+237
-127
lines changed

.bazelrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
22
build --action_env=BAZEL_LINKOPTS=-static-libgcc
33
build --action_env=CUDA_DIR=/usr/local/cuda
4+
build --action_env=LD_LIBRARY_PATH=/home/ubuntu/.mujoco/mujoco210/bin
45
build --incompatible_strict_action_env --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --client_env=BAZEL_CXXOPTS=-std=c++17
56
build:debug --cxxopt=-DENVPOOL_TEST --compilation_mode=dbg -s
67
build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx

benchmark/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ mujoco_py==2.1.2.14
77
tqdm
88
opencv-python-headless
99
dm_control==1.0.3.post1
10+
packaging

docker/release.dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/baze
3636

3737
# install big wheels
3838

39-
RUN for i in 7 8 9; do ln -sf /usr/bin/python3.$i /usr/bin/python3; pip3 install torch opencv-python-headless; done
39+
RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; pip3 install torch opencv-python-headless; done
4040

4141
RUN bazel version
4242

@@ -45,4 +45,4 @@ COPY . .
4545

4646
# compile and test release wheels
4747

48-
RUN for i in 7 8 9; do ln -sf /usr/bin/python3.$i /usr/bin/python3; make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080"; pip3 install wheelhouse/*cp3$i*.whl; rm dist/*.whl; make release-test; done
48+
RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080"; pip3 install wheelhouse/*cp3$i*.whl; rm dist/*.whl; make release-test; done

docs/content/python_interface.rst

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ batched environments:
3131
``gym.Env``, while some environments may not have such an option;
3232
* ``gym_reset_return_info (bool)``: whether to return a tuple of
3333
``(obs, info)`` instead of only ``obs`` when calling reset in ``gym.Env``,
34-
default to ``False``; this option is to adapt the newest version of gym's
34+
defaults to ``False`` if you are using Gym<0.26.0, otherwise it defaults
35+
to ``True``; this option is to adapt the newest version of gym's
3536
interface;
3637
* other configurations such as ``img_height`` / ``img_width`` / ``stack_num``
3738
/ ``frame_skip`` / ``noop_max`` in Atari env, ``reward_metric`` /
@@ -115,16 +116,17 @@ third case, use ``env.step(action)`` where action is a dictionary.
115116
Data Output Format
116117
------------------
117118

118-
+----------+------------------------------------------------------------------+------------------------------------------------------------------+
119-
| function | gym | dm |
120-
| | | |
121-
+==========+==================================================================+==================================================================+
122-
| reset | | env_id -> obs array (single observation) | env_id -> TimeStep(FIRST, obs|info|env_id, rew=0, discount or 1) |
123-
| | | or an obs dict (multi observation) | |
124-
| | | or (obs, info) tuple (when ``gym_reset_return_info`` == True) | |
125-
+----------+------------------------------------------------------------------+------------------------------------------------------------------+
126-
| step | (obs, rew, done, info|env_id) | TimeStep(StepType, obs|info|env_id, rew, discount or 1 - done) |
127-
+----------+------------------------------------------------------------------+------------------------------------------------------------------+
119+
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+
120+
| function | gym | dm |
121+
| | | |
122+
+==========+======================================================================+==================================================================+
123+
| reset | | env_id -> obs array (single observation) | env_id -> TimeStep(FIRST, obs|info|env_id, rew=0, discount or 1) |
124+
| | | or an obs dict (multi observation) | |
125+
| | | or (obs, info) tuple (when ``gym_reset_return_info`` == True) | |
126+
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+
127+
| step | (obs, rew, done, info|env_id) or | TimeStep(StepType, obs|info|env_id, rew, discount or 1 - done) |
128+
| | (obs, rew, terminated, truncated, info|env_id) (when Gym >= 0.26.0) | |
129+
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+
128130

129131
Note: ``gym.reset()`` doesn't support async step setting because it cannot get
130132
``env_id`` from ``reset()`` function, so it's better to use low-level APIs such

envpool/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
register,
2424
)
2525

26-
__version__ = "0.6.4"
26+
__version__ = "0.6.5"
2727
__all__ = [
2828
"register",
2929
"make",

envpool/atari/api_test.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
from absl import logging
2222
from absl.testing import absltest
23+
from packaging import version
2324

2425
from envpool.atari import AtariDMEnvPool, AtariEnvSpec, AtariGymEnvPool
2526

@@ -250,63 +251,77 @@ def test_lowlevel_step(self) -> None:
250251
self.assertTrue(isinstance(env, gym.Env))
251252
logging.info(env)
252253
env.async_reset()
253-
obs, rew, done, info = env.recv()
254+
obs, rew, terminated, truncated, info = env.recv()
255+
done = np.logical_or(terminated, truncated)
254256
# check shape
255257
self.assertIsInstance(obs, np.ndarray)
256258
self.assertEqual(obs.dtype, np.uint8)
257259
np.testing.assert_allclose(rew.shape, (num_envs,))
258260
self.assertEqual(rew.dtype, np.float32)
259261
np.testing.assert_allclose(done.shape, (num_envs,))
260262
self.assertEqual(done.dtype, np.bool_)
263+
self.assertEqual(terminated.dtype, np.bool_)
264+
self.assertEqual(truncated.dtype, np.bool_)
261265
self.assertIsInstance(info, dict)
262-
self.assertEqual(len(info), 7)
266+
self.assertEqual(len(info), 6)
263267
self.assertEqual(info["env_id"].dtype, np.int32)
264268
self.assertEqual(info["lives"].dtype, np.int32)
265269
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
266-
self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_)
267270
np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
268271
np.testing.assert_allclose(info["lives"].shape, (num_envs,))
269272
np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
270-
np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,))
273+
np.testing.assert_allclose(truncated.shape, (num_envs,))
271274
while not np.any(done):
272275
env.send(np.random.randint(6, size=num_envs))
273-
obs, rew, done, info = env.recv()
276+
obs, rew, terminated, truncated, info = env.recv()
277+
done = np.logical_or(terminated, truncated)
274278
env.send(np.random.randint(6, size=num_envs))
275-
obs1, rew1, done1, info1 = env.recv()
279+
obs1, rew1, terminated1, truncated1, info1 = env.recv()
280+
done1 = np.logical_or(terminated1, truncated1)
276281
index = np.where(done)[0]
277282
self.assertTrue(np.all(~done1[index]))
278283

279284
def test_highlevel_step(self) -> None:
285+
assert version.parse(gym.__version__) >= version.parse("0.26.0")
280286
num_envs = 4
281287
config = AtariEnvSpec.gen_config(task="pong", num_envs=num_envs)
282288
spec = AtariEnvSpec(config)
283289
env = AtariGymEnvPool(spec)
284290
self.assertTrue(isinstance(env, gym.Env))
285291
logging.info(env)
286-
obs = env.reset()
292+
obs, _ = env.reset()
287293
# check shape
288294
self.assertIsInstance(obs, np.ndarray)
289-
self.assertEqual(obs.dtype, np.uint8) # type: ignore
290-
obs, rew, done, info = env.step(np.random.randint(6, size=num_envs))
295+
self.assertEqual(obs.dtype, np.uint8)
296+
obs, rew, terminated, truncated, info = env.step(
297+
np.random.randint(6, size=num_envs)
298+
)
299+
done = np.logical_or(terminated, truncated)
291300
self.assertIsInstance(obs, np.ndarray)
292301
self.assertEqual(obs.dtype, np.uint8)
293302
np.testing.assert_allclose(rew.shape, (num_envs,))
294303
self.assertEqual(rew.dtype, np.float32)
295304
np.testing.assert_allclose(done.shape, (num_envs,))
296305
self.assertEqual(done.dtype, np.bool_)
297306
self.assertIsInstance(info, dict)
298-
self.assertEqual(len(info), 7)
307+
self.assertEqual(len(info), 6)
299308
self.assertEqual(info["env_id"].dtype, np.int32)
300309
self.assertEqual(info["lives"].dtype, np.int32)
301310
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
302-
self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_)
311+
self.assertEqual(truncated.dtype, np.bool_)
303312
np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
304313
np.testing.assert_allclose(info["lives"].shape, (num_envs,))
305314
np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
306-
np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,))
315+
np.testing.assert_allclose(truncated.shape, (num_envs,))
307316
while not np.any(done):
308-
obs, rew, done, info = env.step(np.random.randint(6, size=num_envs))
309-
obs1, rew1, done1, info1 = env.step(np.random.randint(6, size=num_envs))
317+
obs, rew, terminated, truncated, info = env.step(
318+
np.random.randint(6, size=num_envs)
319+
)
320+
done = np.logical_or(terminated, truncated)
321+
obs1, rew1, terminated1, truncated1, info1 = env.step(
322+
np.random.randint(6, size=num_envs)
323+
)
324+
done1 = np.logical_or(terminated1, truncated1)
310325
index = np.where(done)[0]
311326
self.assertTrue(np.all(~done1[index]))
312327

envpool/atari/atari_envpool_test.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_align(self) -> None:
6868
spec = AtariEnvSpec(config)
6969
env0 = AtariGymEnvPool(spec)
7070
env1 = AtariDMEnvPool(spec)
71-
obs0 = env0.reset()
71+
obs0, _ = env0.reset()
7272
obs1 = env1.reset().observation.obs # type: ignore
7373
np.testing.assert_allclose(obs0, obs1)
7474
for _ in range(1000):
@@ -96,7 +96,10 @@ def test_reset_life(self) -> None:
9696
# no life in this game
9797
continue
9898
for _ in range(10000):
99-
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
99+
_, _, terminated, truncated, info = env.step(
100+
np.random.randint(0, action_num, 1)
101+
)
102+
done = np.logical_or(terminated, truncated)
100103
if info["lives"][0] == 0:
101104
break
102105
else:
@@ -106,7 +109,7 @@ def test_reset_life(self) -> None:
106109
continue
107110
# for normal atari (e.g., breakout)
108111
# take an additional step after all lives are exhausted
109-
_, _, next_done, next_info = env.step(
112+
_, _, next_terminated, next_truncated, next_info = env.step(
110113
np.random.randint(0, action_num, 1)
111114
)
112115
if done[0] and next_info["lives"][0] > 0:
@@ -116,8 +119,11 @@ def test_reset_life(self) -> None:
116119
self.assertFalse(info["terminated"][0])
117120
while not done[0]:
118121
self.assertFalse(info["terminated"][0])
119-
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
120-
_, _, next_done, next_info = env.step(
122+
_, _, terminated, truncated, info = env.step(
123+
np.random.randint(0, action_num, 1)
124+
)
125+
done = np.logical_or(terminated, truncated)
126+
_, _, next_terminated, next_truncated, next_info = env.step(
121127
np.random.randint(0, action_num, 1)
122128
)
123129
self.assertTrue(next_info["lives"][0] > 0)
@@ -137,21 +143,21 @@ def test_partial_step(self) -> None:
137143
partial_ids = [np.arange(num_envs)[::2], np.arange(num_envs)[1::2]]
138144
env.step(np.zeros(len(partial_ids[1]), dtype=int), env_id=partial_ids[1])
139145
for _ in range(max_episode_steps - 2):
140-
info = env.step(
146+
_, _, _, truncated, info = env.step(
141147
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
142-
)[-1]
143-
assert np.all(~info["TimeLimit.truncated"])
144-
info = env.step(
148+
)
149+
assert np.all(~truncated)
150+
_, _, _, truncated, info = env.step(
145151
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
146-
)[-1]
152+
)
147153
env_id = np.array(info["env_id"])
148-
done_id = np.array(sorted(env_id[info["TimeLimit.truncated"]]))
154+
done_id = np.array(sorted(env_id[truncated]))
149155
assert np.all(done_id == partial_ids[1])
150-
info = env.step(
156+
_, _, _, truncated, info = env.step(
151157
np.zeros(len(partial_ids[0]), dtype=int),
152158
env_id=partial_ids[0],
153-
)[-1]
154-
assert np.all(info["TimeLimit.truncated"])
159+
)
160+
assert np.all(truncated)
155161

156162
def test_xla_api(self) -> None:
157163
num_envs = 10
@@ -216,15 +222,15 @@ def test_no_gray_scale(self) -> None:
216222
spec = AtariEnvSpec(config)
217223
env = AtariGymEnvPool(spec)
218224
self.assertTrue(env.observation_space.shape, ref_shape)
219-
obs = env.reset()
225+
obs, _ = env.reset()
220226
self.assertTrue(obs.shape, ref_shape)
221227
config = AtariEnvSpec.gen_config(
222228
task="breakout", gray_scale=False, img_height=210, img_width=160
223229
)
224230
spec = AtariEnvSpec(config)
225231
env = AtariGymEnvPool(spec)
226232
self.assertTrue(env.observation_space.shape, raw_shape)
227-
obs1 = env.reset()
233+
obs1, _ = env.reset()
228234
self.assertTrue(obs1.shape, raw_shape)
229235
for i in range(0, 12, 3):
230236
obs_ = cv2.resize(

envpool/atari/atari_pretrain_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,14 @@ def eval_qrdqn(
5959
policy.eval()
6060
ids = np.arange(num_envs)
6161
reward = np.zeros(num_envs)
62-
obs = env.reset()
62+
obs, _ = env.reset()
6363
for _ in range(25000):
6464
if np.random.rand() < 5e-3:
6565
act = np.random.randint(action_shape, size=len(ids))
6666
else:
6767
act = policy(Batch(obs=obs, info={})).act
68-
obs, rew, done, info = env.step(act, ids)
68+
obs, rew, terminated, truncated, info = env.step(act, ids)
69+
done = np.logical_or(terminated, truncated)
6970
ids = np.asarray(info["env_id"])
7071
reward[ids] += rew
7172
obs = obs[~done]

envpool/box2d/box2d_correctness_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,14 @@ def solve_lunar_lander(self, num_envs: int, continuous: bool) -> None:
109109
for _ in range(2):
110110
env_id = np.arange(num_envs)
111111
done = np.array([False] * num_envs)
112-
obs = env.reset(env_id)
112+
obs, _ = env.reset(env_id)
113113
rewards = np.zeros(num_envs)
114114
while not np.all(done):
115115
action = np.array(
116116
[self.heuristic_lunar_lander_policy(s, continuous) for s in obs]
117117
)
118-
obs, rew, done, info = env.step(action, env_id)
118+
obs, rew, terminated, truncated, info = env.step(action, env_id)
119+
done = np.logical_or(terminated, truncated)
119120
env_id = info["env_id"]
120121
rewards[env_id] += rew
121122
obs = obs[~done]
@@ -228,11 +229,12 @@ def solve_bipedal_walker(
228229
)
229230
env_id = np.arange(num_envs)
230231
done = np.array([False] * num_envs)
231-
obs = env.reset(env_id)
232+
obs, _ = env.reset(env_id)
232233
rewards = np.zeros(num_envs)
233234
action = np.zeros([num_envs, 4])
234235
for _ in range(max_episode_steps):
235-
obs, rew, done, info = env.step(action, env_id)
236+
obs, rew, terminated, truncated, info = env.step(action, env_id)
237+
done = np.logical_or(terminated, truncated)
236238
if render:
237239
self.render_bpw(info)
238240
env_id = info["env_id"]

envpool/classic_control/classic_control_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ def run_align_check(self, env0: gym.Env, env1: Any, reset_fn: Any) -> None:
8282
d0 = False
8383
while not d0:
8484
a = env0.action_space.sample()
85-
o0, r0, d0, _ = env0.step(a)
86-
o1, r1, d1, _ = env1.step(np.array([a]), np.array([0]))
85+
o0, r0, term0, trunc0, _ = env0.step(a)
86+
d0 = np.logical_or(term0, trunc0)
87+
o1, r1, term1, trunc1, _ = env1.step(np.array([a]), np.array([0]))
88+
d1 = np.logical_or(term1, trunc1)
8789
np.testing.assert_allclose(o0, o1[0], atol=1e-4)
8890
np.testing.assert_allclose(r0, r1[0])
8991
np.testing.assert_allclose(d0, d1[0])
92+
np.testing.assert_allclose(term0, term1[0])
93+
np.testing.assert_allclose(trunc0, trunc1[0])
9094

9195
def test_cartpole(self) -> None:
9296
env0 = gym.make("CartPole-v1")
@@ -109,7 +113,7 @@ def test_mountain_car(self) -> None:
109113
@no_type_check
110114
def reset_fn(env0: gym.Env, env1: Any) -> None:
111115
env0.reset()
112-
obs = env1.reset()
116+
obs, _ = env1.reset()
113117
env0.unwrapped.state = obs[0]
114118

115119
env0 = gym.make("MountainCar-v0")

0 commit comments

Comments
 (0)