Skip to content

Commit a23e57d

Browse files
matt-grahamtamuri
andauthored
Add support for saving and loading simulation state to / from files (#1227)
* Factor out parts of simulate method * Further refactoring of Simulation * Add methods for saving and loading simulations * Add initial test for simulation saving and loading * Factor out and add additional simulation test checks * Explicitly set logger output file when loading from pickle Avoids deadlock on trying to acquire lock on loaded file handler * Check next date in event queue before popping Ensures event not lost in partial simulations * Make pytest seed parameter session scoped Allows use in fixtures with non-function scope * Don't use next on counter in test check Has side effect of mutating counter * Refactor global constants to fixtures in simulation tests + additional tests * Move logging configuration out of load_from_pickle Better to be explicit * Add test for exception when simulation past end date * Add docstrings for new methods * Add errors when running without initialising or initialising multiple times * Add dill to dependencies * Sort imports * Fix fenceposting error in simulation end date * Fix explicit comparison to type * Add option to configure logging when loading from pickle * Move check for open log file in close_output_file method * Tidy up docstrings and type hints * Remove use of configure_logging in test * Update scenario to allow suspending and resuming Co-authored-by: Asif Tamuri <[email protected]> * Add utility function to merge log files * Add test to check equality of parsed log files in suspend-resume * Fix import sort order * Update pinned dill version to 0.3.8 * Adding log message when loading suspended simulation * Adding log message when saving suspended simulation * Increase simulation pop size and duration in test * Avoid reading in log files to be merged all at once * Add tests for merge_log_files function * Fix import order sorting * Fix import order sorting (second attempt) --------- Co-authored-by: Asif Tamuri <[email protected]>
1 parent 5b4f328 commit a23e57d

File tree

10 files changed

+825
-172
lines changed

10 files changed

+825
-172
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@
106106
'exclude-members': '__dict__, name, rng, sim' # , read_parameters',
107107
}
108108

109+
# Include both class level and __init__ docstring content in class documentation
110+
autoclass_content = 'both'
111+
109112
# The checker can't see private repos
110113
linkcheck_ignore = ['^https://github.com/UCL/TLOmodel.*',
111114
'https://www.who.int/bulletin/volumes/88/8/09-068213/en/nn']

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ dependencies = [
3333
"azure-identity",
3434
"azure-keyvault",
3535
"azure-storage-file-share",
36+
# For saving and loading simulation state
37+
"dill",
3638
]
3739
description = "Thanzi la Onse Epidemiology Model"
3840
dynamic = ["version"]

requirements/base.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ cryptography==41.0.3
5656
# pyjwt
5757
cycler==0.11.0
5858
# via matplotlib
59+
dill==0.3.8
60+
# via tlo (pyproject.toml)
5961
et-xmlfile==1.1.0
6062
# via openpyxl
6163
fonttools==4.42.1
@@ -112,6 +114,7 @@ pyjwt[crypto]==2.8.0
112114
# via
113115
# adal
114116
# msal
117+
# pyjwt
115118
pyparsing==3.1.1
116119
# via matplotlib
117120
pyshp==2.3.1

requirements/dev.txt

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with Python 3.8
2+
# This file is autogenerated by pip-compile with Python 3.11
33
# by the following command:
44
#
55
# pip-compile --extra=dev --output-file=requirements/dev.txt
@@ -61,7 +61,9 @@ colorama==0.4.6
6161
contourpy==1.1.1
6262
# via matplotlib
6363
coverage[toml]==7.3.1
64-
# via pytest-cov
64+
# via
65+
# coverage
66+
# pytest-cov
6567
cryptography==41.0.3
6668
# via
6769
# adal
@@ -72,14 +74,14 @@ cryptography==41.0.3
7274
# pyjwt
7375
cycler==0.11.0
7476
# via matplotlib
75-
dill==0.3.7
76-
# via pylint
77+
dill==0.3.8
78+
# via
79+
# pylint
80+
# tlo (pyproject.toml)
7781
distlib==0.3.7
7882
# via virtualenv
7983
et-xmlfile==1.1.0
8084
# via openpyxl
81-
exceptiongroup==1.1.3
82-
# via pytest
8385
execnet==2.0.2
8486
# via pytest-xdist
8587
filelock==3.12.4
@@ -94,10 +96,6 @@ gitpython==3.1.36
9496
# via tlo (pyproject.toml)
9597
idna==3.4
9698
# via requests
97-
importlib-metadata==6.8.0
98-
# via build
99-
importlib-resources==6.1.1
100-
# via matplotlib
10199
iniconfig==2.0.0
102100
# via pytest
103101
isodate==0.6.1
@@ -172,6 +170,7 @@ pyjwt[crypto]==2.8.0
172170
# via
173171
# adal
174172
# msal
173+
# pyjwt
175174
pylint==3.0.1
176175
# via tlo (pyproject.toml)
177176
pyparsing==3.1.1
@@ -221,29 +220,17 @@ smmap==5.0.1
221220
# via gitdb
222221
squarify==0.4.3
223222
# via tlo (pyproject.toml)
224-
tomli==2.0.1
225-
# via
226-
# build
227-
# coverage
228-
# pip-tools
229-
# pylint
230-
# pyproject-api
231-
# pyproject-hooks
232-
# pytest
233-
# tox
234223
tomlkit==0.12.1
235224
# via pylint
236225
tox==4.11.3
237226
# via tlo (pyproject.toml)
238227
typing-extensions==4.8.0
239228
# via
240-
# astroid
241229
# azure-core
242230
# azure-keyvault-certificates
243231
# azure-keyvault-keys
244232
# azure-keyvault-secrets
245233
# azure-storage-file-share
246-
# pylint
247234
tzdata==2023.3
248235
# via pandas
249236
urllib3==2.0.4
@@ -254,10 +241,6 @@ virtualenv==20.24.5
254241
# tox
255242
wheel==0.41.2
256243
# via pip-tools
257-
zipp==3.17.0
258-
# via
259-
# importlib-metadata
260-
# importlib-resources
261244

262245
# The following packages are considered to be unsafe in a requirements file:
263246
# pip

src/tlo/analysis/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
General utility functions for TLO analysis
33
"""
4+
import fileinput
45
import gzip
56
import json
67
import os
@@ -86,6 +87,40 @@ def parse_log_file(log_filepath, level: int = logging.INFO):
8687
return LogsDict({name: handle.name for name, handle in module_name_to_filehandle.items()}, level)
8788

8889

90+
def merge_log_files(log_path_1: Path, log_path_2: Path, output_path: Path) -> None:
91+
"""Merge two log files, skipping any repeated header lines.
92+
93+
:param log_path_1: Path to first log file to merge. Records from this log file will
94+
appear first in merged log file.
95+
:param log_path_2: Path to second log file to merge. Records from this log file will
96+
appear after those in log file at `log_path_1` and any header lines in this file
97+
which are also present in log file at `log_path_1` will be skipped.
98+
:param output_path: Path to write merged log file to. Must not be one of `log_path_1`
99+
or `log_path_2` as data is read from files while writing to this path.
100+
"""
101+
if output_path == log_path_1 or output_path == log_path_2:
102+
msg = "output_path must not be equal to log_path_1 or log_path_2"
103+
raise ValueError(msg)
104+
with fileinput.input(files=(log_path_1, log_path_2), mode="r") as log_lines:
105+
with output_path.open("w") as output_file:
106+
written_header_lines = {}
107+
for log_line in log_lines:
108+
log_data = json.loads(log_line)
109+
if "type" in log_data and log_data["type"] == "header":
110+
if log_data["uuid"] in written_header_lines:
111+
previous_header_line = written_header_lines[log_data["uuid"]]
112+
if previous_header_line == log_line:
113+
continue
114+
else:
115+
msg = (
116+
"Inconsistent header lines with matching UUIDs found when merging logs:\n"
117+
f"{previous_header_line}\n{log_line}\n"
118+
)
119+
raise RuntimeError(msg)
120+
written_header_lines[log_data["uuid"]] = log_line
121+
output_file.write(log_line)
122+
123+
89124
def write_log_to_excel(filename, log_dataframes):
90125
"""Takes the output of parse_log_file() and creates an Excel file from dataframes"""
91126
metadata = list()
@@ -1131,7 +1166,7 @@ def get_parameters_for_status_quo() -> Dict:
11311166
"equip_availability": "all", # <--- NB. Existing calibration is assuming all equipment is available
11321167
},
11331168
}
1134-
1169+
11351170
def get_parameters_for_standard_mode2_runs() -> Dict:
11361171
"""
11371172
Returns a dictionary of parameters and their updated values to indicate

src/tlo/scenario.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def draw_parameters(self, draw_number, rng):
7373

7474
from tlo import Date, Simulation, logging
7575
from tlo.analysis.utils import parse_log_file
76+
from tlo.util import str_to_pandas_date
7677

7778
logger = logging.getLogger(__name__)
7879
logger.setLevel(logging.INFO)
@@ -141,6 +142,16 @@ def parse_arguments(self, extra_arguments: List[str]) -> None:
141142
self.arguments = extra_arguments
142143

143144
parser = argparse.ArgumentParser()
145+
parser.add_argument(
146+
"--resume-simulation",
147+
type=str,
148+
help="Directory containing suspended state files to resume simulation from",
149+
)
150+
parser.add_argument(
151+
"--suspend-date",
152+
type=str_to_pandas_date,
153+
help="Date to suspend the simulation at",
154+
)
144155

145156
# add arguments from the subclass
146157
self.add_arguments(parser)
@@ -382,20 +393,58 @@ def run_sample_by_number(self, output_directory, draw_number, sample_number):
382393
sample = self.get_sample(draw, sample_number)
383394
log_config = self.scenario.get_log_config(output_directory)
384395

385-
logger.info(key="message", data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}")
386-
387-
sim = Simulation(
388-
start_date=self.scenario.start_date,
389-
seed=sample["simulation_seed"],
390-
log_config=log_config
396+
logger.info(
397+
key="message",
398+
data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}",
391399
)
392-
sim.register(*self.scenario.modules())
393400

394-
if sample["parameters"] is not None:
395-
self.override_parameters(sim, sample["parameters"])
396-
397-
sim.make_initial_population(n=self.scenario.pop_size)
398-
sim.simulate(end_date=self.scenario.end_date)
401+
# if user has specified a restore simulation, we load it from a pickle file
402+
if (
403+
hasattr(self.scenario, "resume_simulation")
404+
and self.scenario.resume_simulation is not None
405+
):
406+
suspended_simulation_path = (
407+
Path(self.scenario.resume_simulation)
408+
/ str(draw_number)
409+
/ str(sample_number)
410+
/ "suspended_simulation.pickle"
411+
)
412+
logger.info(
413+
key="message",
414+
data=f"Loading pickled suspended simulation from {suspended_simulation_path}",
415+
)
416+
sim = Simulation.load_from_pickle(pickle_path=suspended_simulation_path, log_config=log_config)
417+
else:
418+
sim = Simulation(
419+
start_date=self.scenario.start_date,
420+
seed=sample["simulation_seed"],
421+
log_config=log_config,
422+
)
423+
sim.register(*self.scenario.modules())
424+
425+
if sample["parameters"] is not None:
426+
self.override_parameters(sim, sample["parameters"])
427+
428+
sim.make_initial_population(n=self.scenario.pop_size)
429+
sim.initialise(end_date=self.scenario.end_date)
430+
431+
# if user has specified a suspend date, we run the simulation to that date and
432+
# save it to a pickle file
433+
if (
434+
hasattr(self.scenario, "suspend_date")
435+
and self.scenario.suspend_date is not None
436+
):
437+
sim.run_simulation_to(to_date=self.scenario.suspend_date)
438+
suspended_simulation_path = Path(log_config["directory"]) / "suspended_simulation.pickle"
439+
sim.save_to_pickle(pickle_path=suspended_simulation_path)
440+
sim.close_output_file()
441+
logger.info(
442+
key="message",
443+
data=f"Simulation suspended at {self.scenario.suspend_date} and saved to {suspended_simulation_path}",
444+
)
445+
else:
446+
sim.run_simulation_to(to_date=self.scenario.end_date)
447+
sim.finalise()
399448

400449
if sim.log_filepath is not None:
401450
outputs = parse_log_file(sim.log_filepath)

0 commit comments

Comments
 (0)