Skip to content

Commit d4c0f0b

Browse files
committed
add saving and loading the game with simple test
1 parent 2dba5de commit d4c0f0b

File tree

6 files changed

+93
-19
lines changed

6 files changed

+93
-19
lines changed

include/nledl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ void nle_end(nledl_ctx *);
2828
void nle_set_seed(nledl_ctx *, unsigned long, unsigned long, char);
2929
void nle_get_seed(nledl_ctx *, unsigned long *, unsigned long *, char *);
3030

31+
int nle_save(nledl_ctx *);
32+
3133
#endif /* NLEDL_H */

nle/env/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def __init__(
198198
allow_all_modes=False,
199199
spawn_monsters=True,
200200
render_mode="human",
201+
gamesavedir=None,
202+
gameloaddir=None,
201203
):
202204
"""Constructs a new NLE environment.
203205
@@ -317,6 +319,8 @@ def __init__(
317319
wizard=wizard,
318320
spawn_monsters=spawn_monsters,
319321
scoreprefix=scoreprefix,
322+
gamesavedir=gamesavedir,
323+
gameloaddir=gameloaddir,
320324
)
321325
self._close_nethack = weakref.finalize(self, self.nethack.close)
322326

@@ -545,6 +549,9 @@ def render(self):
545549

546550
return "\nInvalid render mode: " + mode
547551

552+
def save(self, gamesavedir=None):
553+
return self.nethack.save(gamesavedir=gamesavedir)
554+
548555
def __repr__(self):
549556
return "<%s>" % self.__class__.__name__
550557

nle/nethack/nethack.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,20 @@ def _new_dl_linux(vardir):
9696

9797
def _new_dl(vardir):
9898
"""Creates a copied .so file to allow for multiple independent NLE instances"""
99-
if sys.platform == "linux":
100-
return _new_dl_linux(vardir)
99+
# if sys.platform == "linux":
100+
# return _new_dl_linux(vardir)
101101

102102
# MacOS has no memfd_create or O_TMPFILE. Using /dev/fd/{FD} as an argument
103103
# to dlopen doesn't work after unlinking from the file system. So let's copy
104104
# instead and hope vardir gets properly deleted at some point.
105-
dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir)
105+
# dl = tempfile.NamedTemporaryFile(suffix="libnethack.so", dir=vardir)
106+
# shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile.
107+
# return dl, dl.name
108+
109+
dlpath = os.path.join(vardir, "libnethack.so")
110+
dl = open(dlpath, "w")
106111
shutil.copyfile(DLPATH, dl.name) # Might use fcopyfile.
107-
return dl, dl.name
112+
return dl, dlpath
108113

109114

110115
def _close(pynethack, dl, tempdir, warn=True):
@@ -168,7 +173,11 @@ def __init__(
168173
hackdir=HACKDIR,
169174
spawn_monsters=True,
170175
scoreprefix="",
176+
gamesavedir=None,
177+
gameloaddir=None,
171178
):
179+
self.gamesavedir = gamesavedir
180+
self.gameloaddir = gameloaddir
172181
self._copy = copy
173182

174183
if not os.path.exists(hackdir) or not os.path.exists(
@@ -182,24 +191,31 @@ def __init__(
182191
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
183192
self._vardir = self._tempdir.name
184193

185-
# Symlink a nhdat.
186-
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))
194+
if self.gameloaddir:
195+
# restore files (save) from directory
196+
shutil.copytree(self.gameloaddir, self._vardir, dirs_exist_ok=True)
187197

188-
# Touch files, so lock_file() in files.c passes.
189-
for fn in ["perm", "record", "logfile"]:
190-
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
191-
if scoreprefix:
192-
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
198+
self.dlpath = os.path.join(self._vardir, "libnethack.so")
199+
self._dl = open(self.dlpath, "r")
193200
else:
194-
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))
201+
# Symlink a nhdat.
202+
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))
203+
204+
# Touch files, so lock_file() in files.c passes.
205+
for fn in ["perm", "record", "logfile"]:
206+
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
207+
if scoreprefix:
208+
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
209+
else:
210+
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))
195211

196-
os.mkdir(os.path.join(self._vardir, "save"))
212+
os.mkdir(os.path.join(self._vardir, "save"))
197213

198-
# An assortment of hacks:
199-
# Copy our .so into self._vardir to load several copies of the dl.
200-
# (Or use a memfd_create hack to create a file that gets deleted on
201-
# process exit.)
202-
self._dl, self.dlpath = _new_dl(self._vardir)
214+
# An assortment of hacks:
215+
# Copy our .so into self._vardir to load several copies of the dl.
216+
# (Or use a memfd_create hack to create a file that gets deleted on
217+
# process exit.)
218+
self._dl, self.dlpath = _new_dl(self._vardir)
203219

204220
# Finalize even when the rest of this constructor fails.
205221
self._finalizer = weakref.finalize(self, _close, None, self._dl, self._tempdir)
@@ -323,3 +339,15 @@ def in_normal_game(self):
323339

324340
def how_done(self):
325341
return self._pynethack.how_done()
342+
343+
def save(self, gamesavedir=None):
344+
if gamesavedir:
345+
savedir = gamesavedir
346+
else:
347+
savedir = self.gamesavedir
348+
349+
assert savedir is not None
350+
351+
success = self._pynethack.save()
352+
shutil.copytree(self._vardir, savedir, dirs_exist_ok=True)
353+
return success

nle/tests/test_envs.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,26 @@ def test_render_ansi(self, env_name, rollout_len):
340340
assert isinstance(output, str)
341341
assert len(output.replace("\n", "")) == np.prod(nle.env.DUNGEON_SHAPE)
342342

343-
343+
def test_save_and_load(self, env_name, rollout_len):
344+
with tempfile.TemporaryDirectory() as gamesavedir:
345+
env = gym.make(env_name, gamesavedir=gamesavedir)
346+
347+
obs = env.reset()
348+
for _ in range(rollout_len):
349+
action = env.action_space.sample()
350+
obs, _, done, _ = env.step(action)
351+
if done:
352+
obs = env.reset()
353+
354+
env.save()
355+
356+
env = gym.make(env_name, gameloaddir=gamesavedir)
357+
obsload = env.reset()
358+
359+
assert (obsload["blstats"] == obs["blstats"]).all()
360+
assert (obsload["glyphs"] == obs["glyphs"]).all()
361+
362+
344363
class TestGymDynamics:
345364
"""Tests a few game dynamics."""
346365

sys/unix/nledl.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,16 @@ nle_get_seed(nledl_ctx *nledl, unsigned long *core, unsigned long *disp,
150150
get_seed(nledl->nle_ctx, core, disp, reseed);
151151
}
152152
#endif
153+
154+
155+
int
156+
nle_save(nledl_ctx *nledl)
157+
{
158+
int success;
159+
void *(*dosave0)();
160+
161+
dosave0 = dlsym(nledl->dlhandle, "dosave0");
162+
success = dosave0();
163+
164+
return success;
165+
}

win/rl/pynethack.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ checked_conversion(py::handle h, const std::vector<ssize_t> &shape)
101101
class Nethack
102102
{
103103
public:
104+
int save() {
105+
return nle_save(nle_);
106+
}
107+
104108
Nethack(std::string dlpath, std::string ttyrec, std::string hackdir,
105109
std::string nethackoptions, bool spawn_monsters,
106110
std::string scoreprefix)
@@ -404,6 +408,7 @@ PYBIND11_MODULE(_pynethack, m)
404408
.def("get_seeds", &Nethack::get_seeds)
405409
.def("in_normal_game", &Nethack::in_normal_game)
406410
.def("how_done", &Nethack::how_done)
411+
.def("save", &Nethack::save)
407412
.def("set_wizkit", &Nethack::set_wizkit);
408413

409414
py::module mn = m.def_submodule(

0 commit comments

Comments
 (0)