diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index fa8dd367..dd09eec0 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -41,6 +41,7 @@ envs: names: babyai-babaisai-textworld-crafter-nle-minihack # Environments to evaluate, separated by hyphens env_kwargs: seed: null # Random seed; null means a random seed is used + render_mode: null nle_kwargs: character: "@" # Character representing the agent in NLE max_episode_steps: 100_000 # Max steps per episode in NLE diff --git a/balrog/environments/env_wrapper.py b/balrog/environments/env_wrapper.py index 161f3a93..1cf0f299 100644 --- a/balrog/environments/env_wrapper.py +++ b/balrog/environments/env_wrapper.py @@ -16,7 +16,7 @@ def __init__(self, env, env_name, task_name): @property def max_steps(self): - return self.env.max_steps + return int(self.env.max_steps) def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) @@ -55,7 +55,7 @@ def get_instruction_prompt(self, instructions=None): if self.env_name == "nle": from balrog.environments.nle import get_instruction_prompt - return get_instruction_prompt() + return get_instruction_prompt(self.env, self.task_name) elif self.env_name == "minihack": from balrog.environments.minihack import get_instruction_prompt diff --git a/balrog/environments/minihack/__init__.py b/balrog/environments/minihack/__init__.py index eafa5209..18d563ce 100644 --- a/balrog/environments/minihack/__init__.py +++ b/balrog/environments/minihack/__init__.py @@ -1,51 +1,5 @@ from nle.language_wrapper.wrappers.nle_language_wrapper import NLELanguageWrapper -ACTIONS = { - "north": "move north", - "east": "move east", - "south": "move south", - "west": "move west", - "northeast": "move northeast", - "southeast": "move southeast", - "southwest": "move southwest", - "northwest": "move northwest", - "far north": "move far north", - "far east": "move far east", - "far south": "move far south", - "far west": "move far west", - "far northeast": "move far northeast", - "far southeast": "move far southeast", - "far southwest": "move far southwest", - "far northwest": "move far northwest", - "up": "go up the stairs", - "down": "go down the stairs", - "wait": "rest one move while doing nothing", - "more": "display more of the message", - "apply": "apply (use) a tool", - "close": "close an adjacent door", - "open": "open an adjacent door", - "eat": "eat something", - "force": "force a lock", - "kick": "kick an enemy or a locked door or chest", - "loot": "loot a box on the floor", - "pickup": "pick up things at the current location if there are any", - "pray": "pray to the gods for help", - "puton": "put on an accessory", - "quaff": "quaff (drink) something", - "search": "search for hidden doors and passages", - "zap": "zap a wand", -} - - -def get_available_actions(env): - available_actions = {} - for action in env.actions: - action_key = NLELanguageWrapper.all_nle_action_map[action][0] - if action_key not in ACTIONS: - continue - available_actions[action_key] = ACTIONS[action_key] - return available_actions - def get_instruction_prompt(env, task="MiniHack-ExploreMaze-Hard-Mapped-v0"): if "corridor" in task.lower(): @@ -57,7 +11,7 @@ def get_instruction_prompt(env, task="MiniHack-ExploreMaze-Hard-Mapped-v0"): else: goal = "Your goal is to get as far as possible in the game." - available_actions = get_available_actions(env) + available_actions = env.action_str_desc_map action_strings = ",\n".join(f"{action}: {description}" for action, description in available_actions.items()) instruction_prompt = f""" You are an agent playing MiniHack. The following are the possible actions you can take in the game, followed by a short description of each action: diff --git a/balrog/environments/minihack/minihack_env.py b/balrog/environments/minihack/minihack_env.py index 98287315..e61b6c3b 100644 --- a/balrog/environments/minihack/minihack_env.py +++ b/balrog/environments/minihack/minihack_env.py @@ -1,16 +1,12 @@ from typing import Optional -import gym -import minihack # NOQA: F401 +import gymnasium as gym +from gymnasium import registry +import minihack # NOQA: F401 from balrog.environments.nle import AutoMore, NLELanguageWrapper -from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit -MINIHACK_ENVS = [] -for env_spec in gym.envs.registry.all(): - id = env_spec.id - if id.split("-")[0] == "MiniHack": - MINIHACK_ENVS.append(id) +MINIHACK_ENVS = [env_spec.id for env_spec in registry.values() if "MiniHack" in env_spec.id] def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None): @@ -29,14 +25,11 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None) "tty_colors", ], **minihack_kwargs, + render_mode=render_mode, ) if skip_more: env = AutoMore(env) - env = NLELanguageWrapper(env, vlm=vlm) - # wrap NLE with timeout - env = NLETimeLimit(env) - - env = GymV21CompatibilityV0(env=env, render_mode=render_mode) + env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/environments/nle/__init__.py b/balrog/environments/nle/__init__.py index 7ddb3ece..568a66e9 100644 --- a/balrog/environments/nle/__init__.py +++ b/balrog/environments/nle/__init__.py @@ -20,92 +20,8 @@ class Role(enum.Enum): WIZARD = "wiz" -ACTIONS = { - "north": "move north", - "east": "move east", - "south": "move south", - "west": "move west", - "northeast": "move northeast", - "southeast": "move southeast", - "southwest": "move southwest", - "northwest": "move northwest", - "far north": "move far north", - "far east": "move far east", - "far south": "move far south", - "far west": "move far west", - "far northeast": "move far northeast", - "far southeast": "move far southeast", - "far southwest": "move far southwest", - "far northwest": "move far northwest", - "up": "go up a staircase", - "down": "go down a staircase (tip: you can only go down if you are standing on the stairs)", - "wait": "rest one move while doing nothing", - "more": "display more of the message (tip: ONLY ever use when current message ends with --More--)", - "annotate": "leave a note about the level", - "apply": "apply (use) a tool", - "call": "name a monster or object, or add an annotation", - "cast": "cast a spell", - "close": "close an adjacent door", - "open": "open an adjacent door", - "dip": "dip an object into something", - "drop": "drop an item", - "droptype": "drop specific item types (specify in the next prompt)", - "eat": "eat something (tip: replenish food when hungry)", - "esc": "exit menu or message", - "engrave": "engrave writing on the floor (tip: Elbereth)", - "enhance": "advance or check weapons skills", - "fire": "fire ammunition from quiver", - "fight": "fight a monster (even if you only guess one is there)", - "force": "force a lock", - "inventory": "show your inventory", - "invoke": "invoke ", - "jump": "jump to a location", - "kick": "kick an enemy or a locked door or chest", - "look": "look at what is under you", - "loot": "loot a box on the floor", - "monster": "use a monster's special ability (when polymorphed)", - "offer": "offer a sacrifice to the gods (tip: on an aligned altar)", - "overview": "display an overview of the dungeon", - "pay": "pay your shopping bill", - "pickup": "pick up things at the current location", - "pray": "pray to the gods for help", - "puton": "put on an accessory", - "quaff": "quaff (drink) something", - "quiver": "select ammunition for quiver", - "read": "read a scroll or spellbook", - "remove": "remove an accessory", - "rub": "rub a lamp or a stone", - "search": "search for hidden doors and passages", - "swap": "swap wielded and secondary weapons", - "takeoff": "take off one piece of armor", - "takeoffall": "take off all armor", - "teleport": "teleport to another level (if you have the ability)", - "throw": "throw something (e.g. a dagger or dart)", - "travel": "travel to a specific location on the map (tip: in the next action, specify > or < for stairs, { for fountain, and _ for altar)", - "twoweapon": "toggle two-weapon combat", - "untrap": "untrap something", - "wear": "wear a piece of armor", - "wield": "wield a weapon", - "wipe": "wipe off your face", - "zap": "zap a wand", - "minus": "-", - "space": " ", - "apos": "'", - "0": "0", - "1": "1", - "2": "2", - "3": "3", - "4": "4", - "5": "5", - "6": "6", - "7": "7", - "8": "8", - "9": "9", -} - - -def get_instruction_prompt(task=None): - action_strings = ",\n".join(f"{action}: {description}" for action, description in ACTIONS.items()) +def get_instruction_prompt(env, task=None): + action_strings = ",\n".join(f"{action}: {description}" for action, description in env.action_str_desc_map.items()) instruction_prompt = f""" You are an agent playing NetHack. The following are the possible actions you can take in the game, followed by a short description of each action: diff --git a/balrog/environments/nle/auto_more.py b/balrog/environments/nle/auto_more.py index 92b71048..c78caba0 100644 --- a/balrog/environments/nle/auto_more.py +++ b/balrog/environments/nle/auto_more.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym from nle import nle_language_obsv from nle.nethack import actions as A @@ -9,25 +9,26 @@ def __init__(self, env): self.nle_language = nle_language_obsv.NLELanguageObsv() def reset(self, **kwargs): - obs = super().reset(**kwargs) + obs, info = self.env.reset(**kwargs) obs["text_message"] = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") - return obs + return obs, info def step(self, action): - obs, reward, done, info = super().step(action) + obs, reward, term, trun, info = self.env.step(action) message = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") - + done = term or trun while "--More--" in message and not done: message = message.replace("--More--", "\n") action_index = self.env.actions.index(A.MiscAction.MORE) - obs, rew, done, info = super().step(action_index) + obs, rew, term, trun, info = self.env.step(action_index) + done = term or trun add = self.nle_language.text_message(obs["tty_chars"]).decode("latin-1") message += add reward += rew obs["text_message"] = message - return obs, reward, done, info + return obs, reward, term, trun, info diff --git a/balrog/environments/nle/base.py b/balrog/environments/nle/base.py index 134050f4..3c63f7be 100644 --- a/balrog/environments/nle/base.py +++ b/balrog/environments/nle/base.py @@ -1,24 +1,22 @@ import random +from typing import Any, Dict +import gymnasium as gym from nle import nle_language_obsv -from nle.language_wrapper.wrappers import nle_language_wrapper as language_wrapper +from nle.language_wrapper.wrappers import nle_language_wrapper from nle.nethack import USEFUL_ACTIONS from PIL import Image from balrog.environments import Strings +from balrog.environments.nle.progress import get_progress_system +from balrog.environments.nle.render import tty_render_image +from balrog.environments.nle.render_rgb import rgb_render_image -from ..minihack import ACTIONS as MINIHACK_ACTIONS -from .progress import get_progress_system -from .render import tty_render_image -from .render_rgb import rgb_render_image - -class NLELanguageWrapper(language_wrapper.NLELanguageWrapper): +class NLELanguageWrapper(gym.Wrapper): def __init__(self, env, vlm=False): - super().__init__(env, use_language_action=True) + super().__init__(env) self.nle_language = nle_language_obsv.NLELanguageObsv() - self.language_action_space = self.create_action_space() - self.env = env self.vlm = vlm self.done = False @@ -27,25 +25,92 @@ def __init__(self, env, vlm=False): else: self.prompt_mode = "language" + self._setup_action_space() self.progress = get_progress_system(self.env) self.max_steps = self.env.unwrapped._max_episode_steps + def _setup_action_space(self): + """Determines the action set and initializes the action space.""" + env_id_lower = self.env.spec.id.lower() + if "minihack" in env_id_lower: + actions_to_descr = MINIHACK_ACTIONS_TO_DESCR + elif "nethack" in env_id_lower: + actions_to_descr = NLE_ACTIONS_TO_DESCR + else: + raise ValueError(f"Unsupported environment: {self.env.spec.id}") + + self._initialize_action_mappings(actions_to_descr) + self.language_action_space = Strings(list(self.action_str_enum_map.keys())) + + def _initialize_action_mappings(self, actions_to_descr: Dict[str, str]): + """Builds mappings from action strings to NLE enums and indices.""" + self.action_str_enum_map: Dict[str, Any] = {} + self.action_enum_index_map: Dict[Any, int] = {} + self.action_str_desc_map: Dict[str, str] = {} + + # Pre-calculate all possible action strings from the base environment + all_action_strs = { + action_str + for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values() + for action_str in action_strs + } + + # Validate that all keys in our description map are valid actions + missing_keys = [key for key in actions_to_descr if key not in all_action_strs] + if missing_keys: + raise KeyError(f"Action keys not found in NLE's action map: {', '.join(missing_keys)}") + + for action_enum in self.env.unwrapped.actions: + action_index = self.env.unwrapped.actions.index(action_enum) + possible_strs = nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.get(action_enum, []) + + for action_str in possible_strs: + if action_str in actions_to_descr: + self.action_str_enum_map[action_str] = action_enum + self.action_enum_index_map[action_enum] = action_index + self.action_str_desc_map[action_str] = actions_to_descr[action_str] + + def pre_reset(self): + self.progress = get_progress_system(self.env) + + def reset(self, **kwargs): + self.pre_reset() + obs, info = self.env.reset(**kwargs) + + return self.post_reset(obs), info + + def post_reset(self, nle_obs): + return self.nle_process_obs(nle_obs) + + def pre_step(self, action): + nle_action_enum = self.action_str_enum_map[action] + nle_action_idx = self.action_enum_index_map[nle_action_enum] + + return nle_action_idx + def step(self, action): - obs, reward, done, info = super().step(action) + action = self.pre_step(action) + + obs, reward, term, trun, info = self.env.step(action) + + done = term or trun self.done = done if not self.done else self.done - self.progress.update(obs["obs"], reward, self.done, info) - return obs, reward, self.done, info + self.progress.update(obs, reward, self.done, info) - def post_reset(self, obsv): - return self.post_step(obsv) + return self.post_step(obs), reward, term, trun, info - def reset(self, **kwargs): - self.progress = get_progress_system(self.env) - obsv = self.env.reset(**kwargs) - return self.post_reset(obsv) + def post_step(self, nle_obs): + return self.nle_process_obs(nle_obs) - def post_step(self, nle_obsv): - return self.nle_process_obsv(nle_obsv) + def render(self): + mode = self.env.render_mode + + if mode == "tiles": + return self.render_tiles_from_obs(self.env.unwrapped._last_obs) + elif mode == "tty_image": + return self.render_tty_from_obs(self.env.unwrapped._last_obs) + else: + return self.env.render() @property def default_action(self): @@ -55,69 +120,39 @@ def default_action(self): return "esc" def get_text_action(self, action): - return NLELanguageWrapper.all_nle_action_map[self.env.actions[action]][0] + return self.action_str_enum_map[action] - def nle_process_obsv(self, nle_obsv): - img = Image.fromarray(self.render("tiles")).convert("RGB") if self.vlm else None - text = self.nle_obsv_type(nle_obsv) + def nle_process_obs(self, nle_obs): + img = Image.fromarray(self.render_tiles_from_obs(nle_obs)).convert("RGB") if self.vlm else None + text = self.nle_obs_type(nle_obs) return { "text": text, "image": img, - "obs": nle_obsv, + "obs": nle_obs, } - def nle_obsv_type(self, nle_obsv): - nle_obsv = self.nle_obsv_to_language(nle_obsv) + def nle_obs_type(self, nle_obs): + nle_obs = self.nle_obs_to_language(nle_obs) if self.prompt_mode == "language": - return self.render_text(nle_obsv) + return self.render_text(nle_obs) elif self.prompt_mode == "hybrid": - return self.render_hybrid(nle_obsv) + return self.render_hybrid(nle_obs) else: raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.') - def render(self, mode="human"): - if mode == "tiles": - obs = self.env.unwrapped.last_observation - glyphs = obs[self.env.unwrapped._observation_keys.index("glyphs")] - return rgb_render_image(glyphs) - elif mode == "tty_image": - obs = self.env.unwrapped.last_observation - tty_chars = obs[self.env.unwrapped._observation_keys.index("tty_chars")] - tty_colors = obs[self.env.unwrapped._observation_keys.index("tty_colors")] - return tty_render_image(tty_chars, tty_colors) - else: - return super().render(mode) + def render_tiles_from_obs(self, obs): + # Custom tiles rendering from latest observation + key_idx = self.env.unwrapped._observation_keys.index + return rgb_render_image(obs[key_idx("glyphs")]) + + def render_tty_from_obs(self, obs): + key_idx = self.env.unwrapped._observation_keys.index + return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")]) def get_stats(self): return self.progress.__dict__ - def create_action_space(self): - if "minihack" in self.env.spec.id.lower(): - available_actions = {} - for action in self.env.actions: - action_key = NLELanguageWrapper.all_nle_action_map[action][0] - if action_key not in MINIHACK_ACTIONS: - continue - available_actions[action_key] = MINIHACK_ACTIONS[action_key] - - all_actions = [action for action, _ in available_actions.items()] - - else: - available_actions = [ - action_strs[0] - for action, action_strs in NLELanguageWrapper.all_nle_action_map.items() - if action in USEFUL_ACTIONS - ] - single_chars = [chr(i) for i in range(ord("a"), ord("z") + 1)] + [ - chr(i) for i in range(ord("A"), ord("Z") + 1) - ] - single_digits = [str(i) for i in range(10)] - double_digits = [f"{i:02d}" for i in range(100)] - all_actions = available_actions + single_chars + single_digits + double_digits - - return Strings(all_actions) - def ascii_render(self, chars): rows, cols = chars.shape result = "" @@ -128,7 +163,7 @@ def ascii_render(self, chars): result += "\n" return result - def nle_obsv_to_language(self, nle_obsv): + def nle_obs_to_language(self, nle_obsv): """Translate NLE Observation into a language observation. Args: nle_obsv (dict): NLE observation from the base environment @@ -204,3 +239,124 @@ def render_hybrid(self, nle_obsv): "long_term_context": long_term_context, "short_term_context": short_term_context, } + + +NLE_ACTIONS_TO_DESCR = { + "north": "move north", + "east": "move east", + "south": "move south", + "west": "move west", + "northeast": "move northeast", + "southeast": "move southeast", + "southwest": "move southwest", + "northwest": "move northwest", + "far north": "move far north", + "far east": "move far east", + "far south": "move far south", + "far west": "move far west", + "far northeast": "move far northeast", + "far southeast": "move far southeast", + "far southwest": "move far southwest", + "far northwest": "move far northwest", + "up": "go up a staircase", + "down": "go down a staircase (tip: you can only go down if you are standing on the stairs)", + "wait": "rest one move while doing nothing", + "more": "display more of the message (tip: ONLY ever use when current message ends with --More--)", + "annotate": "leave a note about the level", + "apply": "apply (use) a tool", + "call": "name a monster or object, or add an annotation", + "cast": "cast a spell", + "close": "close an adjacent door", + "open": "open an adjacent door", + "dip": "dip an object into something", + "drop": "drop an item", + "droptype": "drop specific item types (specify in the next prompt)", + "eat": "eat something (tip: replenish food when hungry)", + "esc": "exit menu or message", + "engrave": "engrave writing on the floor (tip: Elbereth)", + "enhance": "advance or check weapons skills", + "fire": "fire ammunition from quiver", + "fight": "fight a monster (even if you only guess one is there)", + "force": "force a lock", + "inventory": "show your inventory", + "invoke": "invoke ", + "jump": "jump to a location", + "kick": "kick an enemy or a locked door or chest", + "look": "look at what is under you", + "loot": "loot a box on the floor", + "monster": "use a monster's special ability (when polymorphed)", + "offer": "offer a sacrifice to the gods (tip: on an aligned altar)", + # "overview": "display an overview of the dungeon", + "pay": "pay your shopping bill", + "pickup": "pick up things at the current location", + "pray": "pray to the gods for help", + "puton": "put on an accessory", + "quaff": "quaff (drink) something", + "quiver": "select ammunition for quiver", + "read": "read a scroll or spellbook", + "remove": "remove an accessory", + "rub": "rub a lamp or a stone", + "search": "search for hidden doors and passages", + "swap": "swap wielded and secondary weapons", + "takeoff": "take off one piece of armor", + "takeoffall": "take off all armor", + "teleport": "teleport to another level (if you have the ability)", + "throw": "throw something (e.g. a dagger or dart)", + "travel": "travel to a specific location on the map (tip: in the next action, specify > or < for stairs, { for fountain, and _ for altar)", + "twoweapon": "toggle two-weapon combat", + "untrap": "untrap something", + "wear": "wear a piece of armor", + "wield": "wield a weapon", + "wipe": "wipe off your face", + "zap": "zap a wand", + "minus": "-", + "space": " ", + "apos": "'", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", +} + + +MINIHACK_ACTIONS_TO_DESCR = { + "north": "move north", + "east": "move east", + "south": "move south", + "west": "move west", + "northeast": "move northeast", + "southeast": "move southeast", + "southwest": "move southwest", + "northwest": "move northwest", + "far north": "move far north", + "far east": "move far east", + "far south": "move far south", + "far west": "move far west", + "far northeast": "move far northeast", + "far southeast": "move far southeast", + "far southwest": "move far southwest", + "far northwest": "move far northwest", + "up": "go up the stairs", + "down": "go down the stairs", + "wait": "rest one move while doing nothing", + "more": "display more of the message", + "apply": "apply (use) a tool", + "close": "close an adjacent door", + "open": "open an adjacent door", + "eat": "eat something", + "force": "force a lock", + "kick": "kick an enemy or a locked door or chest", + "loot": "loot a box on the floor", + "pickup": "pick up things at the current location if there are any", + "pray": "pray to the gods for help", + "puton": "put on an accessory", + "quaff": "quaff (drink) something", + "search": "search for hidden doors and passages", + "zap": "zap a wand", +} diff --git a/balrog/environments/nle/nle_env.py b/balrog/environments/nle/nle_env.py index affb61a2..aba509db 100644 --- a/balrog/environments/nle/nle_env.py +++ b/balrog/environments/nle/nle_env.py @@ -1,30 +1,23 @@ from typing import Optional -import gym +import gymnasium as gym import nle # NOQA: F401 +import nle_progress # NOQA: F401 +from gymnasium import registry from balrog.environments.nle import AutoMore, NLELanguageWrapper -from balrog.environments.wrappers import GymV21CompatibilityV0, NLETimeLimit -NETHACK_ENVS = [] -for env_spec in gym.envs.registry.all(): - id = env_spec.id - if "NetHack" in id: - NETHACK_ENVS.append(id) +NETHACK_ENVS = [env_spec.id for env_spec in registry.values() if "NetHack" in env_spec.id] def make_nle_env(env_name, task, config, render_mode: Optional[str] = None): nle_kwargs = dict(config.envs.nle_kwargs) skip_more = nle_kwargs.pop("skip_more", False) vlm = True if config.agent.max_image_history > 0 else False - env = gym.make(task, **nle_kwargs) + env = gym.make(task, **nle_kwargs, render_mode=render_mode) if skip_more: env = AutoMore(env) - env = NLELanguageWrapper(env, vlm=vlm) - - # wrap NLE with timeout - env = NLETimeLimit(env) - env = GymV21CompatibilityV0(env=env, render_mode=render_mode) + env = NLELanguageWrapper(env, vlm=vlm) return env diff --git a/balrog/evaluator.py b/balrog/evaluator.py index a9a937a9..76af039f 100644 --- a/balrog/evaluator.py +++ b/balrog/evaluator.py @@ -254,7 +254,7 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0): Returns: dict: Log of the episode containing statistics and results. """ - env = make_env(self.env_name, task, self.config) + env = make_env(self.env_name, task, self.config, render_mode=self.config.envs.env_kwargs.render_mode) agent.reset() seed = self.config.envs.env_kwargs.seed diff --git a/play.py b/play.py new file mode 100644 index 00000000..970d399a --- /dev/null +++ b/play.py @@ -0,0 +1,118 @@ +import os +import random +import readline +import timeit +from datetime import datetime +from functools import partial +from pathlib import Path +from pprint import pprint + +import hydra +import numpy as np +from hydra.utils import get_original_cwd +from omegaconf import DictConfig + +from balrog.agents import AgentFactory +from balrog.environments import make_env +from balrog.evaluator import EvaluatorManager +from balrog.utils import get_unique_seed, setup_environment + + +def completer(text, state, commands=[]): + options = [cmd for cmd in commands if cmd.startswith(text)] + return options[state] if state < len(options) else None + + +def setup_autocomplete(completer_fn): + readline.parse_and_bind("tab: complete") + print("Type commands and use TAB to autocomplete.") + print("To see strategies use command: `help`") + readline.set_completer(completer_fn) + + +def get_action(env, obs): + language_action_space = env.get_wrapper_attr("language_action_space") + setup_autocomplete(partial(completer, commands=language_action_space)) + + while True: + command = input("> ") + + if command == "help": + print(language_action_space) + continue + else: + try: + assert command in language_action_space + break + except Exception: + print(f"Selected action '{command}' is not in action list. Please try again.") + continue + + return command + + +@hydra.main(config_path="balrog/config", config_name="config", version_base="1.1") +def main(config: DictConfig): + original_cwd = get_original_cwd() + setup_environment(original_cwd=original_cwd) + + # Determine output directory + if config.eval.resume_from is not None: + output_dir = config.eval.resume_from + else: + now = datetime.now() + timestamp = now.strftime("%Y-%m-%d_%H-%M-%S") + run_name = f"{timestamp}_{config.agent.type}_{config.client.model_id.replace('/', '_')}" + output_dir = os.path.join(config.eval.output_dir, run_name) + + # Create the directory if it doesn't exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + + env_name = random.choice(config.envs.names.split("-")) + task = random.choice(config.tasks[f"{env_name}_tasks"]) + print(f"Selected environment: {env_name}, task: {task}") + + env = make_env(env_name, task, config, render_mode="human") + + seed = config.envs.env_kwargs.seed + if seed is None: + seed = get_unique_seed(process_num=None, episode_idx=0) + random.seed(seed) + np.random.seed(seed) + obs, info = env.reset(seed=seed) + env.render() + + steps = 0 + reward = 0.0 + total_reward = 0.0 + action = None + + total_start_time = timeit.default_timer() + start_time = total_start_time + + while True: + action = get_action(env, obs) + if action is None: + break + + obs, reward, terminated, truncated, info = env.step(action) + env.render() + + steps += 1 + total_reward += reward + + if not (terminated or truncated): + continue + + time_delta = timeit.default_timer() - start_time + + print("Final reward:", reward) + print(f"Total reward: {total_reward}, Steps: {steps}, SPS: {steps / time_delta}", total_reward) + pprint.pprint(info) + + break + env.close() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 5999890f..4ff890f7 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ "gym==0.23", "requests", "balrog-nle", - "minihack @ git+https://github.com/balrog-ai/minihack.git", + "minihack @ git+https://github.com/BartekCupial/minihack.git", "textworld @ git+https://github.com/balrog-ai/TextWorld.git", "tatsu==5.8.3", "minigrid @ git+https://github.com/BartekCupial/Minigrid.git",