Skip to content

Commit b04e9f4

Browse files
committed
change saving and loading
1 parent 81fe4b0 commit b04e9f4

File tree

4 files changed

+93
-52
lines changed

4 files changed

+93
-52
lines changed

nle/env/base.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ def __init__(
193193
allow_all_yn_questions=False,
194194
allow_all_modes=False,
195195
spawn_monsters=True,
196-
gamesavedir=None,
197-
gameloaddir=None,
198196
):
199197
"""Constructs a new NLE environment.
200198
@@ -310,8 +308,6 @@ def __init__(
310308
wizard=wizard,
311309
spawn_monsters=spawn_monsters,
312310
scoreprefix=scoreprefix,
313-
gamesavedir=gamesavedir,
314-
gameloaddir=gameloaddir,
315311
)
316312
self._close_nethack = weakref.finalize(self, self.nethack.close)
317313

@@ -521,8 +517,25 @@ def render(self, mode="human"):
521517

522518
return super().render(mode=mode)
523519

524-
def save(self, gamesavedir=None):
525-
return self.nethack.save(gamesavedir=gamesavedir)
520+
def save(self, gamesavedir):
521+
# save calls function `dosave0` (responsible for saving the game) and copies save to `gamesavedir`
522+
# we have to call this function explicitly because action combination ord(S) + ord(y) are overriden
523+
self.nethack.save(gamesavedir)
524+
525+
# we use reset to turn on the game, since saving turns it off
526+
# TODO: can we dump save files without turning off the game???
527+
self.last_observation = self.nethack.reset(self.ttyrec)
528+
529+
# get out of the menu after we load new game
530+
# TODO: we have one more problem, after resetting the game we have left all open menus (like inventory)
531+
# additionally we see new message "Welcome back to NetHack"
532+
self.nethack.step(ASCII_ESC)
533+
534+
def load(self, gameloaddir):
535+
self.last_observation = self.nethack.load(gameloaddir)
536+
537+
# get out of the menu after we load new game
538+
self.nethack.step(ASCII_ESC)
526539

527540
def __repr__(self):
528541
return "<%s>" % self.__class__.__name__

nle/nethack/nethack.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,7 @@ def __init__(
166166
hackdir=HACKDIR,
167167
spawn_monsters=True,
168168
scoreprefix="",
169-
gamesavedir=None,
170-
gameloaddir=None,
171169
):
172-
self.gamesavedir = gamesavedir
173-
self.gameloaddir = gameloaddir
174170
self._copy = copy
175171

176172
if not os.path.exists(hackdir) or not os.path.exists(
@@ -184,31 +180,24 @@ def __init__(
184180
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
185181
self._vardir = self._tempdir.name
186182

187-
if self.gameloaddir:
188-
# restore files (save) from directory
189-
shutil.copytree(self.gameloaddir, self._vardir, dirs_exist_ok=True)
183+
# Symlink a nhdat.
184+
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))
190185

191-
self.dlpath = os.path.join(self._vardir, "libnethack.so")
192-
self._dl = open(self.dlpath, "r")
186+
# Touch files, so lock_file() in files.c passes.
187+
for fn in ["perm", "record", "logfile"]:
188+
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
189+
if scoreprefix:
190+
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
193191
else:
194-
# Symlink a nhdat.
195-
os.symlink(os.path.join(hackdir, "nhdat"), os.path.join(self._vardir, "nhdat"))
196-
197-
# Touch files, so lock_file() in files.c passes.
198-
for fn in ["perm", "record", "logfile"]:
199-
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
200-
if scoreprefix:
201-
os.close(os.open(scoreprefix + "xlogfile", os.O_CREAT))
202-
else:
203-
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))
192+
os.close(os.open(os.path.join(self._vardir, "xlogfile"), os.O_CREAT))
204193

205-
os.mkdir(os.path.join(self._vardir, "save"))
194+
os.mkdir(os.path.join(self._vardir, "save"))
206195

207-
# An assortment of hacks:
208-
# Copy our .so into self._vardir to load several copies of the dl.
209-
# (Or use a memfd_create hack to create a file that gets deleted on
210-
# process exit.)
211-
self._dl, self.dlpath = _new_dl(self._vardir)
196+
# An assortment of hacks:
197+
# Copy our .so into self._vardir to load several copies of the dl.
198+
# (Or use a memfd_create hack to create a file that gets deleted on
199+
# process exit.)
200+
self._dl, self.dlpath = _new_dl(self._vardir)
212201

213202
# Finalize even when the rest of this constructor fails.
214203
self._finalizer = weakref.finalize(self, _close, None, self._dl, self._tempdir)
@@ -332,14 +321,19 @@ def in_normal_game(self):
332321
def how_done(self):
333322
return self._pynethack.how_done()
334323

335-
def save(self, gamesavedir=None):
336-
if gamesavedir:
337-
savedir = gamesavedir
338-
else:
339-
savedir = self.gamesavedir
340-
341-
assert savedir is not None
324+
def save(self, gamesavedir):
325+
assert gamesavedir is not None
326+
327+
self._pynethack.save()
328+
shutil.copytree(self._vardir, gamesavedir, dirs_exist_ok=True)
329+
self._pynethack.set_use_seed_init(True)
330+
331+
def load(self, gameloaddir):
332+
assert gameloaddir is not None
333+
334+
self._pynethack.end()
335+
shutil.copytree(gameloaddir, self._vardir, dirs_exist_ok=True)
336+
self._pynethack.set_use_seed_init(True)
337+
self._pynethack.start()
342338

343-
success = self._pynethack.save()
344-
shutil.copytree(self._vardir, savedir, dirs_exist_ok=True)
345-
return success
339+
return self._step_return()

sys/unix/nledl.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,7 @@ int
156156
nle_save(nledl_ctx *nledl)
157157
{
158158
int success;
159-
void *(*dosave0)();
160-
159+
int *(*dosave0)();
161160
dosave0 = dlsym(nledl->dlhandle, "dosave0");
162161
success = dosave0();
163162

win/rl/pynethack.cc

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,6 @@ checked_conversion(py::handle h, const std::vector<ssize_t> &shape)
101101
class Nethack
102102
{
103103
public:
104-
int save() {
105-
return nle_save(nle_);
106-
}
107-
108104
Nethack(std::string dlpath, std::string ttyrec, std::string hackdir,
109105
std::string nethackoptions, bool spawn_monsters,
110106
std::string scoreprefix)
@@ -205,6 +201,42 @@ class Nethack
205201
ttyrec_ = f;
206202
}
207203

204+
int save()
205+
{
206+
int success;
207+
success = nle_save(nle_);
208+
return success;
209+
}
210+
211+
void
212+
set_use_seed_init(bool use_seed_init)
213+
{
214+
use_seed_init_ = use_seed_init;
215+
}
216+
217+
void
218+
end()
219+
{
220+
// We use this function for loading the game from save
221+
// 1. end()
222+
// 2. we copy save to hackdir in `Nethack`
223+
// 3. start()
224+
if (nle_) {
225+
nle_end(nle_);
226+
nle_ = nullptr;
227+
}
228+
}
229+
230+
void
231+
start()
232+
{
233+
py::gil_scoped_release gil;
234+
235+
nle_ =
236+
nle_start(dlpath_.c_str(), &obs_, ttyrec_,
237+
use_seed_init_ ? &seed_init_ : nullptr, &settings_);
238+
}
239+
208240
void
209241
set_buffers(py::object glyphs, py::object chars, py::object colors,
210242
py::object specials, py::object blstats, py::object message,
@@ -282,7 +314,7 @@ class Nethack
282314
seed_init_.seeds[0] = core;
283315
seed_init_.seeds[1] = disp;
284316
seed_init_.reseed = reseed;
285-
use_seed_init = true;
317+
use_seed_init_ = true;
286318
#else
287319
throw std::runtime_error("Seeding not enabled");
288320
#endif
@@ -351,11 +383,11 @@ class Nethack
351383
if (!nle_) {
352384
nle_ =
353385
nle_start(dlpath_.c_str(), &obs_, ttyrec ? ttyrec : ttyrec_,
354-
use_seed_init ? &seed_init_ : nullptr, &settings_);
386+
use_seed_init_ ? &seed_init_ : nullptr, &settings_);
355387
} else
356388
nle_reset(nle_, &obs_, ttyrec,
357-
use_seed_init ? &seed_init_ : nullptr, &settings_);
358-
use_seed_init = false;
389+
use_seed_init_ ? &seed_init_ : nullptr, &settings_);
390+
use_seed_init_ = false;
359391

360392
if (obs_.done)
361393
throw std::runtime_error("NetHack done right after reset");
@@ -365,7 +397,7 @@ class Nethack
365397
nle_obs obs_;
366398
std::vector<py::object> py_buffers_;
367399
nle_seeds_init_t seed_init_;
368-
bool use_seed_init = false;
400+
bool use_seed_init_ = false;
369401
nledl_ctx *nle_ = nullptr;
370402
std::FILE *ttyrec_ = nullptr;
371403
nle_settings settings_;
@@ -409,6 +441,9 @@ PYBIND11_MODULE(_pynethack, m)
409441
.def("in_normal_game", &Nethack::in_normal_game)
410442
.def("how_done", &Nethack::how_done)
411443
.def("save", &Nethack::save)
444+
.def("start", &Nethack::start)
445+
.def("end", &Nethack::end)
446+
.def("set_use_seed_init", &Nethack::set_use_seed_init, py::arg("use_seed_init"))
412447
.def("set_wizkit", &Nethack::set_wizkit);
413448

414449
py::module mn = m.def_submodule(

0 commit comments

Comments
 (0)