Skip to content
This repository was archived by the owner on May 6, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ endif()

message(STATUS "Building nle backend version: ${NLE_VERSION}")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

# We use this to decide where the root of the nle/ package is. Normally it
Expand Down Expand Up @@ -99,7 +100,7 @@ target_link_directories(nethack PUBLIC /usr/local/lib)
target_link_libraries(nethack PUBLIC m fcontext bz2)

# dlopen wrapper library
add_library(nethackdl STATIC "sys/unix/nledl.c")
add_library(nethackdl STATIC "sys/unix/nledl.c" "sys/unix/nleshared.cc")
target_include_directories(nethackdl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(nethackdl PUBLIC dl)

Expand Down
14 changes: 5 additions & 9 deletions include/nledl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,9 @@
#include "nleobs.h"

/* TODO: Don't call this nle_ctx_t as well. */
typedef struct nledl_ctx {
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*step)(void *, nle_obs *);
FILE *ttyrec;
} nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *);
typedef struct nledl_ctx nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *, int shared);
nle_ctx_t *nle_step(nle_ctx_t *, nle_obs *);

void nle_reset(nle_ctx_t *, nle_obs *, FILE *, nle_seeds_init_t *);
Expand All @@ -27,4 +21,6 @@ void nle_end(nle_ctx_t *);
void nle_set_seed(nle_ctx_t *, unsigned long, unsigned long, char);
void nle_get_seed(nle_ctx_t *, unsigned long *, unsigned long *, char *);

int nle_supports_shared(void);

#endif /* NLEDL_H */
72 changes: 45 additions & 27 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def _set_env_vars(options, hackdir, wizkit=None):
# which should allow several instances of this. On MacOS, that seems
# a tough call.
class Nethack:
_instances = 0

def __init__(
self,
observation_keys=OBSERVATION_DESC.keys(),
Expand All @@ -102,25 +100,38 @@ def __init__(
"Couldn't find NetHack installation at '%s'." % hackdir
)

# Create a HACKDIR for us.
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name
self.shared = False

if _pynethack.supports_shared():
# "shared" mode does some hacky things to enable using a
# shared libnethack.so, prevents writing to any files, and does
# not chdir.
self.shared = True
dlpath = DLPATH
self._hackdir = hackdir
else:

# Create a HACKDIR for us.
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name

self._hackdir = self._vardir

# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()
# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()

# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))
# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))

# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)
# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)

if options is None:
options = NETHACKOPTIONS
Expand All @@ -129,10 +140,10 @@ def __init__(
self._options.append("playmode:debug")
self._wizard = wizard

_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
self._ttyrec = ttyrec

self._pynethack = _pynethack.Nethack(dlpath, ttyrec)
self._pynethack = _pynethack.Nethack(dlpath, ttyrec, self.shared)

self._obs_buffers = {}

Expand All @@ -154,6 +165,11 @@ def step(self, action):
return self._step_return(), self._pynethack.done()

def _write_wizkit_file(self, wizkit_items):
if self._vardir is None:
raise RuntimeError(
"FIXME: shared wizkit: can't write to HACKDIR as "
"it is a shared directory"
)
# TODO ideally we need to check the validity of the requested items
with open(os.path.join(self._vardir, WIZKIT_FNAME), "w") as f:
for item in wizkit_items:
Expand All @@ -164,9 +180,9 @@ def reset(self, new_ttyrec=None, wizkit_items=None):
if not self._wizard:
raise ValueError("Set wizard=True to use the wizkit option.")
self._write_wizkit_file(wizkit_items)
_set_env_vars(self._options, self._vardir, wizkit=WIZKIT_FNAME)
_set_env_vars(self._options, self._hackdir, wizkit=WIZKIT_FNAME)
else:
_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
if new_ttyrec is None:
self._pynethack.reset()
else:
Expand All @@ -178,11 +194,13 @@ def reset(self, new_ttyrec=None, wizkit_items=None):

def close(self):
self._pynethack.close()
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
self._tempdir.cleanup()
if not self.shared:
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
if self._tempdir is not None:
self._tempdir.cleanup()

def set_initial_seeds(self, core, disp, reseed=False):
self._pynethack.set_initial_seeds(core, disp, reseed)
Expand Down
3 changes: 2 additions & 1 deletion nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_run_n_episodes(self, tmpdir, game, episodes=3):

nethackdir = tmpdir.chdir()

assert nethackdir.fnmatch("nle*")
if not game.shared:
assert nethackdir.fnmatch("nle*")
assert tmpdir.ensure("nle.ttyrec")

if mean_sps < 15000:
Expand Down
156 changes: 100 additions & 56 deletions sys/unix/nledl.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,74 +6,113 @@

#include "nledl.h"

void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init)
{
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);
void *nleshared_open(const char *dlpath);
void nleshared_close(void *handle);
void nleshared_reset(void *handle);
void *nleshared_sym(void *handle, const char *symname);
void nleshared_set_current(void *handle);
int nleshared_supported(void);

typedef struct nledl_ctx {
void *shared;
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *);
void *(*step)(void *, nle_obs *);
void (*end)(void *);
FILE *ttyrec;
} nle_ctx_t;

if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
static void *
sym(nle_ctx_t *nledl, const char *name)
{
if (nledl->shared) {
return nleshared_sym(nledl->shared, name);
} else {
dlerror(); /* Clear any existing error */
void *r = dlsym(nledl->dlhandle, name);
char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
return r;
}
}

dlerror(); /* Clear any existing error */

void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *);
start = dlsym(nledl->dlhandle, "nle_start");
nledl->nle_ctx = start(obs, nledl->ttyrec, seed_init);

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init,
int shared)
{
nledl->shared = NULL;
if (shared) {
if (nleshared_supported()) {
nledl->shared = nleshared_open(nledl->dlpath);
nleshared_set_current(nledl->shared);
} else {
fprintf(stderr, "Shared mode not supported on this system!\n");
exit(EXIT_FAILURE);
}
} else {
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);
if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}
}

nledl->step = dlsym(nledl->dlhandle, "nle_step");
nledl->start = sym(nledl, "nle_start");
nledl->step = sym(nledl, "nle_step");
nledl->end = sym(nledl, "nle_end");

error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
nledl->nle_ctx = nledl->start(obs, nledl->ttyrec, seed_init);
}

void
nledl_close(nle_ctx_t *nledl)
{
void (*end)(void *);
if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->end(nledl->nle_ctx);

end = dlsym(nledl->dlhandle, "nle_end");
end(nledl->nle_ctx);
if (nledl->shared) {
nleshared_close(nledl->shared);
} else {
if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
}

if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
dlerror();
}

dlerror();
}

nle_ctx_t *
nle_start(const char *dlpath, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
nle_seeds_init_t *seed_init, int shared)
{
/* TODO: Consider getting ttyrec path from caller? */
struct nledl_ctx *nledl = malloc(sizeof(struct nledl_ctx));
nledl->ttyrec = ttyrec;
strncpy(nledl->dlpath, dlpath, sizeof(nledl->dlpath));

nledl_init(nledl, obs, seed_init);
nledl_init(nledl, obs, seed_init, shared);
return nledl;
};

nle_ctx_t *
nle_step(nle_ctx_t *nledl, nle_obs *obs)
{
if (!nledl || !nledl->dlhandle || !nledl->nle_ctx) {
if (!nledl || (!nledl->dlhandle && !nledl->shared) || !nledl->nle_ctx) {
fprintf(stderr, "Illegal nledl_ctx\n");
exit(EXIT_FAILURE);
}

if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->step(nledl->nle_ctx, obs);

return nledl;
Expand All @@ -85,14 +124,25 @@ void
nle_reset(nle_ctx_t *nledl, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
{
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init);
if (nledl->shared) {
if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->end(nledl->nle_ctx);
nleshared_reset(nledl->shared);
if (ttyrec)
nledl->ttyrec = ttyrec;
nledl->nle_ctx = nledl->start(obs, ttyrec, seed_init);
} else {
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init, 0);
}
}

void
Expand All @@ -108,13 +158,7 @@ nle_set_seed(nle_ctx_t *nledl, unsigned long core, unsigned long disp,
{
void (*set_seed)(void *, unsigned long, unsigned long, char);

set_seed = dlsym(nledl->dlhandle, "nle_set_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
set_seed = sym(nledl, "nle_set_seed");

set_seed(nledl->nle_ctx, core, disp, reseed);
}
Expand All @@ -125,16 +169,16 @@ nle_get_seed(nle_ctx_t *nledl, unsigned long *core, unsigned long *disp,
{
void (*get_seed)(void *, unsigned long *, unsigned long *, char *);

get_seed = dlsym(nledl->dlhandle, "nle_get_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
get_seed = sym(nledl, "nle_get_seed");

/* Careful here. NetHack has different ideas of what a boolean is
* than C++ (see global.h and SKIP_BOOLEAN). But one byte should be fine.
*/
get_seed(nledl->nle_ctx, core, disp, reseed);
}

int
nle_supports_shared(void)
{
return nleshared_supported();
}
Loading