From d00ed6127f3e2af29e33ab6b86470e77bce26cc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Sat, 2 Aug 2025 14:01:03 +0200 Subject: [PATCH 1/7] update nle and minihack to gymnasium --- balrog/environments/env_wrapper.py | 4 +- balrog/environments/minihack/__init__.py | 48 +--- balrog/environments/minihack/minihack_env.py | 20 +- .../minihack/minihack_progress.py | 45 +++ balrog/environments/nle/__init__.py | 88 +----- balrog/environments/nle/auto_more.py | 15 +- balrog/environments/nle/base.py | 272 ++++++++++++++---- balrog/environments/nle/nle_env.py | 19 +- balrog/environments/nle/progress.py | 169 ----------- 9 files changed, 283 insertions(+), 397 deletions(-) create mode 100644 balrog/environments/minihack/minihack_progress.py delete mode 100644 balrog/environments/nle/progress.py diff --git a/balrog/environments/env_wrapper.py b/balrog/environments/env_wrapper.py index 161f3a93..1cf0f299 100644 --- a/balrog/environments/env_wrapper.py +++ b/balrog/environments/env_wrapper.py @@ -16,7 +16,7 @@ def __init__(self, env, env_name, task_name): @property def max_steps(self): - return self.env.max_steps + return int(self.env.max_steps) def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) @@ -55,7 +55,7 @@ def get_instruction_prompt(self, instructions=None): if self.env_name == "nle": from balrog.environments.nle import get_instruction_prompt - return get_instruction_prompt() + return get_instruction_prompt(self.env, self.task_name) elif self.env_name == "minihack": from balrog.environments.minihack import get_instruction_prompt diff --git a/balrog/environments/minihack/__init__.py b/balrog/environments/minihack/__init__.py index eafa5209..18d563ce 100644 --- a/balrog/environments/minihack/__init__.py +++ b/balrog/environments/minihack/__init__.py @@ -1,51 +1,5 @@ from nle.language_wrapper.wrappers.nle_language_wrapper import NLELanguageWrapper -ACTIONS = { - "north": "move north", - "east": "move east", - "south": "move south", - "west": "move west", - "northeast": "move northeast", - "southeast": "move southeast", - "southwest": "move southwest", - "northwest": "move northwest", - "far north": "move far north", - "far east": "move far east", - "far south": "move far south", - "far west": "move far west", - "far northeast": "move far northeast", - "far southeast": "move far southeast", - "far southwest": "move far southwest", - "far northwest": "move far northwest", - "up": "go up the stairs", - "down": "go down the stairs", - "wait": "rest one move while doing nothing", - "more": "display more of the message", - "apply": "apply (use) a tool", - "close": "close an adjacent door", - "open": "open an adjacent door", - "eat": "eat something", - "force": "force a lock", - "kick": "kick an enemy or a locked door or chest", - "loot": "loot a box on the floor", - "pickup": "pick up things at the current location if there are any", - "pray": "pray to the gods for help", - "puton": "put on an accessory", - "quaff": "quaff (drink) something", - "search": "search for hidden doors and passages", - "zap": "zap a wand", -} - - -def get_available_actions(env): - available_actions = {} - for action in env.actions: - action_key = NLELanguageWrapper.all_nle_action_map[action][0] - if action_key not in ACTIONS: - continue - available_actions[action_key] = ACTIONS[action_key] - return available_actions - def get_instruction_prompt(env, task="MiniHack-ExploreMaze-Hard-Mapped-v0"): if "corridor" in task.lower(): @@ -57,7 +11,7 @@ def get_instruction_prompt(env, task="MiniHack-ExploreMaze-Hard-Mapped-v0"): else: goal = "Your goal is to get as far as possible in the game." - available_actions = get_available_actions(env) + available_actions = env.action_str_desc_map action_strings = ",\n".join(f"{action}: {description}" for action, description in available_actions.items()) instruction_prompt = f""" You are an agent playing MiniHack. The following are the possible actions you can take in the game, followed by a short description of each action: diff --git a/balrog/environments/minihack/minihack_env.py b/balrog/environments/minihack/minihack_env.py index 98287315..c6e224ef 100644 --- a/balrog/environments/minihack/minihack_env.py +++ b/balrog/environments/minihack/minihack_env.py @@ -1,16 +1,13 @@ from typing import Optional -import gym -import minihack # NOQA: F401 +import gymnasium as gym +from gymnasium import registry +import minihack # NOQA: F401 +from balrog.environments.minihack.minihack_progress import MiniHackProgressWrapper from balrog.environments.nle import AutoMore, NLELanguageWrapper -from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit -MINIHACK_ENVS = [] -for env_spec in gym.envs.registry.all(): - id = env_spec.id - if id.split("-")[0] == "MiniHack": - MINIHACK_ENVS.append(id) +MINIHACK_ENVS = [env_spec.id for env_spec in registry.values() if "MiniHack" in env_spec.id] def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None): @@ -32,11 +29,8 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None) ) if skip_more: env = AutoMore(env) - env = NLELanguageWrapper(env, vlm=vlm) - # wrap NLE with timeout - env = NLETimeLimit(env) - - env = GymV21CompatibilityV0(env=env, render_mode=render_mode) + env = MiniHackProgressWrapper(env, progression_on_done_only=False) + env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/environments/minihack/minihack_progress.py b/balrog/environments/minihack/minihack_progress.py new file mode 100644 index 00000000..98381a32 --- /dev/null +++ b/balrog/environments/minihack/minihack_progress.py @@ -0,0 +1,45 @@ +from typing import Optional + +import gymnasium as gym + + +class MiniHackProgress: + episode_return: float = 0.0 + progression: float = 0.0 + end_reason: Optional[str] = None + + def update(self, reward, info): + self.episode_return += reward + if reward >= 1.0: + self.progression = 1.0 + else: + self.progression = 0.0 + self.end_reason = info["end_status"] + + +class MiniHackProgressWrapper(gym.Wrapper): + def __init__(self, env, progression_on_done_only: bool = True): + super().__init__(env) + self.progression_on_done_only = progression_on_done_only + + def reset(self, **kwargs): + self.progress = MiniHackProgress() + return self.env.reset(**kwargs) + + def step(self, action): + obs, reward, term, trun, info = self.env.step(action) + self.progress.update(reward, info) + + done = term or trun + if not self.progression_on_done_only or done: + info["episode_extra_stats"] = self.episode_extra_stats(info) + + return obs, reward, term, trun, info + + def episode_extra_stats(self, info): + extra_stats = info.get("episode_extra_stats", {}) + new_extra_stats = { + "progression": self.progress.progression, + } + + return {**extra_stats, **new_extra_stats} diff --git a/balrog/environments/nle/__init__.py b/balrog/environments/nle/__init__.py index 7ddb3ece..568a66e9 100644 --- a/balrog/environments/nle/__init__.py +++ b/balrog/environments/nle/__init__.py @@ -20,92 +20,8 @@ class Role(enum.Enum): WIZARD = "wiz" -ACTIONS = { - "north": "move north", - "east": "move east", - "south": "move south", - "west": "move west", - "northeast": "move northeast", - "southeast": "move southeast", - "southwest": "move southwest", - "northwest": "move northwest", - "far north": "move far north", - "far east": "move far east", - "far south": "move far south", - "far west": "move far west", - "far northeast": "move far northeast", - "far southeast": "move far southeast", - "far southwest": "move far southwest", - "far northwest": "move far northwest", - "up": "go up a staircase", - "down": "go down a staircase (tip: you can only go down if you are standing on the stairs)", - "wait": "rest one move while doing nothing", - "more": "display more of the message (tip: ONLY ever use when current message ends with --More--)", - "annotate": "leave a note about the level", - "apply": "apply (use) a tool", - "call": "name a monster or object, or add an annotation", - "cast": "cast a spell", - "close": "close an adjacent door", - "open": "open an adjacent door", - "dip": "dip an object into something", - "drop": "drop an item", - "droptype": "drop specific item types (specify in the next prompt)", - "eat": "eat something (tip: replenish food when hungry)", - "esc": "exit menu or message", - "engrave": "engrave writing on the floor (tip: Elbereth)", - "enhance": "advance or check weapons skills", - "fire": "fire ammunition from quiver", - "fight": "fight a monster (even if you only guess one is there)", - "force": "force a lock", - "inventory": "show your inventory", - "invoke": "invoke ", - "jump": "jump to a location", - "kick": "kick an enemy or a locked door or chest", - "look": "look at what is under you", - "loot": "loot a box on the floor", - "monster": "use a monster's special ability (when polymorphed)", - "offer": "offer a sacrifice to the gods (tip: on an aligned altar)", - "overview": "display an overview of the dungeon", - "pay": "pay your shopping bill", - "pickup": "pick up things at the current location", - "pray": "pray to the gods for help", - "puton": "put on an accessory", - "quaff": "quaff (drink) something", - "quiver": "select ammunition for quiver", - "read": "read a scroll or spellbook", - "remove": "remove an accessory", - "rub": "rub a lamp or a stone", - "search": "search for hidden doors and passages", - "swap": "swap wielded and secondary weapons", - "takeoff": "take off one piece of armor", - "takeoffall": "take off all armor", - "teleport": "teleport to another level (if you have the ability)", - "throw": "throw something (e.g. a dagger or dart)", - "travel": "travel to a specific location on the map (tip: in the next action, specify > or < for stairs, { for fountain, and _ for altar)", - "twoweapon": "toggle two-weapon combat", - "untrap": "untrap something", - "wear": "wear a piece of armor", - "wield": "wield a weapon", - "wipe": "wipe off your face", - "zap": "zap a wand", - "minus": "-", - "space": " ", - "apos": "'", - "0": "0", - "1": "1", - "2": "2", - "3": "3", - "4": "4", - "5": "5", - "6": "6", - "7": "7", - "8": "8", - "9": "9", -} - - -def get_instruction_prompt(task=None): - action_strings = ",\n".join(f"{action}: {description}" for action, description in ACTIONS.items()) +def get_instruction_prompt(env, task=None): + action_strings = ",\n".join(f"{action}: {description}" for action, description in env.action_str_desc_map.items()) instruction_prompt = f""" You are an agent playing NetHack. The following are the possible actions you can take in the game, followed by a short description of each action: diff --git a/balrog/environments/nle/auto_more.py b/balrog/environments/nle/auto_more.py index 92b71048..c78caba0 100644 --- a/balrog/environments/nle/auto_more.py +++ b/balrog/environments/nle/auto_more.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym from nle import nle_language_obsv from nle.nethack import actions as A @@ -9,25 +9,26 @@ def __init__(self, env): self.nle_language = nle_language_obsv.NLELanguageObsv() def reset(self, **kwargs): - obs = super().reset(**kwargs) + obs, info = self.env.reset(**kwargs) obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") - return obs + return obs, info def step(self, action): - obs, reward, done, info = super().step(action) + obs, reward, term, trun, info = self.env.step(action) message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") - + done = term or trun while "--More--" in message and not done: message = message.replace("--More--", "\n") action_index = self.env.actions.index(A.MiscAction.MORE) - obs, rew, done, info = super().step(action_index) + obs, rew, term, trun, info = self.env.step(action_index) + done = term or trun add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") message += add reward += rew obs["text_message"] = message - return obs, reward, done, info + return obs, reward, term, trun, info diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 134050f4..0ad5028d 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -1,51 +1,102 @@ import random +import gymnasium as gym from nle import nle_language_obsv -from nle.language_wrapper.wrappers import nle_language_wrapper as language_wrapper -from nle.nethack import USEFUL_ACTIONS +from nle.language_wrapper.wrappers import nle_language_wrapper from PIL import Image from balrog.environments import Strings +from balrog.environments.nle.render import tty_render_image +from balrog.environments.nle.render_rgb import rgb_render_image -from ..minihack import ACTIONS as MINIHACK_ACTIONS -from .progress import get_progress_system -from .render import tty_render_image -from .render_rgb import rgb_render_image - -class NLELanguageWrapper(language_wrapper.NLELanguageWrapper): +class NLELanguageWrapper(gym.Wrapper): def __init__(self, env, vlm=False): - super().__init__(env, use_language_action=True) - self.nle_language = nle_language_obsv.NLELanguageObsv() - self.language_action_space = self.create_action_space() - self.env = env - self.vlm = vlm - self.done = False + super().__init__(env) + self.vlm = vlm if not vlm: self.prompt_mode = "hybrid" else: self.prompt_mode = "language" - self.progress = get_progress_system(self.env) + self.action_str_enum_map = {} + self.action_enum_index_map = {} + self.action_str_desc_map = {} + + if "minihack" in self.env.spec.id.lower(): + all_action_strs = [ + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + ] + assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join( + [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs] + ) + + for action_enum in self.env.unwrapped.actions: + for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: + if action_str not in MINIHACK_ACTIONS_TO_DESCR: + continue + + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) + self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str] + + elif "nethack" in self.env.spec.id.lower(): + all_action_strs = [ + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + ] + assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join( + [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs] + ) + + for action_enum in self.env.unwrapped.actions: + for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: + if action_str not in NLE_ACTIONS_TO_DESCR: + continue + + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) + self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str] + + else: + raise ValueError(f"Unsupported environment: {self.env.spec.id}") + + self.nle_language = nle_language_obsv.NLELanguageObsv() + self.language_action_space = self.create_action_space() + self.done = False self.max_steps = self.env.unwrapped._max_episode_steps + def pre_reset(self): + pass + + def reset(self, **kwargs): + self.pre_reset() + self.obs, self.info = self.env.reset(**kwargs) + + return self.post_reset(self.obs), self.info + + def post_reset(self, nle_obs): + return self.nle_process_obs(nle_obs) + + def pre_step(self, action): + nle_action_enum = self.action_str_enum_map[action] + nle_action_idx = self.action_enum_index_map[nle_action_enum] + + return nle_action_idx + def step(self, action): - obs, reward, done, info = super().step(action) - self.done = done if not self.done else self.done - self.progress.update(obs["obs"], reward, self.done, info) - return obs, reward, self.done, info + action = self.pre_step(action) - def post_reset(self, obsv): - return self.post_step(obsv) + self.obs, reward, term, trun, self.info = self.env.step(action) - def reset(self, **kwargs): - self.progress = get_progress_system(self.env) - obsv = self.env.reset(**kwargs) - return self.post_reset(obsv) + return self.post_step(self.obs), reward, term, trun, self.info - def post_step(self, nle_obsv): - return self.nle_process_obsv(nle_obsv) + def post_step(self, nle_obs): + return self.nle_process_obs(nle_obs) @property def default_action(self): @@ -55,24 +106,24 @@ def default_action(self): return "esc" def get_text_action(self, action): - return NLELanguageWrapper.all_nle_action_map[self.env.actions[action]][0] + return self.action_str_enum_map[action] - def nle_process_obsv(self, nle_obsv): + def nle_process_obs(self, nle_obs): img = Image.fromarray(self.render("tiles")).convert("RGB") if self.vlm else None - text = self.nle_obsv_type(nle_obsv) + text = self.nle_obs_type(nle_obs) return { "text": text, "image": img, - "obs": nle_obsv, + "obs": nle_obs, } - def nle_obsv_type(self, nle_obsv): - nle_obsv = self.nle_obsv_to_language(nle_obsv) + def nle_obs_type(self, nle_obs): + nle_obs = self.nle_obs_to_language(nle_obs) if self.prompt_mode == "language": - return self.render_text(nle_obsv) + return self.render_text(nle_obs) elif self.prompt_mode == "hybrid": - return self.render_hybrid(nle_obsv) + return self.render_hybrid(nle_obs) else: raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.') @@ -87,35 +138,13 @@ def render(self, mode="human"): tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")] return tty_render_image(tty_chars, tty_colors) else: - return super().render(mode) + return self.env.render(mode) def get_stats(self): - return self.progress.__dict__ + return self.info.get("episode_extra_stats", {}) def create_action_space(self): - if "minihack" in self.env.spec.id.lower(): - available_actions = {} - for action in self.env.actions: - action_key = NLELanguageWrapper.all_nle_action_map[action][0] - if action_key not in MINIHACK_ACTIONS: - continue - available_actions[action_key] = MINIHACK_ACTIONS[action_key] - - all_actions = [action for action, _ in available_actions.items()] - - else: - available_actions = [ - action_strs[0] - for action, action_strs in NLELanguageWrapper.all_nle_action_map.items() - if action in USEFUL_ACTIONS - ] - single_chars = [chr(i) for i in range(ord("a"), ord("z") + 1)] + [ - chr(i) for i in range(ord("A"), ord("Z") + 1) - ] - single_digits = [str(i) for i in range(10)] - double_digits = [f"{i:02d}" for i in range(100)] - all_actions = available_actions + single_chars + single_digits + double_digits - + all_actions = list(self.action_str_enum_map.keys()) return Strings(all_actions) def ascii_render(self, chars): @@ -128,7 +157,7 @@ def ascii_render(self, chars): result += "\n" return result - def nle_obsv_to_language(self, nle_obsv): + def nle_obs_to_language(self, nle_obsv): """Translate NLE Observation into a language observation. Args: nle_obsv (dict): NLE observation from the base environment @@ -204,3 +233,124 @@ def render_hybrid(self, nle_obsv): "long_term_context": long_term_context, "short_term_context": short_term_context, } + + +NLE_ACTIONS_TO_DESCR = { + "north": "move north", + "east": "move east", + "south": "move south", + "west": "move west", + "northeast": "move northeast", + "southeast": "move southeast", + "southwest": "move southwest", + "northwest": "move northwest", + "far north": "move far north", + "far east": "move far east", + "far south": "move far south", + "far west": "move far west", + "far northeast": "move far northeast", + "far southeast": "move far southeast", + "far southwest": "move far southwest", + "far northwest": "move far northwest", + "up": "go up a staircase", + "down": "go down a staircase (tip: you can only go down if you are standing on the stairs)", + "wait": "rest one move while doing nothing", + "more": "display more of the message (tip: ONLY ever use when current message ends with --More--)", + "annotate": "leave a note about the level", + "apply": "apply (use) a tool", + "call": "name a monster or object, or add an annotation", + "cast": "cast a spell", + "close": "close an adjacent door", + "open": "open an adjacent door", + "dip": "dip an object into something", + "drop": "drop an item", + "droptype": "drop specific item types (specify in the next prompt)", + "eat": "eat something (tip: replenish food when hungry)", + "esc": "exit menu or message", + "engrave": "engrave writing on the floor (tip: Elbereth)", + "enhance": "advance or check weapons skills", + "fire": "fire ammunition from quiver", + "fight": "fight a monster (even if you only guess one is there)", + "force": "force a lock", + "inventory": "show your inventory", + "invoke": "invoke ", + "jump": "jump to a location", + "kick": "kick an enemy or a locked door or chest", + "look": "look at what is under you", + "loot": "loot a box on the floor", + "monster": "use a monster's special ability (when polymorphed)", + "offer": "offer a sacrifice to the gods (tip: on an aligned altar)", + # "overview": "display an overview of the dungeon", + "pay": "pay your shopping bill", + "pickup": "pick up things at the current location", + "pray": "pray to the gods for help", + "puton": "put on an accessory", + "quaff": "quaff (drink) something", + "quiver": "select ammunition for quiver", + "read": "read a scroll or spellbook", + "remove": "remove an accessory", + "rub": "rub a lamp or a stone", + "search": "search for hidden doors and passages", + "swap": "swap wielded and secondary weapons", + "takeoff": "take off one piece of armor", + "takeoffall": "take off all armor", + "teleport": "teleport to another level (if you have the ability)", + "throw": "throw something (e.g. a dagger or dart)", + "travel": "travel to a specific location on the map (tip: in the next action, specify > or < for stairs, { for fountain, and _ for altar)", + "twoweapon": "toggle two-weapon combat", + "untrap": "untrap something", + "wear": "wear a piece of armor", + "wield": "wield a weapon", + "wipe": "wipe off your face", + "zap": "zap a wand", + "minus": "-", + "space": " ", + "apos": "'", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", +} + + +MINIHACK_ACTIONS_TO_DESCR = { + "north": "move north", + "east": "move east", + "south": "move south", + "west": "move west", + "northeast": "move northeast", + "southeast": "move southeast", + "southwest": "move southwest", + "northwest": "move northwest", + "far north": "move far north", + "far east": "move far east", + "far south": "move far south", + "far west": "move far west", + "far northeast": "move far northeast", + "far southeast": "move far southeast", + "far southwest": "move far southwest", + "far northwest": "move far northwest", + "up": "go up the stairs", + "down": "go down the stairs", + "wait": "rest one move while doing nothing", + "more": "display more of the message", + "apply": "apply (use) a tool", + "close": "close an adjacent door", + "open": "open an adjacent door", + "eat": "eat something", + "force": "force a lock", + "kick": "kick an enemy or a locked door or chest", + "loot": "loot a box on the floor", + "pickup": "pick up things at the current location if there are any", + "pray": "pray to the gods for help", + "puton": "put on an accessory", + "quaff": "quaff (drink) something", + "search": "search for hidden doors and passages", + "zap": "zap a wand", +} diff --git a/balrog/environments/nle/nle_env.py b/balrog/environments/nle/nle_env.py index affb61a2..de707fd2 100644 --- a/balrog/environments/nle/nle_env.py +++ b/balrog/environments/nle/nle_env.py @@ -1,16 +1,14 @@ from typing import Optional -import gym +import gymnasium as gym import nle # NOQA: F401 +import nle_progress # NOQA: F401 +from gymnasium import registry +from nle_progress import NLEProgressWrapper from balrog.environments.nle import AutoMore, NLELanguageWrapper -from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit -NETHACK_ENVS = [] -for env_spec in gym.envs.registry.all(): - id = env_spec.id - if "NetHack" in id: - NETHACK_ENVS.append(id) +NETHACK_ENVS = [env_spec.id for env_spec in registry.values() if "NetHack" in env_spec.id] def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): @@ -20,11 +18,8 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): env = gym.make(task, **nle_kwargs) if skip_more: env = AutoMore(env) - env = NLELanguageWrapper(env, vlm=vlm) - - # wrap NLE with timeout - env = NLETimeLimit(env) - env = GymV21CompatibilityV0(env=env, render_mode=render_mode) + env = NLEProgressWrapper(env, progression_on_done_only=False) + env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/environments/nle/progress.py b/balrog/environments/nle/progress.py deleted file mode 100644 index 590ee832..00000000 --- a/balrog/environments/nle/progress.py +++ /dev/null @@ -1,169 +0,0 @@ -import json -import os -from dataclasses import dataclass, field -from typing import Optional - -with open(os.path.join(os.path.dirname(__file__), "achievements.json"), "r") as f: - ACHIEVEMENTS = json.load(f) - - -def get_progress_system(env): - if "NetHackChallenge" in env.spec.id: - return Progress() - elif "MiniHack" in env.spec.id: - return BaseProgress() - else: - raise ValueError(f"Unsupported environment type: {type(env)}") - - -@dataclass -class Progress: - episode_return: float = 0.0 - score: int = 0 - depth: int = 1 - gold: int = 0 - experience_level: int = 1 - time: int = 0 - dlvl_list: list = field(default_factory=list) - xplvl_list: list = field(default_factory=list) - highest_achievement: Optional[str] = None - progression: float = 0.0 - end_reason: Optional[str] = None - - def update(self, nle_obsv, reward, done, info): - """ - Update the progress of the player given a message and stats. - - Returns: - float: The progression of the player. - """ - self.episode_return += reward - - stats = self._update_stats(nle_obsv["blstats"]) - - if done: - tty_chars = bytes(nle_obsv["tty_chars"].reshape(-1)).decode(errors="ignore") - self.end_reason = self._get_end_reason(tty_chars, info["end_status"]) - - xp = self._get_xp(stats) - if xp not in self.xplvl_list and xp in ACHIEVEMENTS.keys(): - self.xplvl_list.append(xp) - if ACHIEVEMENTS[xp] > self.progression: - self.progression = ACHIEVEMENTS[xp] - self.highest_achievement = xp - - dlvl = self._get_dlvl(stats) - if dlvl not in self.dlvl_list and dlvl in ACHIEVEMENTS.keys(): - self.dlvl_list.append(dlvl) - if ACHIEVEMENTS[dlvl] > self.progression: - self.progression = ACHIEVEMENTS[dlvl] - self.highest_achievement = dlvl - - def _update_stats(self, blstats): - # see: https://arxiv.org/pdf/2006.13760#page=16 - stats_names = [ - "x_pos", - "y_pos", - "strength_percentage", - "strength", - "dexterity", - "constitution", - "intelligence", - "wisdom", - "charisma", - "score", - "hitpoints", - "max_hitpoints", - "depth", - "gold", - "energy", - "max_energy", - "armor_class", - "monster_level", - "experience_level", - "experience_points", - "time", - "hunger_state", - "carrying_capacity", - "dungeon_number", - "level_number", - ] - stats = {name: value for name, value in zip(stats_names, blstats)} - - self.score = int(stats["score"]) - self.depth = int(stats["depth"]) - self.gold = int(stats["gold"]) - self.experience_level = int(stats["experience_level"]) - self.time = int(stats["time"]) - - return stats - - def _get_end_reason(self, tty_chars, end_status): - end_reason_words = tty_chars.replace("You made the top ten list!", "").split() - - if len(end_reason_words) > 7 and end_reason_words[7].startswith("Agent"): - end_reason = " ".join(end_reason_words[8:-2]) - else: - end_reason = " ".join(end_reason_words[7:-2]) - sentences = end_reason.split(".") - first_sentence = sentences[0].split() - - if "in" in first_sentence: - index_in = first_sentence.index("in") - first_part = " ".join(first_sentence[:index_in]) - else: - first_part = " ".join(first_sentence) - - remaining_sentences = ".".join(sentences[1:]).strip() - end_reason_final = f"{end_status.name}: " f"{first_part}." f" {remaining_sentences}".strip() - - return end_reason_final - - def _get_dlvl(self, stats): - """ - Get the dungeong lvl from the stats string. - - Args: - string (str): The stats string. - Returns: - str: The dungeong lvl - """ - # dlvl = string.split("$")[0] - dlvl = f"Dlvl:{stats['depth']}" - return dlvl - - def _get_xp(self, stats): - """ - Get the experience points from the stats string. - - Args: - string (str): The stats string. - Returns: - str: The experience points - """ - xp = f"Xp:{stats['experience_level']}" - return xp - - -class BaseProgress: - episode_return: float = 0.0 - progression: float = 0.0 - end_reason: Optional[str] = None - - def update(self, nle_obsv, reward, done, info): - """ - Update the progress of the player given a message and stats. - - Args: - message (str): The message to check for achievements. - stats (str): The stats to check for achievements. - - Returns: - float: The progression of the player. - """ - self.episode_return += reward - if reward >= 1.0: - self.progression = 1.0 - else: - self.progression = 0.0 - self.end_reason = info["end_status"] From bfa6d116c4b1ad63c6850809daab40090fa623b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Sat, 2 Aug 2025 14:02:33 +0200 Subject: [PATCH 2/7] update setup py --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 5999890f..92593dd0 100644 --- a/setup.py +++ b/setup.py @@ -29,8 +29,9 @@ "crafter", "gym==0.23", "requests", + "nle-progress @ git+https://github.com/BartekCupial/nle-progress.git", "balrog-nle", - "minihack @ git+https://github.com/balrog-ai/minihack.git", + "minihack @ git+https://github.com/BartekCupial/minihack.git", "textworld @ git+https://github.com/balrog-ai/TextWorld.git", "tatsu==5.8.3", "minigrid @ git+https://github.com/BartekCupial/Minigrid.git", From 8ada2bb2c899373686a71dbe99b4d2b4bf82bd28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Wed, 13 Aug 2025 14:53:54 +0200 Subject: [PATCH 3/7] use og progress --- balrog/environments/minihack/minihack_env.py | 2 - .../minihack/minihack_progress.py | 45 ----- balrog/environments/nle/base.py | 117 ++++++------ balrog/environments/nle/nle_env.py | 2 - balrog/environments/nle/progress.py | 169 ++++++++++++++++++ setup.py | 1 - 6 files changed, 231 insertions(+), 105 deletions(-) delete mode 100644 balrog/environments/minihack/minihack_progress.py create mode 100644 balrog/environments/nle/progress.py diff --git a/balrog/environments/minihack/minihack_env.py b/balrog/environments/minihack/minihack_env.py index c6e224ef..dbdc59d7 100644 --- a/balrog/environments/minihack/minihack_env.py +++ b/balrog/environments/minihack/minihack_env.py @@ -4,7 +4,6 @@ from gymnasium import registry import minihack # NOQA: F401 -from balrog.environments.minihack.minihack_progress import MiniHackProgressWrapper from balrog.environments.nle import AutoMore, NLELanguageWrapper MINIHACK_ENVS = [env_spec.id for env_spec in registry.values() if "MiniHack" in env_spec.id] @@ -30,7 +29,6 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None) if skip_more: env = AutoMore(env) - env = MiniHackProgressWrapper(env, progression_on_done_only=False) env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/environments/minihack/minihack_progress.py b/balrog/environments/minihack/minihack_progress.py deleted file mode 100644 index 98381a32..00000000 --- a/balrog/environments/minihack/minihack_progress.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -import gymnasium as gym - - -class MiniHackProgress: - episode_return: float = 0.0 - progression: float = 0.0 - end_reason: Optional[str] = None - - def update(self, reward, info): - self.episode_return += reward - if reward >= 1.0: - self.progression = 1.0 - else: - self.progression = 0.0 - self.end_reason = info["end_status"] - - -class MiniHackProgressWrapper(gym.Wrapper): - def __init__(self, env, progression_on_done_only: bool = True): - super().__init__(env) - self.progression_on_done_only = progression_on_done_only - - def reset(self, **kwargs): - self.progress = MiniHackProgress() - return self.env.reset(**kwargs) - - def step(self, action): - obs, reward, term, trun, info = self.env.step(action) - self.progress.update(reward, info) - - done = term or trun - if not self.progression_on_done_only or done: - info["episode_extra_stats"] = self.episode_extra_stats(info) - - return obs, reward, term, trun, info - - def episode_extra_stats(self, info): - extra_stats = info.get("episode_extra_stats", {}) - new_extra_stats = { - "progression": self.progress.progression, - } - - return {**extra_stats, **new_extra_stats} diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 0ad5028d..5cd337b3 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -3,9 +3,11 @@ import gymnasium as gym from nle import nle_language_obsv from nle.language_wrapper.wrappers import nle_language_wrapper +from nle.nethack import USEFUL_ACTIONS from PIL import Image from balrog.environments import Strings +from balrog.environments.nle.progress import get_progress_system from balrog.environments.nle.render import tty_render_image from balrog.environments.nle.render_rgb import rgb_render_image @@ -13,71 +15,27 @@ class NLELanguageWrapper(gym.Wrapper): def __init__(self, env, vlm=False): super().__init__(env) - + self.nle_language = nle_language_obsv.NLELanguageObsv() + self.language_action_space = self.create_action_space() self.vlm = vlm + self.done = False + if not vlm: self.prompt_mode = "hybrid" else: self.prompt_mode = "language" - self.action_str_enum_map = {} - self.action_enum_index_map = {} - self.action_str_desc_map = {} - - if "minihack" in self.env.spec.id.lower(): - all_action_strs = [ - action_str - for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() - for action_str in action_strs - ] - assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join( - [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs] - ) - - for action_enum in self.env.unwrapped.actions: - for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: - if action_str not in MINIHACK_ACTIONS_TO_DESCR: - continue - - self.action_str_enum_map[action_str] = action_enum - self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) - self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str] - - elif "nethack" in self.env.spec.id.lower(): - all_action_strs = [ - action_str - for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() - for action_str in action_strs - ] - assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join( - [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs] - ) - - for action_enum in self.env.unwrapped.actions: - for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: - if action_str not in NLE_ACTIONS_TO_DESCR: - continue - - self.action_str_enum_map[action_str] = action_enum - self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) - self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str] - - else: - raise ValueError(f"Unsupported environment: {self.env.spec.id}") - - self.nle_language = nle_language_obsv.NLELanguageObsv() - self.language_action_space = self.create_action_space() - self.done = False + self.progress = get_progress_system(self.env) self.max_steps = self.env.unwrapped._max_episode_steps def pre_reset(self): - pass + self.progress = get_progress_system(self.env) def reset(self, **kwargs): self.pre_reset() - self.obs, self.info = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) - return self.post_reset(self.obs), self.info + return self.post_reset(obs), info def post_reset(self, nle_obs): return self.nle_process_obs(nle_obs) @@ -91,9 +49,13 @@ def pre_step(self, action): def step(self, action): action = self.pre_step(action) - self.obs, reward, term, trun, self.info = self.env.step(action) + obs, reward, term, trun, info = self.env.step(action) + + done = term or trun + self.done = done if not self.done else self.done + self.progress.update(obs, reward, self.done, info) - return self.post_step(self.obs), reward, term, trun, self.info + return self.post_step(obs), reward, term, trun, info def post_step(self, nle_obs): return self.nle_process_obs(nle_obs) @@ -141,9 +103,54 @@ def render(self, mode="human"): return self.env.render(mode) def get_stats(self): - return self.info.get("episode_extra_stats", {}) + return self.progress.__dict__ def create_action_space(self): + self.action_str_enum_map = {} + self.action_enum_index_map = {} + self.action_str_desc_map = {} + + if "minihack" in self.env.spec.id.lower(): + all_action_strs = [ + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + ] + assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join( + [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs] + ) + + for action_enum in self.env.unwrapped.actions: + for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: + if action_str not in MINIHACK_ACTIONS_TO_DESCR: + continue + + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) + self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str] + + elif "nethack" in self.env.spec.id.lower(): + all_action_strs = [ + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + ] + assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join( + [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs] + ) + + for action_enum in self.env.unwrapped.actions: + for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: + if action_str not in NLE_ACTIONS_TO_DESCR: + continue + + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) + self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str] + + else: + raise ValueError(f"Unsupported environment: {self.env.spec.id}") + all_actions = list(self.action_str_enum_map.keys()) return Strings(all_actions) diff --git a/balrog/environments/nle/nle_env.py b/balrog/environments/nle/nle_env.py index de707fd2..1652511d 100644 --- a/balrog/environments/nle/nle_env.py +++ b/balrog/environments/nle/nle_env.py @@ -4,7 +4,6 @@ import nle # NOQA: F401 import nle_progress # NOQA: F401 from gymnasium import registry -from nle_progress import NLEProgressWrapper from balrog.environments.nle import AutoMore, NLELanguageWrapper @@ -19,7 +18,6 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): if skip_more: env = AutoMore(env) - env = NLEProgressWrapper(env, progression_on_done_only=False) env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/environments/nle/progress.py b/balrog/environments/nle/progress.py new file mode 100644 index 00000000..590ee832 --- /dev/null +++ b/balrog/environments/nle/progress.py @@ -0,0 +1,169 @@ +import json +import os +from dataclasses import dataclass, field +from typing import Optional + +with open(os.path.join(os.path.dirname(__file__), "achievements.json"), "r") as f: + ACHIEVEMENTS = json.load(f) + + +def get_progress_system(env): + if "NetHackChallenge" in env.spec.id: + return Progress() + elif "MiniHack" in env.spec.id: + return BaseProgress() + else: + raise ValueError(f"Unsupported environment type: {type(env)}") + + +@dataclass +class Progress: + episode_return: float = 0.0 + score: int = 0 + depth: int = 1 + gold: int = 0 + experience_level: int = 1 + time: int = 0 + dlvl_list: list = field(default_factory=list) + xplvl_list: list = field(default_factory=list) + highest_achievement: Optional[str] = None + progression: float = 0.0 + end_reason: Optional[str] = None + + def update(self, nle_obsv, reward, done, info): + """ + Update the progress of the player given a message and stats. + + Returns: + float: The progression of the player. + """ + self.episode_return += reward + + stats = self._update_stats(nle_obsv["blstats"]) + + if done: + tty_chars = bytes(nle_obsv["tty_chars"].reshape(-1)).decode(errors="ignore") + self.end_reason = self._get_end_reason(tty_chars, info["end_status"]) + + xp = self._get_xp(stats) + if xp not in self.xplvl_list and xp in ACHIEVEMENTS.keys(): + self.xplvl_list.append(xp) + if ACHIEVEMENTS[xp] > self.progression: + self.progression = ACHIEVEMENTS[xp] + self.highest_achievement = xp + + dlvl = self._get_dlvl(stats) + if dlvl not in self.dlvl_list and dlvl in ACHIEVEMENTS.keys(): + self.dlvl_list.append(dlvl) + if ACHIEVEMENTS[dlvl] > self.progression: + self.progression = ACHIEVEMENTS[dlvl] + self.highest_achievement = dlvl + + def _update_stats(self, blstats): + # see: https://arxiv.org/pdf/2006.13760#page=16 + stats_names = [ + "x_pos", + "y_pos", + "strength_percentage", + "strength", + "dexterity", + "constitution", + "intelligence", + "wisdom", + "charisma", + "score", + "hitpoints", + "max_hitpoints", + "depth", + "gold", + "energy", + "max_energy", + "armor_class", + "monster_level", + "experience_level", + "experience_points", + "time", + "hunger_state", + "carrying_capacity", + "dungeon_number", + "level_number", + ] + stats = {name: value for name, value in zip(stats_names, blstats)} + + self.score = int(stats["score"]) + self.depth = int(stats["depth"]) + self.gold = int(stats["gold"]) + self.experience_level = int(stats["experience_level"]) + self.time = int(stats["time"]) + + return stats + + def _get_end_reason(self, tty_chars, end_status): + end_reason_words = tty_chars.replace("You made the top ten list!", "").split() + + if len(end_reason_words) > 7 and end_reason_words[7].startswith("Agent"): + end_reason = " ".join(end_reason_words[8:-2]) + else: + end_reason = " ".join(end_reason_words[7:-2]) + sentences = end_reason.split(".") + first_sentence = sentences[0].split() + + if "in" in first_sentence: + index_in = first_sentence.index("in") + first_part = " ".join(first_sentence[:index_in]) + else: + first_part = " ".join(first_sentence) + + remaining_sentences = ".".join(sentences[1:]).strip() + end_reason_final = f"{end_status.name}: " f"{first_part}." f" {remaining_sentences}".strip() + + return end_reason_final + + def _get_dlvl(self, stats): + """ + Get the dungeong lvl from the stats string. + + Args: + string (str): The stats string. + Returns: + str: The dungeong lvl + """ + # dlvl = string.split("$")[0] + dlvl = f"Dlvl:{stats['depth']}" + return dlvl + + def _get_xp(self, stats): + """ + Get the experience points from the stats string. + + Args: + string (str): The stats string. + Returns: + str: The experience points + """ + xp = f"Xp:{stats['experience_level']}" + return xp + + +class BaseProgress: + episode_return: float = 0.0 + progression: float = 0.0 + end_reason: Optional[str] = None + + def update(self, nle_obsv, reward, done, info): + """ + Update the progress of the player given a message and stats. + + Args: + message (str): The message to check for achievements. + stats (str): The stats to check for achievements. + + Returns: + float: The progression of the player. + """ + self.episode_return += reward + if reward >= 1.0: + self.progression = 1.0 + else: + self.progression = 0.0 + self.end_reason = info["end_status"] diff --git a/setup.py b/setup.py index 92593dd0..4ff890f7 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,6 @@ "crafter", "gym==0.23", "requests", - "nle-progress @ git+https://github.com/BartekCupial/nle-progress.git", "balrog-nle", "minihack @ git+https://github.com/BartekCupial/minihack.git", "textworld @ git+https://github.com/balrog-ai/TextWorld.git", From 33e9b72583c5dac66e1ae324575f71123b31d40f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Wed, 13 Aug 2025 14:58:23 +0200 Subject: [PATCH 4/7] simplify render --- balrog/environments/nle/base.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 5cd337b3..6c951c0b 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -90,17 +90,14 @@ def nle_obs_type(self, nle_obs): raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.') def render(self, mode="human"): - if mode == "tiles": + if mode in ("tiles", "tty_image"): obs = self.env.unwrapped.last_observation - glyphs = obs[self.env.unwrapped._observation_keys.index("glyphs")] - return rgb_render_image(glyphs) - elif mode == "tty_image": - obs = self.env.unwrapped.last_observation - tty_chars = obs[self.env.unwrapped._observation_keys.index("tty_chars")] - tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")] - return tty_render_image(tty_chars, tty_colors) - else: - return self.env.render(mode) + key_idx = self.env.unwrapped._observation_keys.index + if mode == "tiles": + return rgb_render_image(obs[key_idx("glyphs")]) + else: + return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")]) + return self.env.render(mode) def get_stats(self): return self.progress.__dict__ From 97f10855ced8b019a100245949060e4ad76e082e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Wed, 13 Aug 2025 15:07:31 +0200 Subject: [PATCH 5/7] small refactor for action mappings --- balrog/environments/nle/base.py | 93 +++++++++++++++------------------ 1 file changed, 43 insertions(+), 50 deletions(-) diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 6c951c0b..a021ece2 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -1,4 +1,5 @@ import random +from typing import Any, Dict import gymnasium as gym from nle import nle_language_obsv @@ -16,7 +17,6 @@ class NLELanguageWrapper(gym.Wrapper): def __init__(self, env, vlm=False): super().__init__(env) self.nle_language = nle_language_obsv.NLELanguageObsv() - self.language_action_space = self.create_action_space() self.vlm = vlm self.done = False @@ -25,9 +25,51 @@ def __init__(self, env, vlm=False): else: self.prompt_mode = "language" + self._setup_action_space() self.progress = get_progress_system(self.env) self.max_steps = self.env.unwrapped._max_episode_steps + def _setup_action_space(self): + """Determines the action set and initializes the action space.""" + env_id_lower = self.env.spec.id.lower() + if "minihack" in env_id_lower: + actions_to_descr = MINIHACK_ACTIONS_TO_DESCR + elif "nethack" in env_id_lower: + actions_to_descr = NLE_ACTIONS_TO_DESCR + else: + raise ValueError(f"Unsupported environment: {self.env.spec.id}") + + self._initialize_action_mappings(actions_to_descr) + self.language_action_space = Strings(list(self.action_str_enum_map.keys())) + + def _initialize_action_mappings(self, actions_to_descr: Dict[str, str]): + """Builds mappings from action strings to NLE enums and indices.""" + self.action_str_enum_map: Dict[str, Any] = {} + self.action_enum_index_map: Dict[Any, int] = {} + self.action_str_desc_map: Dict[str, str] = {} + + # Pre-calculate all possible action strings from the base environment + all_action_strs = { + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + } + + # Validate that all keys in our description map are valid actions + missing_keys = [key for key in actions_to_descr if key not in all_action_strs] + if missing_keys: + raise KeyError(f"Action keys not found in NLE's action map: {', '.join(missing_keys)}") + + for action_enum in self.env.unwrapped.actions: + action_index = self.env.unwrapped.actions.index(action_enum) + possible_strs = nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.get(action_enum, []) + + for action_str in possible_strs: + if action_str in actions_to_descr: + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = action_index + self.action_str_desc_map[action_str] = actions_to_descr[action_str] + def pre_reset(self): self.progress = get_progress_system(self.env) @@ -102,55 +144,6 @@ def render(self, mode="human"): def get_stats(self): return self.progress.__dict__ - def create_action_space(self): - self.action_str_enum_map = {} - self.action_enum_index_map = {} - self.action_str_desc_map = {} - - if "minihack" in self.env.spec.id.lower(): - all_action_strs = [ - action_str - for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() - for action_str in action_strs - ] - assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join( - [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs] - ) - - for action_enum in self.env.unwrapped.actions: - for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: - if action_str not in MINIHACK_ACTIONS_TO_DESCR: - continue - - self.action_str_enum_map[action_str] = action_enum - self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) - self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str] - - elif "nethack" in self.env.spec.id.lower(): - all_action_strs = [ - action_str - for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() - for action_str in action_strs - ] - assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join( - [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs] - ) - - for action_enum in self.env.unwrapped.actions: - for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]: - if action_str not in NLE_ACTIONS_TO_DESCR: - continue - - self.action_str_enum_map[action_str] = action_enum - self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum) - self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str] - - else: - raise ValueError(f"Unsupported environment: {self.env.spec.id}") - - all_actions = list(self.action_str_enum_map.keys()) - return Strings(all_actions) - def ascii_render(self, chars): rows, cols = chars.shape result = "" From 0d07d3fb6cc163cd3c4de5d34f95119ee8d56db9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Wed, 13 Aug 2025 19:21:42 +0200 Subject: [PATCH 6/7] change nle language wrapper so we can set render_mode --- balrog/config/config.yaml | 1 + balrog/environments/minihack/minihack_env.py | 1 + balrog/environments/nle/base.py | 29 +++++++++++++------- balrog/environments/nle/nle_env.py | 2 +- balrog/evaluator.py | 2 +- 5 files changed, 23 insertions(+), 12 deletions(-) diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index fa8dd367..dd09eec0 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -41,6 +41,7 @@ envs: names: babyai-babaisai-textworld-crafter-nle-minihack # Environments to evaluate, separated by hyphens env_kwargs: seed: null # Random seed; null means a random seed is used + render_mode: null nle_kwargs: character: "@" # Character representing the agent in NLE max_episode_steps: 100_000 # Max steps per episode in NLE diff --git a/balrog/environments/minihack/minihack_env.py b/balrog/environments/minihack/minihack_env.py index dbdc59d7..e61b6c3b 100644 --- a/balrog/environments/minihack/minihack_env.py +++ b/balrog/environments/minihack/minihack_env.py @@ -25,6 +25,7 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None) "tty_colors", ], **minihack_kwargs, + render_mode=render_mode, ) if skip_more: env = AutoMore(env) diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index a021ece2..3c63f7be 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -102,6 +102,16 @@ def step(self, action): def post_step(self, nle_obs): return self.nle_process_obs(nle_obs) + def render(self): + mode = self.env.render_mode + + if mode == "tiles": + return self.render_tiles_from_obs(self.env.unwrapped._last_obs) + elif mode == "tty_image": + return self.render_tty_from_obs(self.env.unwrapped._last_obs) + else: + return self.env.render() + @property def default_action(self): if "minihack" in self.env.spec.id.lower(): @@ -113,7 +123,7 @@ def get_text_action(self, action): return self.action_str_enum_map[action] def nle_process_obs(self, nle_obs): - img = Image.fromarray(self.render("tiles")).convert("RGB") if self.vlm else None + img = Image.fromarray(self.render_tiles_from_obs(nle_obs)).convert("RGB") if self.vlm else None text = self.nle_obs_type(nle_obs) return { @@ -131,15 +141,14 @@ def nle_obs_type(self, nle_obs): else: raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.') - def render(self, mode="human"): - if mode in ("tiles", "tty_image"): - obs = self.env.unwrapped.last_observation - key_idx = self.env.unwrapped._observation_keys.index - if mode == "tiles": - return rgb_render_image(obs[key_idx("glyphs")]) - else: - return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")]) - return self.env.render(mode) + def render_tiles_from_obs(self, obs): + # Custom tiles rendering from latest observation + key_idx = self.env.unwrapped._observation_keys.index + return rgb_render_image(obs[key_idx("glyphs")]) + + def render_tty_from_obs(self, obs): + key_idx = self.env.unwrapped._observation_keys.index + return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")]) def get_stats(self): return self.progress.__dict__ diff --git a/balrog/environments/nle/nle_env.py b/balrog/environments/nle/nle_env.py index 1652511d..aba509db 100644 --- a/balrog/environments/nle/nle_env.py +++ b/balrog/environments/nle/nle_env.py @@ -14,7 +14,7 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): nle_kwargs = dict(config.envs.nle_kwargs) skip_more = nle_kwargs.pop("skip_more", False) vlm = True if config.agent.max_image_history > 0 else False - env = gym.make(task, **nle_kwargs) + env = gym.make(task, **nle_kwargs, render_mode=render_mode) if skip_more: env = AutoMore(env) diff --git a/balrog/evaluator.py b/balrog/evaluator.py index a9a937a9..76af039f 100644 --- a/balrog/evaluator.py +++ b/balrog/evaluator.py @@ -254,7 +254,7 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0): Returns: dict: Log of the episode containing statistics and results. """ - env = make_env(self.env_name, task, self.config) + env = make_env(self.env_name, task, self.config, render_mode=self.config.envs.env_kwargs.render_mode) agent.reset() seed = self.config.envs.env_kwargs.seed From 8851290dd74659bf88f5944d2510e2a204a9d491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Cupia=C5=82?= Date: Wed, 13 Aug 2025 19:29:48 +0200 Subject: [PATCH 7/7] quick script for playing with nle --- play.py | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 play.py diff --git a/play.py b/play.py new file mode 100644 index 00000000..970d399a --- /dev/null +++ b/play.py @@ -0,0 +1,118 @@ +import os +import random +import readline +import timeit +from datetime import datetime +from functools import partial +from pathlib import Path +from pprint import pprint + +import hydra +import numpy as np +from hydra.utils import get_original_cwd +from omegaconf import DictConfig + +from balrog.agents import AgentFactory +from balrog.environments import make_env +from balrog.evaluator import EvaluatorManager +from balrog.utils import get_unique_seed, setup_environment + + +def completer(text, state, commands=[]): + options = [cmd for cmd in commands if cmd.startswith(text)] + return options[state] if state < len(options) else None + + +def setup_autocomplete(completer_fn): + readline.parse_and_bind("tab: complete") + print("Type commands and use TAB to autocomplete.") + print("To see strategies use command: `help`") + readline.set_completer(completer_fn) + + +def get_action(env, obs): + language_action_space = env.get_wrapper_attr("language_action_space") + setup_autocomplete(partial(completer, commands=language_action_space)) + + while True: + command = input("> ") + + if command == "help": + print(language_action_space) + continue + else: + try: + assert command in language_action_space + break + except Exception: + print(f"Selected action '{command}' is not in action list. Please try again.") + continue + + return command + + +@hydra.main(config_path="balrog/config", config_name="config", version_base="1.1") +def main(config: DictConfig): + original_cwd = get_original_cwd() + setup_environment(original_cwd=original_cwd) + + # Determine output directory + if config.eval.resume_from is not None: + output_dir = config.eval.resume_from + else: + now = datetime.now() + timestamp = now.strftime("%Y-%m-%d_%H-%M-%S") + run_name = f"{timestamp}_{config.agent.type}_{config.client.model_id.replace('/', '_')}" + output_dir = os.path.join(config.eval.output_dir, run_name) + + # Create the directory if it doesn't exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + + env_name = random.choice(config.envs.names.split("-")) + task = random.choice(config.tasks[f"{env_name}_tasks"]) + print(f"Selected environment: {env_name}, task: {task}") + + env = make_env(env_name, task, config, render_mode="human") + + seed = config.envs.env_kwargs.seed + if seed is None: + seed = get_unique_seed(process_num=None, episode_idx=0) + random.seed(seed) + np.random.seed(seed) + obs, info = env.reset(seed=seed) + env.render() + + steps = 0 + reward = 0.0 + total_reward = 0.0 + action = None + + total_start_time = timeit.default_timer() + start_time = total_start_time + + while True: + action = get_action(env, obs) + if action is None: + break + + obs, reward, terminated, truncated, info = env.step(action) + env.render() + + steps += 1 + total_reward += reward + + if not (terminated or truncated): + continue + + time_delta = timeit.default_timer() - start_time + + print("Final reward:", reward) + print(f"Total reward: {total_reward}, Steps: {steps}, SPS: {steps / time_delta}", total_reward) + pprint.pprint(info) + + break + env.close() + + +if __name__ == "__main__": + main()