Skip to content
This repository was archived by the owner on May 6, 2024. It is now read-only.

Commit fc2483a

Browse files
authored
Merge pull request #156 from facebookresearch/eric/competition
Eric/competition
2 parents e464dba + dbf53f0 commit fc2483a

File tree

8 files changed

+197
-28
lines changed

8 files changed

+197
-28
lines changed

nle/env/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@
1616
registration.register(id="NetHackGold-v0", entry_point="nle.env.tasks:NetHackGold")
1717
registration.register(id="NetHackEat-v0", entry_point="nle.env.tasks:NetHackEat")
1818
registration.register(id="NetHackScout-v0", entry_point="nle.env.tasks:NetHackScout")
19+
registration.register(
20+
id="NetHackChallenge-v0", entry_point="nle.env.tasks:NetHackChallenge"
21+
)
1922

2023
__all__ = ["NLE", "DUNGEON_SHAPE"]

nle/env/base.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
),
9595
(
9696
"inv_strs",
97-
gym.spaces.Box(low=0, high=127, **nethack.OBSERVATION_DESC["inv_strs"]),
97+
gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["inv_strs"]),
9898
),
9999
(
100100
"inv_letters",
@@ -116,13 +116,13 @@
116116
),
117117
(
118118
"tty_chars",
119-
gym.spaces.Box(low=0, high=127, **nethack.OBSERVATION_DESC["tty_chars"]),
119+
gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["tty_chars"]),
120120
),
121121
(
122122
"tty_colors",
123123
gym.spaces.Box(
124-
low=-15,
125-
high=15,
124+
low=0,
125+
high=31,
126126
**nethack.OBSERVATION_DESC["tty_colors"],
127127
),
128128
),
@@ -210,6 +210,7 @@ def __init__(
210210
options=None,
211211
wizard=False,
212212
allow_all_yn_questions=False,
213+
allow_all_modes=False,
213214
space_dict=None,
214215
):
215216
"""Constructs a new NLE environment.
@@ -235,11 +236,15 @@ def __init__(
235236
If set to True, no y/n questions in step() are declined.
236237
If set to False, only elements of SKIP_EXCEPTIONS are not declined.
237238
Defaults to False.
239+
allow_all_modes (bool):
240+
If set to True, do not decline menus, text input or auto 'MORE'.
241+
If set to False, only skip click through 'MORE' on death.
238242
"""
239243

240244
self.character = character
241245
self._max_episode_steps = max_episode_steps
242246
self._allow_all_yn_questions = allow_all_yn_questions
247+
self._allow_all_modes = allow_all_modes
243248

244249
if actions is None:
245250
actions = FULL_ACTIONS
@@ -339,6 +344,9 @@ def print_action_meanings(self):
339344
for a_idx, a in enumerate(self._actions):
340345
print(a_idx, a)
341346

347+
def _check_abort(self, observation):
348+
return self._steps >= self._max_episode_steps
349+
342350
def step(self, action: int):
343351
"""Steps the environment.
344352
@@ -360,15 +368,17 @@ def step(self, action: int):
360368
last_observation = tuple(a.copy() for a in self.last_observation)
361369

362370
observation, done = self.env.step(self._actions[action])
363-
observation, done = self._perform_known_steps(
364-
observation, done, exceptions=True
365-
)
371+
is_game_over = observation[self._program_state_index][0] == 1
372+
if is_game_over or not self._allow_all_modes:
373+
observation, done = self._perform_known_steps(
374+
observation, done, exceptions=True
375+
)
366376

367377
self._steps += 1
368378

369379
self.last_observation = observation
370380

371-
if self._steps >= self._max_episode_steps:
381+
if self._check_abort(observation):
372382
end_status = self.StepStatus.ABORTED
373383
else:
374384
end_status = self._is_episode_end(observation)

nle/env/tasks.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,76 @@ def _reward_fn(self, last_observation, observation, end_status):
278278
self.dungeon_explored[key] = explored
279279
time_penalty = self._get_time_penalty(last_observation, observation)
280280
return reward + time_penalty
281+
282+
283+
class NetHackChallenge(NetHackScore):
284+
"""Environment for the NetHack Challenge.
285+
286+
The task is an augmentation of the standard NLE task. This is the NLE Score Task
287+
but with some subtle differences:
288+
* the action space is fixed to include the full keyboard
289+
* menus and "<More>" tokens are not skipped
290+
* starting character is randomly assigned
291+
"""
292+
293+
def __init__(
294+
self,
295+
*args,
296+
character="@",
297+
allow_all_yn_questions=True,
298+
allow_all_modes=True,
299+
penalty_mode="constant",
300+
penalty_step: float = -0.00,
301+
penalty_time: float = -0.0,
302+
max_episode_steps: int = 1e6,
303+
observation_keys=(
304+
"glyphs",
305+
"chars",
306+
"colors",
307+
"specials",
308+
"blstats",
309+
"message",
310+
"inv_glyphs",
311+
"inv_strs",
312+
"inv_letters",
313+
"inv_oclasses",
314+
"tty_chars",
315+
"tty_colors",
316+
"tty_cursor",
317+
),
318+
no_progress_timeout: int = 10_000,
319+
**kwargs,
320+
):
321+
actions = nethack.ACTIONS
322+
super().__init__(
323+
*args,
324+
actions=actions,
325+
character=character,
326+
allow_all_yn_questions=allow_all_yn_questions,
327+
allow_all_modes=allow_all_modes,
328+
penalty_mode=penalty_mode,
329+
penalty_step=penalty_step,
330+
penalty_time=penalty_time,
331+
max_episode_steps=max_episode_steps,
332+
observation_keys=observation_keys,
333+
**kwargs,
334+
)
335+
# If the in-game turn count doesn't change for 10_000 steps, we abort
336+
self._turns = None
337+
self._no_progress_count = 0
338+
self.no_progress_timeout = no_progress_timeout
339+
340+
def _check_abort(self, observation):
341+
"""Check if time has stopped and no observations has changed long enough
342+
to trigger an abort."""
343+
344+
turns = observation[self._blstats_index][20]
345+
if self._turns == turns:
346+
self._no_progress_count += 1
347+
else:
348+
self._turns = turns
349+
self._no_progress_count = 0
350+
return (
351+
self._steps >= self._max_episode_steps
352+
or self._no_progress_count >= self.no_progress_timeout
353+
)

nle/nethack/actions.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,22 @@ def C(c):
1414
return 0x1F & c
1515

1616

17-
# Missing here:
18-
# Some characters for text input (e.g., +).
19-
# General menu handling isn't part of this either.
17+
class TextCharacters(enum.IntEnum):
18+
PLUS = ord("+")
19+
MINUS = ord("-")
20+
SPACE = ord(" ")
21+
APOS = ord("'")
22+
QUOTE = ord('"')
23+
NUM_0 = ord("0")
24+
NUM_1 = ord("1")
25+
NUM_2 = ord("2")
26+
NUM_3 = ord("3")
27+
NUM_4 = ord("4")
28+
NUM_5 = ord("5")
29+
NUM_6 = ord("6")
30+
NUM_7 = ord("7")
31+
NUM_8 = ord("8")
32+
NUM_9 = ord("9")
2033

2134

2235
class CompassCardinalDirection(enum.IntEnum):
@@ -76,6 +89,12 @@ class MiscAction(enum.IntEnum):
7689
MORE = ord("\r") # read the next message
7790

7891

92+
class UnsafeActions(enum.IntEnum):
93+
# currently these result in an error or undesirable behaviour
94+
HELP = ord("?") # give a help message
95+
PREVMSG = C("p") # view recent game messages
96+
97+
7998
class Command(enum.IntEnum):
8099
EXTCMD = ord("#") # perform an extended command
81100
EXTLIST = M("?") # list all extended commands
@@ -100,7 +119,6 @@ class Command(enum.IntEnum):
100119
FIGHT = ord("F") # Prefix: force fight even if you don't see a monster
101120
FORCE = M("f") # force a lock
102121
GLANCE = ord(";") # show what type of thing a map symbol corresponds to
103-
HELP = ord("?") # give a help message
104122
HISTORY = ord("V") # show long version and game history
105123
INVENTORY = ord("i") # show your inventory
106124
INVENTTYPE = ord("I") # inventory specific item types
@@ -121,7 +139,6 @@ class Command(enum.IntEnum):
121139
PAY = ord("p") # pay your shopping bill
122140
PICKUP = ord(",") # pick up things at the current location
123141
PRAY = M("p") # pray to the gods for help
124-
PREVMSG = C("p") # view recent game messages
125142
PUTON = ord("P") # put on an accessory (ring, amulet, etc)
126143
QUAFF = ord("q") # quaff (drink) something
127144
QUIT = M("q") # exit without saving current game
@@ -132,6 +149,7 @@ class Command(enum.IntEnum):
132149
RIDE = M("R") # mount or dismount a saddled steed
133150
RUB = M("r") # rub a lamp or a stone
134151
RUSH = ord("g") # Prefix: rush until something interesting is seen
152+
RUSH2 = ord("G") # Prefix: rush until something interesting is seen
135153
SAVE = ord("S") # save the game and exit
136154
SEARCH = ord("s") # search for traps and secret doors
137155
SEEALL = ord("*") # show all equipment in use
@@ -163,6 +181,7 @@ class Command(enum.IntEnum):
163181
+ list(MiscDirection)
164182
+ list(MiscAction)
165183
+ list(Command)
184+
+ list(TextCharacters)
166185
)
167186

168187
NON_RL_ACTIONS = (
@@ -172,13 +191,11 @@ class Command(enum.IntEnum):
172191
Command.EXTCMD, # Potentially useful for some wizard actions.
173192
Command.EXTLIST,
174193
Command.GLANCE,
175-
Command.HELP,
176194
Command.HISTORY,
177195
Command.KNOWN, # Could potentially be useful.
178196
Command.KNOWNCLASS, # Could potentially be useful.
179197
Command.OPTIONS,
180198
Command.OVERVIEW, # Could potentially be useful.
181-
Command.PREVMSG, # Could potentially be useful.
182199
Command.TELEPORT,
183200
Command.QUIT,
184201
Command.REDRAW,
@@ -191,13 +208,15 @@ class Command(enum.IntEnum):
191208
)
192209

193210
_USEFUL_ACTIONS = list(ACTIONS)
194-
for action in NON_RL_ACTIONS:
211+
for action in NON_RL_ACTIONS + tuple(TextCharacters):
195212
_USEFUL_ACTIONS.remove(action)
213+
_USEFUL_ACTIONS.append(TextCharacters.SPACE)
196214
USEFUL_ACTIONS = tuple(_USEFUL_ACTIONS)
197215
del _USEFUL_ACTIONS
198216

199217
_ACTIONS_DICT = {}
200218
for enum_class in [
219+
TextCharacters,
201220
CompassDirection,
202221
CompassDirectionLonger,
203222
MiscDirection,

nle/scripts/play.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ def get_action(env, action_mode, is_raw_env):
4545
action = env.action_space.sample()
4646
else:
4747
action = random.choice(_ACTIONS)
48+
print(action)
4849
elif action_mode == "human":
4950
while True:
5051
with no_echo():
5152
ch = ord(os.read(0, 1))
52-
if ch in [nethack.C("c"), ord(b"q")]:
53+
if ch in [nethack.C("c")]:
5354
print("Received exit code {}. Aborting.".format(ch))
5455
return None
5556
try:
@@ -67,7 +68,18 @@ def get_action(env, action_mode, is_raw_env):
6768
return action
6869

6970

70-
def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, debug):
71+
def play(
72+
env,
73+
mode,
74+
ngames,
75+
max_steps,
76+
seeds,
77+
savedir,
78+
no_render,
79+
render_mode,
80+
print_frames_separately,
81+
**kwargs,
82+
):
7183
env_name = env
7284
is_raw_env = env_name == "raw"
7385

@@ -100,10 +112,15 @@ def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, d
100112
while True:
101113
if not no_render:
102114
if not is_raw_env:
103-
print("Previous reward:", reward)
104-
if action is not None:
105-
print("Previous action: %s" % repr(env._actions[action]))
115+
print("--------")
116+
print(f"Previous reward: {str(reward):64s}")
117+
act_str = repr(env._actions[action]) if action is not None else ""
118+
print(f"Previous action: {str(act_str):64s}")
119+
print("--------")
106120
env.render(render_mode)
121+
print("--------")
122+
if not print_frames_separately:
123+
print("\033[31A") # Go up 31 lines.
107124
else:
108125
print("Previous action:", action)
109126
_, chars, _, _, blstats, message, *_ = obs
@@ -114,6 +131,7 @@ def play(env, mode, ngames, max_steps, seeds, savedir, no_render, render_mode, d
114131
print(blstats)
115132

116133
action = get_action(env, mode, is_raw_env)
134+
117135
if action is None:
118136
break
119137

@@ -194,7 +212,7 @@ def main():
194212
parser.add_argument(
195213
"--max-steps",
196214
type=int,
197-
default=10000,
215+
default=1_000_000,
198216
help="Number of maximum steps per episode.",
199217
)
200218
parser.add_argument(
@@ -219,6 +237,12 @@ def main():
219237
choices=["human", "full", "ansi"],
220238
help="Render mode. Defaults to 'human'.",
221239
)
240+
parser.add_argument(
241+
"--print-frames-separately",
242+
"-p",
243+
action="store_true",
244+
help="Don't overwrite frames, print them all.",
245+
)
222246
flags = parser.parse_args()
223247

224248
if flags.debug:

0 commit comments

Comments
 (0)