Skip to content

Commit 09f18d4

Browse files
committed
auto more wrapper
1 parent c3b6dbc commit 09f18d4

File tree

5 files changed

+48
-20
lines changed

5 files changed

+48
-20
lines changed

balrog/environments/minihack/minihack_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gym
44
import minihack # NOQA: F401
55

6-
from balrog.environments.nle import NLELanguageWrapper
6+
from balrog.environments.nle import AutoMore, NLELanguageWrapper
77
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit
88

99
MINIHACK_ENVS = []
@@ -30,7 +30,9 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None)
3030
],
3131
**minihack_kwargs,
3232
)
33-
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
33+
env = NLELanguageWrapper(env, vlm=vlm)
34+
if skip_more:
35+
env = AutoMore(env)
3436

3537
# wrap NLE with timeout
3638
env = NLETimeLimit(env)

balrog/environments/nle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22

3+
from .auto_more import AutoMore
34
from .base import NLELanguageWrapper
45

56

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import gym
2+
from nle import nle_language_obsv
3+
from nle.nethack import actions as A
4+
5+
6+
class AutoMore(gym.Wrapper):
7+
def __init__(self, env):
8+
super().__init__(env)
9+
self.nle_language = nle_language_obsv.NLELanguageObsv()
10+
11+
def reset(self, **kwargs):
12+
obs = super().reset(**kwargs)
13+
obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
14+
15+
return obs
16+
17+
def step(self, action):
18+
obs, reward, done, info = super().step(action)
19+
20+
message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
21+
22+
while "--More--" in message and not done:
23+
message = message.replace("--More--", "\n")
24+
25+
action_index = self.env.actions.index(A.MiscAction.MORE)
26+
obs, rew, done, info = super().step(action_index)
27+
add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
28+
message += add
29+
reward += rew
30+
31+
obs["text_message"] = message
32+
33+
return obs, reward, done, info

balrog/environments/nle/base.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515

1616
class NLELanguageWrapper(language_wrapper.NLELanguageWrapper):
17-
def __init__(self, env, vlm=False, skip_more=False):
17+
def __init__(self, env, vlm=False):
1818
super().__init__(env, use_language_action=True)
1919
self.nle_language = nle_language_obsv.NLELanguageObsv()
2020
self.language_action_space = self.create_action_space()
2121
self.env = env
2222
self.vlm = vlm
2323
self.done = False
24-
self.skip_more = skip_more
2524

2625
if not vlm:
2726
self.prompt_mode = "hybrid"
@@ -77,19 +76,6 @@ def nle_obsv_type(self, nle_obsv):
7776
else:
7877
raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.')
7978

80-
def clean_message(self, nle_obsv):
81-
message = self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
82-
if not self.skip_more:
83-
while "--More--" in message and not self.done:
84-
message = message.replace("--More--", " ")
85-
message = message.replace("\n", " ")
86-
87-
nle_obsv, reward, self.done, info = self.step("more")
88-
add = self.nle_language.text_message(nle_obsv["obs"]["tty_chars"]).decode("latin-1")
89-
message += add
90-
return message, nle_obsv["obs"]
91-
return message, nle_obsv
92-
9379
def render(self, mode="human"):
9480
if mode == "tiles":
9581
obs = self.env.last_observation
@@ -150,7 +136,11 @@ def nle_obsv_to_language(self, nle_obsv):
150136
(dict): language observation
151137
"""
152138

153-
message, nle_obsv = self.clean_message(nle_obsv)
139+
message = (
140+
nle_obsv["text_message"]
141+
if "text_message" in nle_obsv
142+
else self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
143+
)
154144

155145
glyphs = nle_obsv["glyphs"]
156146
blstats = nle_obsv["blstats"]

balrog/environments/nle/nle_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gym
44
import nle # NOQA: F401
55

6-
from balrog.environments.nle import NLELanguageWrapper
6+
from balrog.environments.nle import AutoMore, NLELanguageWrapper
77
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit
88

99
NETHACK_ENVS = []
@@ -18,7 +18,9 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None):
1818
skip_more = nle_kwargs.pop("skip_more", False)
1919
vlm = True if config.agent.max_image_history > 0 else False
2020
env = gym.make(task, **nle_kwargs)
21-
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
21+
if skip_more:
22+
env = AutoMore(env)
23+
env = NLELanguageWrapper(env, vlm=vlm)
2224

2325
# wrap NLE with timeout
2426
env = NLETimeLimit(env)

0 commit comments

Comments
 (0)