Skip to content
This repository was archived by the owner on May 6, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nle/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
registration.register(id="NetHackGold-v0", entry_point="nle.env.tasks:NetHackGold")
registration.register(id="NetHackEat-v0", entry_point="nle.env.tasks:NetHackEat")
registration.register(id="NetHackScout-v0", entry_point="nle.env.tasks:NetHackScout")
registration.register(
id="NetHackChallenge-v0", entry_point="nle.env.tasks:NetHackChallenge"
)

__all__ = ["NLE", "DUNGEON_SHAPE"]
26 changes: 18 additions & 8 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
),
(
"inv_strs",
gym.spaces.Box(low=0, high=127, **nethack.OBSERVATION_DESC["inv_strs"]),
gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["inv_strs"]),
),
(
"inv_letters",
Expand All @@ -116,13 +116,13 @@
),
(
"tty_chars",
gym.spaces.Box(low=0, high=127, **nethack.OBSERVATION_DESC["tty_chars"]),
gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["tty_chars"]),
),
(
"tty_colors",
gym.spaces.Box(
low=-15,
high=15,
low=0,
high=31,
**nethack.OBSERVATION_DESC["tty_colors"],
),
),
Expand Down Expand Up @@ -210,6 +210,7 @@ def __init__(
options=None,
wizard=False,
allow_all_yn_questions=False,
allow_all_modes=False,
space_dict=None,
):
"""Constructs a new NLE environment.
Expand All @@ -235,11 +236,15 @@ def __init__(
If set to True, no y/n questions in step() are declined.
If set to False, only elements of SKIP_EXCEPTIONS are not declined.
Defaults to False.
allow_all_modes (bool):
If set to True, do not decline menus, text input or auto 'MORE'.
If set to False, only skip click through 'MORE' on death.
"""

self.character = character
self._max_episode_steps = max_episode_steps
self._allow_all_yn_questions = allow_all_yn_questions
self._allow_all_modes = allow_all_modes

if actions is None:
actions = FULL_ACTIONS
Expand Down Expand Up @@ -339,6 +344,9 @@ def print_action_meanings(self):
for a_idx, a in enumerate(self._actions):
print(a_idx, a)

def _check_abort(self, observation):
return self._steps >= self._max_episode_steps

def step(self, action: int):
"""Steps the environment.

Expand All @@ -360,15 +368,17 @@ def step(self, action: int):
last_observation = tuple(a.copy() for a in self.last_observation)

observation, done = self.env.step(self._actions[action])
observation, done = self._perform_known_steps(
observation, done, exceptions=True
)
is_game_over = observation[self._program_state_index][0] == 1
if is_game_over or not self._allow_all_modes:
observation, done = self._perform_known_steps(
observation, done, exceptions=True
)

self._steps += 1

self.last_observation = observation

if self._steps >= self._max_episode_steps:
if self._check_abort(observation):
end_status = self.StepStatus.ABORTED
else:
end_status = self._is_episode_end(observation)
Expand Down
73 changes: 73 additions & 0 deletions nle/env/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,76 @@ def _reward_fn(self, last_observation, observation, end_status):
self.dungeon_explored[key] = explored
time_penalty = self._get_time_penalty(last_observation, observation)
return reward + time_penalty


class NetHackChallenge(NetHackScore):
"""Environment for the NetHack Challenge.

The task is an augmentation of the standard NLE task. This is the NLE Score Task
but with some subtle differences:
* the action space is fixed to include the full keyboard
* menus and "<More>" tokens are not skipped
* starting character is randomly assigned
"""

def __init__(
self,
*args,
character="@",
allow_all_yn_questions=True,
allow_all_modes=True,
penalty_mode="constant",
penalty_step: float = -0.00,
penalty_time: float = -0.0,
max_episode_steps: int = 1e6,
observation_keys=(
"glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"inv_glyphs",
"inv_strs",
"inv_letters",
"inv_oclasses",
"tty_chars",
"tty_colors",
"tty_cursor",
),
no_progress_timeout: int = 10_000,
**kwargs,
):
actions = nethack.ACTIONS
super().__init__(
*args,
actions=actions,
character=character,
allow_all_yn_questions=allow_all_yn_questions,
allow_all_modes=allow_all_modes,
penalty_mode=penalty_mode,
penalty_step=penalty_step,
penalty_time=penalty_time,
max_episode_steps=max_episode_steps,
observation_keys=observation_keys,
**kwargs,
)
# If the in-game turn count doesn't change for 10_000 steps, we abort
self._turns = None
self._no_progress_count = 0
self.no_progress_timeout = no_progress_timeout

def _check_abort(self, observation):
"""Check if time has stopped and no observations has changed long enough
to trigger an abort."""

turns = observation[self._blstats_index][20]
if self._turns == turns:
self._no_progress_count += 1
else:
self._turns = turns
self._no_progress_count = 0
return (
self._steps >= self._max_episode_steps
or self._no_progress_count >= self.no_progress_timeout
)
35 changes: 27 additions & 8 deletions nle/nethack/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@ def C(c):
return 0x1F & c


# Missing here:
# Some characters for text input (e.g., +).
# General menu handling isn't part of this either.
class TextCharacters(enum.IntEnum):
PLUS = ord("+")
MINUS = ord("-")
SPACE = ord(" ")
APOS = ord("'")
QUOTE = ord('"')
NUM_0 = ord("0")
NUM_1 = ord("1")
NUM_2 = ord("2")
NUM_3 = ord("3")
NUM_4 = ord("4")
NUM_5 = ord("5")
NUM_6 = ord("6")
NUM_7 = ord("7")
NUM_8 = ord("8")
NUM_9 = ord("9")


class CompassCardinalDirection(enum.IntEnum):
Expand Down Expand Up @@ -76,6 +89,12 @@ class MiscAction(enum.IntEnum):
MORE = ord("\r") # read the next message


class UnsafeActions(enum.IntEnum):
# currently these result in an error or undesirable behaviour
HELP = ord("?") # give a help message
PREVMSG = C("p") # view recent game messages


class Command(enum.IntEnum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, this was "all commands" (or something), and the "RL relevant ones" we gathered in their own enum. What's the philosophy now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum wasn't really all commands because it didnt contain movement etc? I figured if we were splitting out lets at least clearly mark the very dangerous ones.

EXTCMD = ord("#") # perform an extended command
EXTLIST = M("?") # list all extended commands
Expand All @@ -100,7 +119,6 @@ class Command(enum.IntEnum):
FIGHT = ord("F") # Prefix: force fight even if you don't see a monster
FORCE = M("f") # force a lock
GLANCE = ord(";") # show what type of thing a map symbol corresponds to
HELP = ord("?") # give a help message
HISTORY = ord("V") # show long version and game history
INVENTORY = ord("i") # show your inventory
INVENTTYPE = ord("I") # inventory specific item types
Expand All @@ -121,7 +139,6 @@ class Command(enum.IntEnum):
PAY = ord("p") # pay your shopping bill
PICKUP = ord(",") # pick up things at the current location
PRAY = M("p") # pray to the gods for help
PREVMSG = C("p") # view recent game messages
PUTON = ord("P") # put on an accessory (ring, amulet, etc)
QUAFF = ord("q") # quaff (drink) something
QUIT = M("q") # exit without saving current game
Expand All @@ -132,6 +149,7 @@ class Command(enum.IntEnum):
RIDE = M("R") # mount or dismount a saddled steed
RUB = M("r") # rub a lamp or a stone
RUSH = ord("g") # Prefix: rush until something interesting is seen
RUSH2 = ord("G") # Prefix: rush until something interesting is seen
SAVE = ord("S") # save the game and exit
SEARCH = ord("s") # search for traps and secret doors
SEEALL = ord("*") # show all equipment in use
Expand Down Expand Up @@ -163,6 +181,7 @@ class Command(enum.IntEnum):
+ list(MiscDirection)
+ list(MiscAction)
+ list(Command)
+ list(TextCharacters)
)

NON_RL_ACTIONS = (
Expand All @@ -172,13 +191,11 @@ class Command(enum.IntEnum):
Command.EXTCMD, # Potentially useful for some wizard actions.
Command.EXTLIST,
Command.GLANCE,
Command.HELP,
Command.HISTORY,
Command.KNOWN, # Could potentially be useful.
Command.KNOWNCLASS, # Could potentially be useful.
Command.OPTIONS,
Command.OVERVIEW, # Could potentially be useful.
Command.PREVMSG, # Could potentially be useful.
Command.TELEPORT,
Command.QUIT,
Command.REDRAW,
Expand All @@ -191,13 +208,15 @@ class Command(enum.IntEnum):
)

_USEFUL_ACTIONS = list(ACTIONS)
for action in NON_RL_ACTIONS:
for action in NON_RL_ACTIONS + tuple(TextCharacters):
_USEFUL_ACTIONS.remove(action)
_USEFUL_ACTIONS.append(TextCharacters.SPACE)
USEFUL_ACTIONS = tuple(_USEFUL_ACTIONS)
del _USEFUL_ACTIONS

_ACTIONS_DICT = {}
for enum_class in [
TextCharacters,
CompassDirection,
CompassDirectionLonger,
MiscDirection,
Expand Down
36 changes: 30 additions & 6 deletions nle/scripts/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def get_action(env, action_mode, is_raw_env):
action = env.action_space.sample()
else:
action = random.choice(_ACTIONS)
print(action)
elif action_mode == "human":
while True:
with no_echo():
ch = ord(os.read(0, 1))
if ch in [nethack.C("c"), ord(b"q")]:
if ch in [nethack.C("c")]:
print("Received exit code {}. Aborting.".format(ch))
return None
try:
Expand All @@ -67,7 +68,18 @@ def get_action(env, action_mode, is_raw_env):
return action


def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, debug):
def play(
env,
mode,
ngames,
max_steps,
seeds,
savedir,
no_render,
render_mode,
print_frames_separately,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we use the **kwargs for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the debug flag is created but does nothing in this function because it has been used previously. however this fn is called with play(**vars(flags)), so it needs a home here. the home i thought would be clearer in **kwargs, but i almost put **_

):
env_name = env
is_raw_env = env_name == "raw"

Expand Down Expand Up @@ -100,10 +112,15 @@ def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, d
while True:
if not no_render:
if not is_raw_env:
print("Previous reward:", reward)
if action is not None:
print("Previous action: %s" % repr(env._actions[action]))
print("--------")
print(f"Previous reward: {str(reward):64s}")
act_str = repr(env._actions[action]) if action is not None else ""
print(f"Previous action: {str(act_str):64s}")
print("--------")
env.render(render_mode)
print("--------")
if not print_frames_separately:
print("\033[31A") # Go up 31 lines.
else:
print("Previous action:", action)
_, chars, _, _, blstats, message, *_ = obs
Expand All @@ -114,6 +131,7 @@ def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, d
print(blstats)

action = get_action(env, mode, is_raw_env)

if action is None:
break

Expand Down Expand Up @@ -194,7 +212,7 @@ def main():
parser.add_argument(
"--max-steps",
type=int,
default=10000,
default=1_000_000,
help="Number of maximum steps per episode.",
)
parser.add_argument(
Expand All @@ -219,6 +237,12 @@ def main():
choices=["human", "full", "ansi"],
help="Render mode. Defaults to 'human'.",
)
parser.add_argument(
"--print-frames-separately",
"-p",
action="store_true",
help="Don't overwrite frames, print them all.",
)
flags = parser.parse_args()

if flags.debug:
Expand Down
Loading