|
20 | 20 | import numpy as np |
21 | 21 | from absl import logging |
22 | 22 | from absl.testing import absltest |
| 23 | +from packaging import version |
23 | 24 |
|
24 | 25 | from envpool.atari import AtariDMEnvPool, AtariEnvSpec, AtariGymEnvPool |
25 | 26 |
|
@@ -250,63 +251,77 @@ def test_lowlevel_step(self) -> None: |
250 | 251 | self.assertTrue(isinstance(env, gym.Env)) |
251 | 252 | logging.info(env) |
252 | 253 | env.async_reset() |
253 | | - obs, rew, done, info = env.recv() |
| 254 | + obs, rew, terminated, truncated, info = env.recv() |
| 255 | + done = np.logical_or(terminated, truncated) |
254 | 256 | # check shape |
255 | 257 | self.assertIsInstance(obs, np.ndarray) |
256 | 258 | self.assertEqual(obs.dtype, np.uint8) |
257 | 259 | np.testing.assert_allclose(rew.shape, (num_envs,)) |
258 | 260 | self.assertEqual(rew.dtype, np.float32) |
259 | 261 | np.testing.assert_allclose(done.shape, (num_envs,)) |
260 | 262 | self.assertEqual(done.dtype, np.bool_) |
| 263 | + self.assertEqual(terminated.dtype, np.bool_) |
| 264 | + self.assertEqual(truncated.dtype, np.bool_) |
261 | 265 | self.assertIsInstance(info, dict) |
262 | | - self.assertEqual(len(info), 7) |
| 266 | + self.assertEqual(len(info), 6) |
263 | 267 | self.assertEqual(info["env_id"].dtype, np.int32) |
264 | 268 | self.assertEqual(info["lives"].dtype, np.int32) |
265 | 269 | self.assertEqual(info["players"]["env_id"].dtype, np.int32) |
266 | | - self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_) |
267 | 270 | np.testing.assert_allclose(info["env_id"], np.arange(num_envs)) |
268 | 271 | np.testing.assert_allclose(info["lives"].shape, (num_envs,)) |
269 | 272 | np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,)) |
270 | | - np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,)) |
| 273 | + np.testing.assert_allclose(truncated.shape, (num_envs,)) |
271 | 274 | while not np.any(done): |
272 | 275 | env.send(np.random.randint(6, size=num_envs)) |
273 | | - obs, rew, done, info = env.recv() |
| 276 | + obs, rew, terminated, truncated, info = env.recv() |
| 277 | + done = np.logical_or(terminated, truncated) |
274 | 278 | env.send(np.random.randint(6, size=num_envs)) |
275 | | - obs1, rew1, done1, info1 = env.recv() |
| 279 | + obs1, rew1, terminated1, truncated1, info1 = env.recv() |
| 280 | + done1 = np.logical_or(terminated1, truncated1) |
276 | 281 | index = np.where(done)[0] |
277 | 282 | self.assertTrue(np.all(~done1[index])) |
278 | 283 |
|
279 | 284 | def test_highlevel_step(self) -> None: |
| 285 | + assert version.parse(gym.__version__) >= version.parse("0.26.0") |
280 | 286 | num_envs = 4 |
281 | 287 | config = AtariEnvSpec.gen_config(task="pong", num_envs=num_envs) |
282 | 288 | spec = AtariEnvSpec(config) |
283 | 289 | env = AtariGymEnvPool(spec) |
284 | 290 | self.assertTrue(isinstance(env, gym.Env)) |
285 | 291 | logging.info(env) |
286 | | - obs = env.reset() |
| 292 | + obs, _ = env.reset() |
287 | 293 | # check shape |
288 | 294 | self.assertIsInstance(obs, np.ndarray) |
289 | | - self.assertEqual(obs.dtype, np.uint8) # type: ignore |
290 | | - obs, rew, done, info = env.step(np.random.randint(6, size=num_envs)) |
| 295 | + self.assertEqual(obs.dtype, np.uint8) |
| 296 | + obs, rew, terminated, truncated, info = env.step( |
| 297 | + np.random.randint(6, size=num_envs) |
| 298 | + ) |
| 299 | + done = np.logical_or(terminated, truncated) |
291 | 300 | self.assertIsInstance(obs, np.ndarray) |
292 | 301 | self.assertEqual(obs.dtype, np.uint8) |
293 | 302 | np.testing.assert_allclose(rew.shape, (num_envs,)) |
294 | 303 | self.assertEqual(rew.dtype, np.float32) |
295 | 304 | np.testing.assert_allclose(done.shape, (num_envs,)) |
296 | 305 | self.assertEqual(done.dtype, np.bool_) |
297 | 306 | self.assertIsInstance(info, dict) |
298 | | - self.assertEqual(len(info), 7) |
| 307 | + self.assertEqual(len(info), 6) |
299 | 308 | self.assertEqual(info["env_id"].dtype, np.int32) |
300 | 309 | self.assertEqual(info["lives"].dtype, np.int32) |
301 | 310 | self.assertEqual(info["players"]["env_id"].dtype, np.int32) |
302 | | - self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_) |
| 311 | + self.assertEqual(truncated.dtype, np.bool_) |
303 | 312 | np.testing.assert_allclose(info["env_id"], np.arange(num_envs)) |
304 | 313 | np.testing.assert_allclose(info["lives"].shape, (num_envs,)) |
305 | 314 | np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,)) |
306 | | - np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,)) |
| 315 | + np.testing.assert_allclose(truncated.shape, (num_envs,)) |
307 | 316 | while not np.any(done): |
308 | | - obs, rew, done, info = env.step(np.random.randint(6, size=num_envs)) |
309 | | - obs1, rew1, done1, info1 = env.step(np.random.randint(6, size=num_envs)) |
| 317 | + obs, rew, terminated, truncated, info = env.step( |
| 318 | + np.random.randint(6, size=num_envs) |
| 319 | + ) |
| 320 | + done = np.logical_or(terminated, truncated) |
| 321 | + obs1, rew1, terminated1, truncated1, info1 = env.step( |
| 322 | + np.random.randint(6, size=num_envs) |
| 323 | + ) |
| 324 | + done1 = np.logical_or(terminated1, truncated1) |
310 | 325 | index = np.where(done)[0] |
311 | 326 | self.assertTrue(np.all(~done1[index])) |
312 | 327 |
|
|
0 commit comments