diff --git a/src/scripts/profiling/scale_run.py b/src/scripts/profiling/scale_run.py index 735d1e7ba3..c0409a88ce 100644 --- a/src/scripts/profiling/scale_run.py +++ b/src/scripts/profiling/scale_run.py @@ -15,6 +15,7 @@ from tlo import Date, Simulation, logging from tlo.analysis.utils import parse_log_file as parse_log_file_fn from tlo.methods.fullmodel import fullmodel +from tlo.threaded_simulation import ThreadedSimulation _TLO_ROOT: Path = Path(__file__).parents[3].resolve() _TLO_OUTPUT_DIR: Path = (_TLO_ROOT / "outputs").resolve() @@ -55,6 +56,7 @@ def scale_run( ignore_warnings: bool = False, log_final_population_checksum: bool = True, profiler: Optional["Profiler"] = None, + n_threads: Optional[int] = 0, ) -> Simulation: if ignore_warnings: warnings.filterwarnings("ignore") @@ -74,12 +76,16 @@ def scale_run( "suppress_stdout": disable_log_output_to_stdout, } - sim = Simulation( - start_date=start_date, - seed=seed, - log_config=log_config, - show_progress_bar=show_progress_bar, - ) + sim_args = { + "start_date": start_date, + "seed": seed, + "log_config": log_config, + "show_progress_bar": show_progress_bar, + } + if n_threads: + sim = ThreadedSimulation(n_threads=n_threads, **sim_args) + else: + sim = Simulation(**sim_args) # Register the appropriate modules with the arguments passed through sim.register( @@ -269,6 +275,12 @@ def scale_run( ), action="store_true", ) + parser.add_argument( + "--n-threads", + help="Run a threaded simulation using the given number of threaded workers", + type=int, + default=0, + ) args = parser.parse_args() args_dict = vars(args) diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index 219b1b8a6f..d5b863352b 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -1,4 +1,5 @@ """The main simulation controller.""" +from __future__ import annotations import datetime import heapq @@ -6,7 +7,7 @@ import time from collections import OrderedDict from pathlib import Path -from typing import Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import numpy as np @@ -15,11 +16,14 @@ from tlo.events import Event, IndividualScopeEventMixin from tlo.progressbar import ProgressBar +if TYPE_CHECKING: + from tlo import Module + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class Simulation: +class _BaseSimulation: """The main control centre for a simulation. This class contains the core simulation logic and event queue, and holds @@ -41,8 +45,15 @@ class Simulation: The simulation-level random number generator. Note that individual modules also have their own random number generator with independent state. + + The `step_through_events` method is implemented by the `Simulation` and + `ThreadedSimulation` classes, which controls how the simulation events are + fired. """ + __name__: str = "_BaseSimulation" + modules: OrderedDict[str, Module] + def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = None, show_progress_bar=False): """Create a new simulation. @@ -63,6 +74,7 @@ def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = Non self.population: Optional[Population] = None self.show_progress_bar = show_progress_bar + self.progress_bar = None # logging if log_config is None: @@ -205,37 +217,18 @@ def simulate(self, *, end_date): for module in self.modules.values(): module.initialise_simulation(self) - progress_bar = None if self.show_progress_bar: - num_simulated_days = (end_date - self.start_date).days - progress_bar = ProgressBar( + num_simulated_days = (self.end_date - self.start_date).days + self.progress_bar = ProgressBar( num_simulated_days, "Simulation progress", unit="day") - progress_bar.start() - - while self.event_queue: - event, date = self.event_queue.next_event() + self.progress_bar.start() - if self.show_progress_bar: - simulation_day = (date - self.start_date).days - stats_dict = { - "date": str(date.date()), - "dataframe size": str(len(self.population.props)), - "queued events": str(len(self.event_queue)), - } - if "HealthSystem" in self.modules: - stats_dict["queued HSI events"] = str( - len(self.modules["HealthSystem"].HSI_EVENT_QUEUE) - ) - progress_bar.update(simulation_day, stats_dict=stats_dict) - - if date >= end_date: - self.date = end_date - break - self.fire_single_event(event, date) + # Run the simulation by firing events in the queue + self.step_through_events() # The simulation has ended. if self.show_progress_bar: - progress_bar.stop() + self.progress_bar.stop() for module in self.modules.values(): module.on_simulation_end() @@ -253,6 +246,17 @@ def simulate(self, *, end_date): finally: self.output_file.release() + def step_through_events(self) -> None: + """ + Method for forward-propagating the simulation, by executing + the scheduled events in the queue. This is overwritten by + inheriting classes. + """ + raise NotImplementedError( + f"{self.__name__} is not intended to be simulated, " + "use either Simulation or ThreadedSimulation to run a simulation." + ) + def schedule_event(self, event, date): """Schedule an event to happen on the given future date. @@ -269,15 +273,6 @@ def schedule_event(self, event, date): self.event_queue.schedule(event=event, date=date) - def fire_single_event(self, event, date): - """Fires the event once for the given date - - :param event: :py:class:`Event` to fire - :param date: the date of the event - """ - self.date = date - event.run() - def do_birth(self, mother_id): """Create a new child person. @@ -308,6 +303,23 @@ def find_events_for_person(self, person_id: int): return person_events + def update_progress_bar(self, new_date: Date): + """ + Updates the simulation's progress bar, if this is in use. + """ + if self.show_progress_bar: + simulation_day = (new_date - self.start_date).days + stats_dict = { + "date": str(new_date.date()), + "dataframe size": str(len(self.population.props)), + "queued events": str(len(self.event_queue)), + } + if "HealthSystem" in self.modules: + stats_dict["queued HSI events"] = str( + len(self.modules["HealthSystem"].HSI_EVENT_QUEUE) + ) + self.progress_bar.update(simulation_day, stats_dict=stats_dict) + class EventQueue: """A simple priority queue for events. @@ -329,7 +341,7 @@ def schedule(self, event, date): entry = (date, event.priority, next(self.counter), event) heapq.heappush(self.queue, entry) - def next_event(self): + def next_event(self) -> Tuple[Event, Date]: """Get the earliest event in the queue. :returns: an (event, date) pair @@ -340,3 +352,35 @@ def next_event(self): def __len__(self): """:return: the length of the queue""" return len(self.queue) + + +class Simulation(_BaseSimulation): + """ + Default simulation type, which runs a serial simulation. + Events in the event_queue are executed in sequence, one + after the other, in the order they appear in the queue. + + See `_BaseSimulation` for more details. + """ + + def step_through_events(self) -> None: + """Serial simulation: events are executed in the + order they occur in the queue.""" + while self.event_queue: + event, date = self.event_queue.next_event() + + self.update_progress_bar(date) + + if date >= self.end_date: + self.date = self.end_date + break + self.fire_single_event(event, date) + + def fire_single_event(self, event, date): + """Fires the event once for the given date + + :param event: :py:class:`Event` to fire + :param date: the date of the event + """ + self.date = date + event.run() diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py new file mode 100644 index 0000000000..072e0bde83 --- /dev/null +++ b/src/tlo/threaded_simulation.py @@ -0,0 +1,230 @@ +from queue import Queue +from threading import Thread +from time import sleep +from typing import Callable, List +from warnings import warn + +from tlo.events import Event, IndividualScopeEventMixin +from tlo.simulation import _BaseSimulation + +MAX_THREADS = 4 # make more elegant, probably examine the OS + +class ThreadController: + """ + Thread controllers serve an organisational role, and allow us to + keep track of threads that we create for debugging purposes. + They also provide convenient wrapper functions to batch start the + threads they control all at once, and will manage the teardown of + their own threads when this is ready. + + Threads spawned by the controller are intended to form a "pool" of + workers that will routinely check a Queue object for tasks to + perform, and otherwise will be idle. Worker targets should be + functions that allow the thread access to the job queue, whilst + they persist. + """ + _n_threads: int + _thread_list: List[Thread] + + _worker_name: str + + @property + def n_threads(self) -> int: + """ + Number of threads that this controller is operating. + """ + return self._n_threads + + def __init__(self, n_threads: int = 1, name: str = "Worker") -> None: + """ + Create a new thread controller. + + :param n_threads: Number of threads to be spawned, in addition to + the main thread. + :param name: Name to assign to worker threads that this controller + creates, for logging and internal ID purposes. + """ + # Determine how many threads to use given the machine maximum, + # and the user's request. Be sure to save one for the main thread! + self._n_threads = min(n_threads, MAX_THREADS - 1) + if self._n_threads < n_threads: + warn( + f"Requested {n_threads} but this exceeds the maximum possible number of worker threads ({MAX_THREADS - 1}). Restricting to {self._n_threads}." + ) + assert ( + self._n_threads > 0 + ), f"Instructed to use {self._n_threads} threads, which must be non-negative. Use a serial simulation if you do not want to delegate event execution to threads." + + # Prepare the list of threads, but do not initialise threads yet + # since they need access to some of the Simulation properties + self._thread_list = [] + + self._worker_name = name + + def create_all(self, target: Callable[[], None]) -> None: + """ + Creates the threads that will be managed by this controller, + and sets their targets. + + Targets are not executed until the start_all method is called. + + Targets are functions that take no arguments and return + no values. Workers will execute these functions - preserving + context and access of the functions that are passed in. + Passing in something like foo.bar will provide access to the + foo object and attempt to run the bar method on said object, + for example. + """ + for i in range(self._n_threads): + self._thread_list.append( + Thread(target=target, daemon=True, name=f"{self._worker_name}-{i}") + ) + + def start_all(self) -> None: + """ + Start all threads managed by this controller. + """ + for thread in self._thread_list: + thread.start() + + +class ThreadedSimulation(_BaseSimulation): + """ + Class for running threaded simulations. Events in the queue that can + be executed in parallel are delegated to a worker pool, to be executed + when resources become available. + + Certain events cannot be executed in parallel threads safely (notably + population-level events, but also events that attempt to advance time). + When encountering such events, all workers complete the remaining + "thread-safe" events before the unsafe event is triggered. + + Progress bar for threaded simulations only advances when time advances, + and statistics do not dynamically update as each event is fired. + + TODO: Prints to actually using the logger + TODO: Prints to include the worker thread they were spit out from + """ + # Tracks the job queue that will be dispatched to worker threads + _worker_queue: Queue + # Workers must always work on different individuals due to + # their ability to edit the population DataFrame. + _worker_patient_targets: set + # Provides programmatic access to the threads created for the + # simulation + thread_controller: ThreadController + + # Safety-catch variables to ensure safe execution of events. + _individuals_currently_being_examined: set + + def __init__(self, n_threads: int = 1, **kwargs) -> None: + """ + In addition to the usual simulation instantiation arguments, + threaded simulations must also be passed the number of + worker threads to be used. + + :param n_threads: Number of threads to use - in addition to + the main thread - when running simulation events. + """ + # Initialise as you would for any other simulation + super().__init__(**kwargs) + + # Setup the thread controller + self.thread_controller = ThreadController(n_threads=n_threads, name = "EventWorker-") + + # Set the target workflow of all workers + self.thread_controller.create_all(self._worker_target) + + self._worker_queue = Queue() + # Initialise the set tracking which individuals the event workers + # are currently targeting. + self._worker_patient_targets = set() + + def _worker_target(self) -> None: + """ + Workflow that threads will execute. + + The workflow assumes that events added to the worker queue + are always safe to execute in any thread, alongside any + other events that might currently be in the queue. + """ + # While thread/worker is alive + # WOULD LIKE TO NOT HAVE THIS. We could spawn threads only when they're needed + # and then limit the number we have spawned at once, but creating a thread is also an expensive operation. + # Plus, the .get() method puts the thread to sleep until it gets something, so this should be fine. + while True: + # Check for the next job in the queue + event_to_run: Event = self._worker_queue.get() + target = event_to_run.target + # Wait for other events targeting the same individual to complete + while target in self._worker_patient_targets: + # Stall if another thread is currently executing an event + # which targets the same individual. + # Add some sleep time here to avoid near-misses. + sleep(0.01) + # Flag that this thread is running an event on this patient + self._worker_patient_targets.add(target) + event_to_run.run() + self._worker_patient_targets.remove(target) + # Report success and await next task + self._worker_queue.task_done() + + @staticmethod + def event_must_run_in_main_thread(event: Event) -> bool: + """ + Return True if the event passed in must be run in the main thread, in serial. + + Population-level events must always run in the main thread with no worker + events running in parallel, since they need to scan the state of the simulation + at that moment in time and workers have write access to simulation properties. + """ + if not isinstance(event, IndividualScopeEventMixin): + return True + return False + + def step_through_events(self) -> None: + # Start the threads + self.thread_controller.start_all() + + # Whilst the event queue is not empty + while self.event_queue: + event, date_of_next_event = self.event_queue.next_event() + self.update_progress_bar(self.date) + + # If the simulation should end, escape + if date_of_next_event >= self.end_date: + break + # If we want to advance time, we need to ensure that + # the worker queue. Otherwise, a worker might be running an + # event from the previous date but may still call sim.date + # to get the "current" time, which would then be out-of-sync. + elif date_of_next_event > self.date: + # This event moves time forward, wait until all jobs + # from the current date have finished before advancing time + self.wait_for_workers() + # All jobs from the previous day have ended. + # Advance time and continue. + self.date = date_of_next_event + + # Next, determine if the event to be run can be delegated to the + # worker pool. + if self.event_must_run_in_main_thread(event): + # Event needs all workers to finish, then to run in + # the main thread (this one) + self.wait_for_workers() + event.run() + else: + # This job can be delegated to the worker pool, and run safely + self._worker_queue.put(event) + + # We may have exhausted all the events in the queue, but the workers will + # still need time to process them all! + self.wait_for_workers() + self.update_progress_bar(date_of_next_event) + + def wait_for_workers(self) -> None: + """ + Pauses simulation progression until all worker threads + are ready and waiting to receive a new job. + """ + self._worker_queue.join()