diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index 24be0f68..76b18fbc 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -47,7 +47,7 @@ envs: no_progress_timeout: 150 # Timeout for no progress in NLE savedir: null # Directory to save NLE data; null disables saving save_ttyrec_every: 0 # Frequency of saving TTY recordings - skip_more: False # Whether to skip the 'more' prompt in NLE + skip_more: True # Whether to skip the 'more' prompt in NLE minihack_kwargs: character: "@" max_episode_steps: 100 @@ -57,7 +57,7 @@ envs: savedir: null save_ttyrec_every: 0 autopickup: False - skip_more: False + skip_more: True babyai_kwargs: num_dists: 0 crafter_kwargs: diff --git a/balrog/environments/minihack/minihack_env.py b/balrog/environments/minihack/minihack_env.py index 4cca3cb7..98287315 100644 --- a/balrog/environments/minihack/minihack_env.py +++ b/balrog/environments/minihack/minihack_env.py @@ -3,7 +3,7 @@ import gym import minihack # NOQA: F401 -from balrog.environments.nle import NLELanguageWrapper +from balrog.environments.nle import AutoMore, NLELanguageWrapper from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit MINIHACK_ENVS = [] @@ -30,7 +30,9 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None) ], **minihack_kwargs, ) - env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more) + if skip_more: + env = AutoMore(env) + env = NLELanguageWrapper(env, vlm=vlm) # wrap NLE with timeout env = NLETimeLimit(env) diff --git a/balrog/environments/nle/__init__.py b/balrog/environments/nle/__init__.py index 99f28a87..7ddb3ece 100644 --- a/balrog/environments/nle/__init__.py +++ b/balrog/environments/nle/__init__.py @@ -1,5 +1,6 @@ import enum +from .auto_more import AutoMore from .base import NLELanguageWrapper diff --git a/balrog/environments/nle/auto_more.py b/balrog/environments/nle/auto_more.py new file mode 100644 index 00000000..92b71048 --- /dev/null +++ b/balrog/environments/nle/auto_more.py @@ -0,0 +1,33 @@ +import gym +from nle import nle_language_obsv +from nle.nethack import actions as A + + +class AutoMore(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.nle_language = nle_language_obsv.NLELanguageObsv() + + def reset(self, **kwargs): + obs = super().reset(**kwargs) + obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") + + return obs + + def step(self, action): + obs, reward, done, info = super().step(action) + + message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") + + while "--More--" in message and not done: + message = message.replace("--More--", "\n") + + action_index = self.env.actions.index(A.MiscAction.MORE) + obs, rew, done, info = super().step(action_index) + add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") + message += add + reward += rew + + obs["text_message"] = message + + return obs, reward, done, info diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 90b91451..134050f4 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -14,14 +14,13 @@ class NLELanguageWrapper(language_wrapper.NLELanguageWrapper): - def __init__(self, env, vlm=False, skip_more=False): + def __init__(self, env, vlm=False): super().__init__(env, use_language_action=True) self.nle_language = nle_language_obsv.NLELanguageObsv() self.language_action_space = self.create_action_space() self.env = env self.vlm = vlm self.done = False - self.skip_more = skip_more if not vlm: self.prompt_mode = "hybrid" @@ -77,28 +76,15 @@ def nle_obsv_type(self, nle_obsv): else: raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.') - def clean_message(self, nle_obsv): - message = self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1") - if not self.skip_more: - while "--More--" in message and not self.done: - message = message.replace("--More--", " ") - message = message.replace("\n", " ") - - nle_obsv, reward, self.done, info = self.step("more") - add = self.nle_language.text_message(nle_obsv["obs"]["tty_chars"]).decode("latin-1") - message += add - return message, nle_obsv["obs"] - return message, nle_obsv - def render(self, mode="human"): if mode == "tiles": - obs = self.env.last_observation - glyphs = obs[self.env._observation_keys.index("glyphs")] + obs = self.env.unwrapped.last_observation + glyphs = obs[self.env.unwrapped._observation_keys.index("glyphs")] return rgb_render_image(glyphs) elif mode == "tty_image": - obs = self.env.last_observation - tty_chars = obs[self.env._observation_keys.index("tty_chars")] - tty_colors = obs[self.env._observation_keys.index("tty_colors")] + obs = self.env.unwrapped.last_observation + tty_chars = obs[self.env.unwrapped._observation_keys.index("tty_chars")] + tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")] return tty_render_image(tty_chars, tty_colors) else: return super().render(mode) @@ -150,7 +136,11 @@ def nle_obsv_to_language(self, nle_obsv): (dict): language observation """ - message, nle_obsv = self.clean_message(nle_obsv) + message = ( + nle_obsv["text_message"] + if "text_message" in nle_obsv + else self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1") + ) glyphs = nle_obsv["glyphs"] blstats = nle_obsv["blstats"] diff --git a/balrog/environments/nle/nle_env.py b/balrog/environments/nle/nle_env.py index b0c0dadc..affb61a2 100644 --- a/balrog/environments/nle/nle_env.py +++ b/balrog/environments/nle/nle_env.py @@ -3,7 +3,7 @@ import gym import nle # NOQA: F401 -from balrog.environments.nle import NLELanguageWrapper +from balrog.environments.nle import AutoMore, NLELanguageWrapper from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit NETHACK_ENVS = [] @@ -18,7 +18,9 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): skip_more = nle_kwargs.pop("skip_more", False) vlm = True if config.agent.max_image_history > 0 else False env = gym.make(task, **nle_kwargs) - env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more) + if skip_more: + env = AutoMore(env) + env = NLELanguageWrapper(env, vlm=vlm) # wrap NLE with timeout env = NLETimeLimit(env)