Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ envs:
no_progress_timeout: 150 # Timeout for no progress in NLE
savedir: null # Directory to save NLE data; null disables saving
save_ttyrec_every: 0 # Frequency of saving TTY recordings
skip_more: False # Whether to skip the 'more' prompt in NLE
skip_more: True # Whether to skip the 'more' prompt in NLE
minihack_kwargs:
character: "@"
max_episode_steps: 100
Expand All @@ -57,7 +57,7 @@ envs:
savedir: null
save_ttyrec_every: 0
autopickup: False
skip_more: False
skip_more: True
babyai_kwargs:
num_dists: 0
crafter_kwargs:
Expand Down
6 changes: 4 additions & 2 deletions balrog/environments/minihack/minihack_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gym
import minihack # NOQA: F401

from balrog.environments.nle import NLELanguageWrapper
from balrog.environments.nle import AutoMore, NLELanguageWrapper
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit

MINIHACK_ENVS = []
Expand All @@ -30,7 +30,9 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None)
],
**minihack_kwargs,
)
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
if skip_more:
env = AutoMore(env)
env = NLELanguageWrapper(env, vlm=vlm)

# wrap NLE with timeout
env = NLETimeLimit(env)
Expand Down
1 change: 1 addition & 0 deletions balrog/environments/nle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum

from .auto_more import AutoMore
from .base import NLELanguageWrapper


Expand Down
33 changes: 33 additions & 0 deletions balrog/environments/nle/auto_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import gym
from nle import nle_language_obsv
from nle.nethack import actions as A


class AutoMore(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.nle_language = nle_language_obsv.NLELanguageObsv()

def reset(self, **kwargs):
obs = super().reset(**kwargs)
obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")

return obs

def step(self, action):
obs, reward, done, info = super().step(action)

message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")

while "--More--" in message and not done:
message = message.replace("--More--", "\n")

action_index = self.env.actions.index(A.MiscAction.MORE)
obs, rew, done, info = super().step(action_index)
add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1")
message += add
reward += rew

obs["text_message"] = message

return obs, reward, done, info
32 changes: 11 additions & 21 deletions balrog/environments/nle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@


class NLELanguageWrapper(language_wrapper.NLELanguageWrapper):
def __init__(self, env, vlm=False, skip_more=False):
def __init__(self, env, vlm=False):
super().__init__(env, use_language_action=True)
self.nle_language = nle_language_obsv.NLELanguageObsv()
self.language_action_space = self.create_action_space()
self.env = env
self.vlm = vlm
self.done = False
self.skip_more = skip_more

if not vlm:
self.prompt_mode = "hybrid"
Expand Down Expand Up @@ -77,28 +76,15 @@ def nle_obsv_type(self, nle_obsv):
else:
raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.')

def clean_message(self, nle_obsv):
message = self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
if not self.skip_more:
while "--More--" in message and not self.done:
message = message.replace("--More--", " ")
message = message.replace("\n", " ")

nle_obsv, reward, self.done, info = self.step("more")
add = self.nle_language.text_message(nle_obsv["obs"]["tty_chars"]).decode("latin-1")
message += add
return message, nle_obsv["obs"]
return message, nle_obsv

def render(self, mode="human"):
if mode == "tiles":
obs = self.env.last_observation
glyphs = obs[self.env._observation_keys.index("glyphs")]
obs = self.env.unwrapped.last_observation
glyphs = obs[self.env.unwrapped._observation_keys.index("glyphs")]
return rgb_render_image(glyphs)
elif mode == "tty_image":
obs = self.env.last_observation
tty_chars = obs[self.env._observation_keys.index("tty_chars")]
tty_colors = obs[self.env._observation_keys.index("tty_colors")]
obs = self.env.unwrapped.last_observation
tty_chars = obs[self.env.unwrapped._observation_keys.index("tty_chars")]
tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")]
return tty_render_image(tty_chars, tty_colors)
else:
return super().render(mode)
Expand Down Expand Up @@ -150,7 +136,11 @@ def nle_obsv_to_language(self, nle_obsv):
(dict): language observation
"""

message, nle_obsv = self.clean_message(nle_obsv)
message = (
nle_obsv["text_message"]
if "text_message" in nle_obsv
else self.nle_language.text_message(nle_obsv["tty_chars"]).decode("latin-1")
)

glyphs = nle_obsv["glyphs"]
blstats = nle_obsv["blstats"]
Expand Down
6 changes: 4 additions & 2 deletions balrog/environments/nle/nle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gym
import nle # NOQA: F401

from balrog.environments.nle import NLELanguageWrapper
from balrog.environments.nle import AutoMore, NLELanguageWrapper
from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit

NETHACK_ENVS = []
Expand All @@ -18,7 +18,9 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None):
skip_more = nle_kwargs.pop("skip_more", False)
vlm = True if config.agent.max_image_history > 0 else False
env = gym.make(task, **nle_kwargs)
env = NLELanguageWrapper(env, vlm=vlm, skip_more=skip_more)
if skip_more:
env = AutoMore(env)
env = NLELanguageWrapper(env, vlm=vlm)

# wrap NLE with timeout
env = NLETimeLimit(env)
Expand Down