Skip to content

Commit

Permalink
Add support for saving and loading simulation state to / from files (#…
Browse files Browse the repository at this point in the history
…1227)

* Factor out parts of simulate method

* Further refactoring of Simulation

* Add methods for saving and loading simulations

* Add initial test for simulation saving and loading

* Factor out and add additional simulation test checks

* Explicitly set logger output file when loading from pickle

Avoids deadlock on trying to acquire lock on loaded file handler

* Check next date in event queue before popping

Ensures event not lost in partial simulations

* Make pytest seed parameter session scoped

Allows use in fixtures with non-function scope

* Don't use next on counter in test check

Has side effect of mutating counter

* Refactor global constants to fixtures in simulation tests + additional tests

* Move logging configuration out of load_from_pickle

Better to be explicit

* Add test for exception when simulation past end date

* Add docstrings for new methods

* Add errors when running without initialising or initialising multiple times

* Add dill to dependencies

* Sort imports

* Fix fenceposting error in simulation end date

* Fix explicit comparison to type

* Add option to configure logging when loading from pickle

* Move check for open log file in close_output_file method

* Tidy up docstrings and type hints

* Remove use of configure_logging in test

* Update scenario to allow suspending and resuming

Co-authored-by: Asif Tamuri <[email protected]>

* Add utility function to merge log files

* Add test to check equality of parsed log files in suspend-resume

* Fix import sort order

* Update pinned dill version to 0.3.8

* Adding log message when loading suspended simulation

* Adding log message when saving suspended simulation

* Increase simulation pop size and duration in test

* Avoid reading in log files to be merged all at once

* Add tests for merge_log_files function

* Fix import order sorting

* Fix import order sorting (second attempt)

---------

Co-authored-by: Asif Tamuri <[email protected]>
  • Loading branch information
matt-graham and tamuri authored Sep 26, 2024
1 parent 5b4f328 commit a23e57d
Show file tree
Hide file tree
Showing 10 changed files with 825 additions and 172 deletions.
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@
'exclude-members': '__dict__, name, rng, sim' # , read_parameters',
}

# Include both class level and __init__ docstring content in class documentation
autoclass_content = 'both'

# The checker can't see private repos
linkcheck_ignore = ['^https://github.com/UCL/TLOmodel.*',
'https://www.who.int/bulletin/volumes/88/8/09-068213/en/nn']
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ dependencies = [
"azure-identity",
"azure-keyvault",
"azure-storage-file-share",
# For saving and loading simulation state
"dill",
]
description = "Thanzi la Onse Epidemiology Model"
dynamic = ["version"]
Expand Down
3 changes: 3 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ cryptography==41.0.3
# pyjwt
cycler==0.11.0
# via matplotlib
dill==0.3.8
# via tlo (pyproject.toml)
et-xmlfile==1.1.0
# via openpyxl
fonttools==4.42.1
Expand Down Expand Up @@ -112,6 +114,7 @@ pyjwt[crypto]==2.8.0
# via
# adal
# msal
# pyjwt
pyparsing==3.1.1
# via matplotlib
pyshp==2.3.1
Expand Down
35 changes: 9 additions & 26 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --extra=dev --output-file=requirements/dev.txt
Expand Down Expand Up @@ -61,7 +61,9 @@ colorama==0.4.6
contourpy==1.1.1
# via matplotlib
coverage[toml]==7.3.1
# via pytest-cov
# via
# coverage
# pytest-cov
cryptography==41.0.3
# via
# adal
Expand All @@ -72,14 +74,14 @@ cryptography==41.0.3
# pyjwt
cycler==0.11.0
# via matplotlib
dill==0.3.7
# via pylint
dill==0.3.8
# via
# pylint
# tlo (pyproject.toml)
distlib==0.3.7
# via virtualenv
et-xmlfile==1.1.0
# via openpyxl
exceptiongroup==1.1.3
# via pytest
execnet==2.0.2
# via pytest-xdist
filelock==3.12.4
Expand All @@ -94,10 +96,6 @@ gitpython==3.1.36
# via tlo (pyproject.toml)
idna==3.4
# via requests
importlib-metadata==6.8.0
# via build
importlib-resources==6.1.1
# via matplotlib
iniconfig==2.0.0
# via pytest
isodate==0.6.1
Expand Down Expand Up @@ -172,6 +170,7 @@ pyjwt[crypto]==2.8.0
# via
# adal
# msal
# pyjwt
pylint==3.0.1
# via tlo (pyproject.toml)
pyparsing==3.1.1
Expand Down Expand Up @@ -221,29 +220,17 @@ smmap==5.0.1
# via gitdb
squarify==0.4.3
# via tlo (pyproject.toml)
tomli==2.0.1
# via
# build
# coverage
# pip-tools
# pylint
# pyproject-api
# pyproject-hooks
# pytest
# tox
tomlkit==0.12.1
# via pylint
tox==4.11.3
# via tlo (pyproject.toml)
typing-extensions==4.8.0
# via
# astroid
# azure-core
# azure-keyvault-certificates
# azure-keyvault-keys
# azure-keyvault-secrets
# azure-storage-file-share
# pylint
tzdata==2023.3
# via pandas
urllib3==2.0.4
Expand All @@ -254,10 +241,6 @@ virtualenv==20.24.5
# tox
wheel==0.41.2
# via pip-tools
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources

# The following packages are considered to be unsafe in a requirements file:
# pip
Expand Down
37 changes: 36 additions & 1 deletion src/tlo/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
General utility functions for TLO analysis
"""
import fileinput
import gzip
import json
import os
Expand Down Expand Up @@ -86,6 +87,40 @@ def parse_log_file(log_filepath, level: int = logging.INFO):
return LogsDict({name: handle.name for name, handle in module_name_to_filehandle.items()}, level)


def merge_log_files(log_path_1: Path, log_path_2: Path, output_path: Path) -> None:
"""Merge two log files, skipping any repeated header lines.
:param log_path_1: Path to first log file to merge. Records from this log file will
appear first in merged log file.
:param log_path_2: Path to second log file to merge. Records from this log file will
appear after those in log file at `log_path_1` and any header lines in this file
which are also present in log file at `log_path_1` will be skipped.
:param output_path: Path to write merged log file to. Must not be one of `log_path_1`
or `log_path_2` as data is read from files while writing to this path.
"""
if output_path == log_path_1 or output_path == log_path_2:
msg = "output_path must not be equal to log_path_1 or log_path_2"
raise ValueError(msg)
with fileinput.input(files=(log_path_1, log_path_2), mode="r") as log_lines:
with output_path.open("w") as output_file:
written_header_lines = {}
for log_line in log_lines:
log_data = json.loads(log_line)
if "type" in log_data and log_data["type"] == "header":
if log_data["uuid"] in written_header_lines:
previous_header_line = written_header_lines[log_data["uuid"]]
if previous_header_line == log_line:
continue
else:
msg = (
"Inconsistent header lines with matching UUIDs found when merging logs:\n"
f"{previous_header_line}\n{log_line}\n"
)
raise RuntimeError(msg)
written_header_lines[log_data["uuid"]] = log_line
output_file.write(log_line)


def write_log_to_excel(filename, log_dataframes):
"""Takes the output of parse_log_file() and creates an Excel file from dataframes"""
metadata = list()
Expand Down Expand Up @@ -1131,7 +1166,7 @@ def get_parameters_for_status_quo() -> Dict:
"equip_availability": "all", # <--- NB. Existing calibration is assuming all equipment is available
},
}

def get_parameters_for_standard_mode2_runs() -> Dict:
"""
Returns a dictionary of parameters and their updated values to indicate
Expand Down
73 changes: 61 additions & 12 deletions src/tlo/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def draw_parameters(self, draw_number, rng):

from tlo import Date, Simulation, logging
from tlo.analysis.utils import parse_log_file
from tlo.util import str_to_pandas_date

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -141,6 +142,16 @@ def parse_arguments(self, extra_arguments: List[str]) -> None:
self.arguments = extra_arguments

parser = argparse.ArgumentParser()
parser.add_argument(
"--resume-simulation",
type=str,
help="Directory containing suspended state files to resume simulation from",
)
parser.add_argument(
"--suspend-date",
type=str_to_pandas_date,
help="Date to suspend the simulation at",
)

# add arguments from the subclass
self.add_arguments(parser)
Expand Down Expand Up @@ -382,20 +393,58 @@ def run_sample_by_number(self, output_directory, draw_number, sample_number):
sample = self.get_sample(draw, sample_number)
log_config = self.scenario.get_log_config(output_directory)

logger.info(key="message", data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}")

sim = Simulation(
start_date=self.scenario.start_date,
seed=sample["simulation_seed"],
log_config=log_config
logger.info(
key="message",
data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}",
)
sim.register(*self.scenario.modules())

if sample["parameters"] is not None:
self.override_parameters(sim, sample["parameters"])

sim.make_initial_population(n=self.scenario.pop_size)
sim.simulate(end_date=self.scenario.end_date)
# if user has specified a restore simulation, we load it from a pickle file
if (
hasattr(self.scenario, "resume_simulation")
and self.scenario.resume_simulation is not None
):
suspended_simulation_path = (
Path(self.scenario.resume_simulation)
/ str(draw_number)
/ str(sample_number)
/ "suspended_simulation.pickle"
)
logger.info(
key="message",
data=f"Loading pickled suspended simulation from {suspended_simulation_path}",
)
sim = Simulation.load_from_pickle(pickle_path=suspended_simulation_path, log_config=log_config)
else:
sim = Simulation(
start_date=self.scenario.start_date,
seed=sample["simulation_seed"],
log_config=log_config,
)
sim.register(*self.scenario.modules())

if sample["parameters"] is not None:
self.override_parameters(sim, sample["parameters"])

sim.make_initial_population(n=self.scenario.pop_size)
sim.initialise(end_date=self.scenario.end_date)

# if user has specified a suspend date, we run the simulation to that date and
# save it to a pickle file
if (
hasattr(self.scenario, "suspend_date")
and self.scenario.suspend_date is not None
):
sim.run_simulation_to(to_date=self.scenario.suspend_date)
suspended_simulation_path = Path(log_config["directory"]) / "suspended_simulation.pickle"
sim.save_to_pickle(pickle_path=suspended_simulation_path)
sim.close_output_file()
logger.info(
key="message",
data=f"Simulation suspended at {self.scenario.suspend_date} and saved to {suspended_simulation_path}",
)
else:
sim.run_simulation_to(to_date=self.scenario.end_date)
sim.finalise()

if sim.log_filepath is not None:
outputs = parse_log_file(sim.log_filepath)
Expand Down
Loading

0 comments on commit a23e57d

Please sign in to comment.