Skip to content

Commit 980e1e5

Browse files
authored
Better handle "--More--" messages in NLE (#23)
* auto more wrapper * keep the behavior the same * fix order * add unwrapped to wrapper works
1 parent c3b6dbc commit 980e1e5

File tree

6 files changed

+55
-27
lines changed

6 files changed

+55
-27
lines changed

balrog/config/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ envs:
4747
no_progress_timeout: 150 # Timeout for no progress in NLE
4848
savedir: null # Directory to save NLE data; null disables saving
4949
save_ttyrec_every: 0 # Frequency of saving TTY recordings
50-
skip_more: False # Whether to skip the 'more' prompt in NLE
50+
skip_more: True # Whether to skip the 'more' prompt in NLE
5151
minihack_kwargs:
5252
character: "@"
5353
max_episode_steps: 100
@@ -57,7 +57,7 @@ envs:
5757
savedir: null
5858
save_ttyrec_every: 0
5959
autopickup: False
60-
skip_more: False
60+
skip_more: True
6161
babyai_kwargs:
6262
num_dists: 0
6363
crafter_kwargs:

balrog/environments/minihack/minihack_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gym
44
import minihack # NOQA: F401
55

6-
from balrog.environments.nle import NLELanguageWrapper
6+
from balrog.environments.nle import AutoMore, NLELanguageWrapper
77
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit
88

99
MINIHACK_ENVS = []
@@ -30,7 +30,9 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None)
3030
],
3131
**minihack_kwargs,
3232
)
33-
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
33+
if skip_more:
34+
env = AutoMore(env)
35+
env = NLELanguageWrapper(env, vlm=vlm)
3436

3537
# wrap NLE with timeout
3638
env = NLETimeLimit(env)

balrog/environments/nle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22

3+
from .auto_more import AutoMore
34
from .base import NLELanguageWrapper
45

56

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import gym
2+
from nle import nle_language_obsv
3+
from nle.nethack import actions as A
4+
5+
6+
class AutoMore(gym.Wrapper):
7+
def __init__(self, env):
8+
super().__init__(env)
9+
self.nle_language = nle_language_obsv.NLELanguageObsv()
10+
11+
def reset(self, **kwargs):
12+
obs = super().reset(**kwargs)
13+
obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
14+
15+
return obs
16+
17+
def step(self, action):
18+
obs, reward, done, info = super().step(action)
19+
20+
message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
21+
22+
while "--More--" in message and not done:
23+
message = message.replace("--More--", "\n")
24+
25+
action_index = self.env.actions.index(A.MiscAction.MORE)
26+
obs, rew, done, info = super().step(action_index)
27+
add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
28+
message += add
29+
reward += rew
30+
31+
obs["text_message"] = message
32+
33+
return obs, reward, done, info

balrog/environments/nle/base.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515

1616
class NLELanguageWrapper(language_wrapper.NLELanguageWrapper):
17-
def __init__(self, env, vlm=False, skip_more=False):
17+
def __init__(self, env, vlm=False):
1818
super().__init__(env, use_language_action=True)
1919
self.nle_language = nle_language_obsv.NLELanguageObsv()
2020
self.language_action_space = self.create_action_space()
2121
self.env = env
2222
self.vlm = vlm
2323
self.done = False
24-
self.skip_more = skip_more
2524

2625
if not vlm:
2726
self.prompt_mode = "hybrid"
@@ -77,28 +76,15 @@ def nle_obsv_type(self, nle_obsv):
7776
else:
7877
raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.')
7978

80-
def clean_message(self, nle_obsv):
81-
message = self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
82-
if not self.skip_more:
83-
while "--More--" in message and not self.done:
84-
message = message.replace("--More--", " ")
85-
message = message.replace("\n", " ")
86-
87-
nle_obsv, reward, self.done, info = self.step("more")
88-
add = self.nle_language.text_message(nle_obsv["obs"]["tty_chars"]).decode("latin-1")
89-
message += add
90-
return message, nle_obsv["obs"]
91-
return message, nle_obsv
92-
9379
def render(self, mode="human"):
9480
if mode == "tiles":
95-
obs = self.env.last_observation
96-
glyphs = obs[self.env._observation_keys.index("glyphs")]
81+
obs = self.env.unwrapped.last_observation
82+
glyphs = obs[self.env.unwrapped._observation_keys.index("glyphs")]
9783
return rgb_render_image(glyphs)
9884
elif mode == "tty_image":
99-
obs = self.env.last_observation
100-
tty_chars = obs[self.env._observation_keys.index("tty_chars")]
101-
tty_colors = obs[self.env._observation_keys.index("tty_colors")]
85+
obs = self.env.unwrapped.last_observation
86+
tty_chars = obs[self.env.unwrapped._observation_keys.index("tty_chars")]
87+
tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")]
10288
return tty_render_image(tty_chars, tty_colors)
10389
else:
10490
return super().render(mode)
@@ -150,7 +136,11 @@ def nle_obsv_to_language(self, nle_obsv):
150136
(dict): language observation
151137
"""
152138

153-
message, nle_obsv = self.clean_message(nle_obsv)
139+
message = (
140+
nle_obsv["text_message"]
141+
if "text_message" in nle_obsv
142+
else self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
143+
)
154144

155145
glyphs = nle_obsv["glyphs"]
156146
blstats = nle_obsv["blstats"]

balrog/environments/nle/nle_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gym
44
import nle # NOQA: F401
55

6-
from balrog.environments.nle import NLELanguageWrapper
6+
from balrog.environments.nle import AutoMore, NLELanguageWrapper
77
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit
88

99
NETHACK_ENVS = []
@@ -18,7 +18,9 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None):
1818
skip_more = nle_kwargs.pop("skip_more", False)
1919
vlm = True if config.agent.max_image_history > 0 else False
2020
env = gym.make(task, **nle_kwargs)
21-
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
21+
if skip_more:
22+
env = AutoMore(env)
23+
env = NLELanguageWrapper(env, vlm=vlm)
2224

2325
# wrap NLE with timeout
2426
env = NLETimeLimit(env)

0 commit comments

Comments
 (0)