Skip to content

Commit 1eedd34

Browse files
authored
Fix zero-life reset error in Atari env (#175)
add `info["terminated"]` as an indicator of `env.game_over()`
1 parent ea86c2b commit 1eedd34

File tree

5 files changed

+54
-7
lines changed

5 files changed

+54
-7
lines changed

envpool/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
register,
2424
)
2525

26-
__version__ = "0.6.2.post2"
26+
__version__ = "0.6.3"
2727
__all__ = [
2828
"register",
2929
"make",

envpool/atari/api_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_lowlevel_step(self) -> None:
259259
np.testing.assert_allclose(done.shape, (num_envs,))
260260
self.assertEqual(done.dtype, np.bool_)
261261
self.assertIsInstance(info, dict)
262-
self.assertEqual(len(info), 5)
262+
self.assertEqual(len(info), 6)
263263
self.assertEqual(info["env_id"].dtype, np.int32)
264264
self.assertEqual(info["lives"].dtype, np.int32)
265265
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
@@ -295,7 +295,7 @@ def test_highlevel_step(self) -> None:
295295
np.testing.assert_allclose(done.shape, (num_envs,))
296296
self.assertEqual(done.dtype, np.bool_)
297297
self.assertIsInstance(info, dict)
298-
self.assertEqual(len(info), 5)
298+
self.assertEqual(len(info), 6)
299299
self.assertEqual(info["env_id"].dtype, np.int32)
300300
self.assertEqual(info["lives"].dtype, np.int32)
301301
self.assertEqual(info["players"]["env_id"].dtype, np.int32)

envpool/atari/atari_env.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ class AtariEnvFns {
6565
conf["img_height"_], conf["img_width"_]},
6666
{0, 255})),
6767
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})),
68-
"info:lives"_.Bind(Spec<int>({-1}, {0, 5})),
69-
"info:reward"_.Bind(Spec<float>({-1})));
68+
"info:lives"_.Bind(Spec<int>({-1})),
69+
"info:reward"_.Bind(Spec<float>({-1})),
70+
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
7071
}
7172
template <typename Config>
7273
static decltype(auto) ActionSpec(const Config& conf) {
@@ -199,7 +200,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
199200
PushStack(false, skip_id == 0);
200201
++elapsed_step_;
201202
done_ |= (elapsed_step_ >= max_episode_steps_);
202-
if (episodic_life_ && env_->lives() < lives_) {
203+
if (episodic_life_ && 0 < env_->lives() && env_->lives() < lives_) {
203204
done_ = true;
204205
}
205206
float discount;
@@ -229,6 +230,7 @@ class AtariEnv : public Env<AtariEnvSpec> {
229230
state["reward"_] = reward;
230231
state["info:lives"_] = lives_;
231232
state["info:reward"_] = info_reward;
233+
state["info:terminated"_] = env_->game_over();
232234
for (int i = 0; i < stack_num_; ++i) {
233235
state["obs"_]
234236
.Slice(gray_scale_ ? i : i * 3, gray_scale_ ? i + 1 : (i + 1) * 3)

envpool/atari/atari_envpool_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,51 @@ def test_align(self) -> None:
7878
np.testing.assert_allclose(obs0, obs1)
7979
# cv2.imwrite(f"/tmp/log/align{i}.png", obs0[0, 1:].transpose(1, 2, 0))
8080

81+
def test_reset_life(self) -> None:
82+
"""Issue 171."""
83+
for env_id in [
84+
"atlantis", "backgammon", "breakout", "pong", "wizard_of_wor"
85+
]:
86+
np.random.seed(0)
87+
env = AtariGymEnvPool(
88+
AtariEnvSpec(
89+
AtariEnvSpec.gen_config(task=env_id, num_envs=1, episodic_life=True)
90+
)
91+
)
92+
action_num = env.action_space.n # type: ignore
93+
env.reset()
94+
info = env.step(np.array([0]))[-1]
95+
if info["lives"].sum() == 0:
96+
# no life in this game
97+
continue
98+
for _ in range(10000):
99+
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
100+
if info["lives"][0] == 0:
101+
break
102+
else:
103+
self.assertFalse(info["terminated"][0])
104+
if info["lives"][0] > 0:
105+
# step too long
106+
continue
107+
# for normal atari (e.g., breakout)
108+
# take an additional step after all lives are exhausted
109+
_, _, next_done, next_info = env.step(
110+
np.random.randint(0, action_num, 1)
111+
)
112+
if done[0] and next_info["lives"][0] > 0:
113+
self.assertTrue(info["terminated"][0])
114+
continue
115+
self.assertFalse(done[0])
116+
self.assertFalse(info["terminated"][0])
117+
while not done[0]:
118+
self.assertFalse(info["terminated"][0])
119+
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
120+
_, _, next_done, next_info = env.step(
121+
np.random.randint(0, action_num, 1)
122+
)
123+
self.assertTrue(next_info["lives"][0] > 0)
124+
self.assertTrue(info["terminated"][0])
125+
81126
def test_partial_step(self) -> None:
82127
num_envs = 5
83128
max_episode_steps = 10

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = envpool
3-
version = 0.6.2.post2
3+
version = 0.6.3
44
author = "EnvPool Contributors"
55
author_email = "[email protected]"
66
description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments."

0 commit comments

Comments
 (0)