From aa009cfeaffb708568506c0bc7f58158319f0f2b Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 25 Apr 2024 15:39:45 +0100 Subject: [PATCH 1/8] Barebones threaded simulation class and controllers --- src/tlo/simulation.py | 10 +- src/tlo/threaded_simulation.py | 209 +++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 src/tlo/threaded_simulation.py diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index 219b1b8a6f..a1149e5614 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -1,4 +1,5 @@ """The main simulation controller.""" +from __future__ import annotations import datetime import heapq @@ -6,7 +7,7 @@ import time from collections import OrderedDict from pathlib import Path -from typing import Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import numpy as np @@ -15,6 +16,9 @@ from tlo.events import Event, IndividualScopeEventMixin from tlo.progressbar import ProgressBar +if TYPE_CHECKING: + from tlo import Module + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -43,6 +47,8 @@ class Simulation: with independent state. """ + modules: OrderedDict[str, Module] + def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = None, show_progress_bar=False): """Create a new simulation. @@ -329,7 +335,7 @@ def schedule(self, event, date): entry = (date, event.priority, next(self.counter), event) heapq.heappush(self.queue, entry) - def next_event(self): + def next_event(self) -> Tuple[Event, Date]: """Get the earliest event in the queue. :returns: an (event, date) pair diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py new file mode 100644 index 0000000000..c785fafe76 --- /dev/null +++ b/src/tlo/threaded_simulation.py @@ -0,0 +1,209 @@ +from threading import Thread +from time import sleep +from typing import Callable, List +from queue import Queue +from warnings import warn + +from tlo import Simulation +from tlo.events import Event, IndividualScopeEventMixin + +MAX_THREADS = 2 # make more elegant, probably examine the OS + +class ThreadController: + """ + """ + _n_threads: int + _thread_list: List[Thread] + + _worker_name: str + + @property + def n_threads(self) -> int: + """ + Number of threads that this controller is operating. + """ + return self._n_threads + + def __init__(self, n_threads: int = 1, name: str = "Worker") -> None: + """ + """ + # Determine how many threads to use given the machine maximum, + # and the user's request + self._n_threads = min(n_threads, MAX_THREADS) + if self._n_threads < n_threads: + warn( + f"Requested {n_threads} but this exceeds the maximum possible number of threads ({MAX_THREADS}). Restricting to {self._n_threads}." + ) + assert ( + self._n_threads > 0 + ), f"Instructed to use {self._n_threads} threads, which must be non-negative. Use a serial simulation if you do not want to delegate event execution to threads." + + # Prepare the list of threads, but do not initialise threads yet + # since they need access to some of the Simulation properties + self._thread_list = [] + + self._worker_name = name + + def create_all(self, target: Callable[[], None]) -> None: + """ + Creates the threads that will be managed by this controller, + and sets their targets. + """ + for i in range(self._n_threads): + self._thread_list.append( + Thread(target=target, daemon=True, name=f"{self._worker_name}-{i}") + ) + + def start_all(self) -> None: + """ + Start all threads managed by this controller. + """ + for thread in self._thread_list: + thread.start() + + +class ThreadedSimulation(Simulation): + """ + Class for running threaded simulations. WIP. + + - No logging support + - No progress bar + + NOTE: Re-implementing the simulate() method is the easiest way to overwrite for the purposes of PoC. Later, we can refactor a base simulation class and from there have two separate simulate methods in alternative subclasses. + """ + # Tracks the job queue that will be dispatched to worker threads + _worker_queue: Queue + # Workers must always work on different individuals due to + # their ability to edit the population DataFrame. + _worker_patient_targets: set + # Provides programmatic access to the threads created for the + # simulation + thread_controller: ThreadController + + # Safety-catch variables to ensure safe execution of events. + _individuals_currently_being_examined: set + + def __init__(self, n_threads: int = 1, **kwargs) -> None: + # Initialise as you would for any other simulation + super().__init__(**kwargs) + + # Progress bar currently not supported + self.show_progress_bar = False + + # Setup the thread controller + self.thread_controller = ThreadController(n_threads=n_threads, name = "EventWorker-") + + # Set the target workflow of all workers + self.thread_controller.create_all(self._worker_target) + + self._worker_queue = Queue() + # Initialise the set tracking which individuals the event workers + # are currently targeting. + self._worker_patient_targets = set() + + def _worker_target(self) -> None: + """ + Workflow that threads will execute. + + The workflow assumes that events added to the worker queue + are always safe to execute in any thread, alongside any + other events that might currently be in the queue. + """ + # While thread/worker is alive + # WOULD LIKE TO NOT HAVE THIS. We could spawn threads only when they're needed + # and then limit the number we have spawned at once, but creating a thread is also an expensive operation. + # Plus, the .get() method puts the thread to sleep until it gets something, so this should be fine. + while True: + # Check for the next job in the queue + event_to_run: Event = self._worker_queue.get() + target = event_to_run.target + # Wait for other events targeting the same individual to complete + while target in self._worker_patient_targets: + # Stall if another thread is currently executing an event + # which targets the same individual. + # Add some sleep time here to avoid near-misses + sleep(0.01) + # Flag that this thread is running an event on this patient + self._worker_patient_targets.add(target) + event_to_run.run() + self._worker_patient_targets.remove(target) + # Report success and await next task + self._worker_queue.task_done() + + @staticmethod + def event_must_run_in_main_thread(event: Event) -> bool: + """ + Return True if the event passed in must be run in the main thread, in serial. + + Population-level events must always run in the main thread with no worker + events running in parallel, since they need to scan the state of the simulation + at that moment in time and workers have write access to simulation properties. + """ + if not isinstance(event, IndividualScopeEventMixin): + return True + return False + + def simulate(self, *, end_date): + """Simulation until the given end date + + :param end_date: when to stop simulating. Only events strictly before this + date will be allowed to occur. + Must be given as a keyword parameter for clarity. + """ + self.end_date = end_date + + for module in self.modules.values(): + module.initialise_simulation(self) + + # Start the threads + self.thread_controller.start_all() + + # Whilst the event queue is not empty + while self.event_queue: + event, date = self.event_queue.next_event() + + # If the simulation should end, escape + if date >= self.end_date: + break + # If we want to advance time, we need to ensure that + # the worker queue. Otherwise, a worker might be running an + # event from the previous date but may still call sim.date + # to get the "current" time, which would then be out-of-sync. + elif date != self.date: + # This event moves time forward, wait until all jobs + # from the current day have finished before advancing time + self.wait_for_workers() + # All jobs from the previous day have ended. + # Advance time and continue. + self.date = date + + # Next, determine if the event to be run can be delegated to the + # worker pool. + if self.event_must_run_in_main_thread(event): + # Event needs all workers to finish, then to run in + # the main thread (this one) + self.wait_for_workers() + event.run() + else: + # This job can be delegated to the worker pool, and run safely + self._worker_queue.put(event) + + self.wait_for_workers() + # The simulation has ended. + + for module in self.modules.values(): + module.on_simulation_end() + + # From Python logging.shutdown + if self.output_file: + try: + self.output_file.acquire() + self.output_file.flush() + self.output_file.close() + except (OSError, ValueError): + pass + finally: + self.output_file.release() + + def wait_for_workers(self) -> None: + self._worker_queue.join() From 93658f1e27a1a60b9eb683e0536b676db61b9d0e Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 25 Apr 2024 15:43:51 +0100 Subject: [PATCH 2/8] Add some print msgs to the main thread for some info --- src/tlo/threaded_simulation.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index c785fafe76..bacb63cd22 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -7,7 +7,7 @@ from tlo import Simulation from tlo.events import Event, IndividualScopeEventMixin -MAX_THREADS = 2 # make more elegant, probably examine the OS +MAX_THREADS = 4 # make more elegant, probably examine the OS class ThreadController: """ @@ -28,11 +28,11 @@ def __init__(self, n_threads: int = 1, name: str = "Worker") -> None: """ """ # Determine how many threads to use given the machine maximum, - # and the user's request - self._n_threads = min(n_threads, MAX_THREADS) + # and the user's request. Be sure to save one for the main thread! + self._n_threads = min(n_threads, MAX_THREADS - 1) if self._n_threads < n_threads: warn( - f"Requested {n_threads} but this exceeds the maximum possible number of threads ({MAX_THREADS}). Restricting to {self._n_threads}." + f"Requested {n_threads} but this exceeds the maximum possible number of worker threads ({MAX_THREADS - 1}). Restricting to {self._n_threads}." ) assert ( self._n_threads > 0 @@ -170,26 +170,33 @@ def simulate(self, *, end_date): # event from the previous date but may still call sim.date # to get the "current" time, which would then be out-of-sync. elif date != self.date: + print("MAIN THREAD: Waiting to advance time...", flush=True, end="") # This event moves time forward, wait until all jobs # from the current day have finished before advancing time self.wait_for_workers() # All jobs from the previous day have ended. # Advance time and continue. self.date = date + print("done") # Next, determine if the event to be run can be delegated to the # worker pool. if self.event_must_run_in_main_thread(event): + print("MAIN THREAD: Waiting to run population level event...") # Event needs all workers to finish, then to run in # the main thread (this one) self.wait_for_workers() + print("running", flush=True, end="...") event.run() + print("done") else: # This job can be delegated to the worker pool, and run safely self._worker_queue.put(event) + # We may have exhausted all the events in the queue, but the workers will + # still need time to process them all! self.wait_for_workers() - # The simulation has ended. + print("MAIN THREAD: Simulation has now ended, worker queue empty.") for module in self.modules.values(): module.on_simulation_end() From 0064baff4128b310a6377bc8f83010c3b70190d3 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 09:59:35 +0100 Subject: [PATCH 3/8] Write threaded simulation with very basic functionality --- src/tlo/threaded_simulation.py | 48 ++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index bacb63cd22..cf8f0be278 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -11,6 +11,17 @@ class ThreadController: """ + Thread controllers serve an organisational role, and allow us to + keep track of threads that we create for debugging purposes. + They also provide convenient wrapper functions to batch start the + threads they control all at once, and will manage the teardown of + their own threads when this is ready. + + Threads spawned by the controller are intended to form a "pool" of + workers that will routinely check a Queue object for tasks to + perform, and otherwise will be idle. Worker targets should be + functions that allow the thread access to the job queue, whilst + they persist. """ _n_threads: int _thread_list: List[Thread] @@ -26,6 +37,12 @@ def n_threads(self) -> int: def __init__(self, n_threads: int = 1, name: str = "Worker") -> None: """ + Create a new thread controller. + + :param n_threads: Number of threads to be spawned, in addition to + the main thread. + :param name: Name to assign to worker threads that this controller + creates, for logging and internal ID purposes. """ # Determine how many threads to use given the machine maximum, # and the user's request. Be sure to save one for the main thread! @@ -48,6 +65,15 @@ def create_all(self, target: Callable[[], None]) -> None: """ Creates the threads that will be managed by this controller, and sets their targets. + + Targets are not executed until the start_all method is called. + + Targets are functions that take no arguments and return + no values. Workers will execute these functions - preserving + context and access of the functions that are passed in. + Passing in something like foo.bar will provide access to the + foo object and attempt to run the bar method on said object, + for example. """ for i in range(self._n_threads): self._thread_list.append( @@ -84,6 +110,14 @@ class ThreadedSimulation(Simulation): _individuals_currently_being_examined: set def __init__(self, n_threads: int = 1, **kwargs) -> None: + """ + In addition to the usual simulation instantiation arguments, + threaded simulations must also be passed the number of + worker threads to be used. + + :param n_threads: Number of threads to use - in addition to + the main thread - when running simulation events. + """ # Initialise as you would for any other simulation super().__init__(**kwargs) @@ -121,7 +155,7 @@ def _worker_target(self) -> None: while target in self._worker_patient_targets: # Stall if another thread is currently executing an event # which targets the same individual. - # Add some sleep time here to avoid near-misses + # Add some sleep time here to avoid near-misses. sleep(0.01) # Flag that this thread is running an event on this patient self._worker_patient_targets.add(target) @@ -144,11 +178,11 @@ def event_must_run_in_main_thread(event: Event) -> bool: return False def simulate(self, *, end_date): - """Simulation until the given end date + """Simulate until the given end date, utilising threads to process events + that can be run simultaneously. - :param end_date: when to stop simulating. Only events strictly before this - date will be allowed to occur. - Must be given as a keyword parameter for clarity. + :param end_date: when to stop simulating. Only events strictly before this + date will be allowed to occur. Must be given as a keyword parameter for clarity. """ self.end_date = end_date @@ -213,4 +247,8 @@ def simulate(self, *, end_date): self.output_file.release() def wait_for_workers(self) -> None: + """ + Pauses simulation progression until all worker threads + are ready and waiting to receive a new job. + """ self._worker_queue.join() From fa5690bc04f23d944f97335a365a82abf6c01389 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 10:39:24 +0100 Subject: [PATCH 4/8] Allow scale_run to run a threaded simulation --- src/scripts/profiling/scale_run.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/scripts/profiling/scale_run.py b/src/scripts/profiling/scale_run.py index 735d1e7ba3..680386b701 100644 --- a/src/scripts/profiling/scale_run.py +++ b/src/scripts/profiling/scale_run.py @@ -13,6 +13,7 @@ from shared import print_checksum, schedule_profile_log from tlo import Date, Simulation, logging +from tlo.threaded_simulation import ThreadedSimulation from tlo.analysis.utils import parse_log_file as parse_log_file_fn from tlo.methods.fullmodel import fullmodel @@ -55,6 +56,7 @@ def scale_run( ignore_warnings: bool = False, log_final_population_checksum: bool = True, profiler: Optional["Profiler"] = None, + n_threads: Optional[int] = 0, ) -> Simulation: if ignore_warnings: warnings.filterwarnings("ignore") @@ -74,12 +76,16 @@ def scale_run( "suppress_stdout": disable_log_output_to_stdout, } - sim = Simulation( - start_date=start_date, - seed=seed, - log_config=log_config, - show_progress_bar=show_progress_bar, - ) + sim_args = { + "start_date": start_date, + "seed": seed, + "log_config": log_config, + "show_progress_bar": show_progress_bar, + } + if n_threads: + sim = ThreadedSimulation(n_threads=n_threads, **sim_args) + else: + sim = Simulation(**sim_args) # Register the appropriate modules with the arguments passed through sim.register( @@ -269,6 +275,12 @@ def scale_run( ), action="store_true", ) + parser.add_argument( + "--n-threads", + help="Run a threaded simulation using the given number of threaded workers", + type=int, + default=0, + ) args = parser.parse_args() args_dict = vars(args) From 459b437f865b84408c95359a01caa1ea8afc84c5 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 11:38:39 +0100 Subject: [PATCH 5/8] Create _BaseSimulation to avoid code repetition --- src/tlo/simulation.py | 107 ++++++++++++++++++++++----------- src/tlo/threaded_simulation.py | 52 ++++++---------- 2 files changed, 89 insertions(+), 70 deletions(-) diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index a1149e5614..bf5348d0fd 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -23,7 +23,7 @@ logger.setLevel(logging.INFO) -class Simulation: +class _BaseSimulation: """The main control centre for a simulation. This class contains the core simulation logic and event queue, and holds @@ -45,6 +45,10 @@ class Simulation: The simulation-level random number generator. Note that individual modules also have their own random number generator with independent state. + + The `step_through_events` method is implemented by the `Simulation` and + `ThreadedSimulation` classes, which controls how the simulation events are + fired. """ modules: OrderedDict[str, Module] @@ -69,6 +73,7 @@ def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = Non self.population: Optional[Population] = None self.show_progress_bar = show_progress_bar + self.progress_bar = None # logging if log_config is None: @@ -211,37 +216,18 @@ def simulate(self, *, end_date): for module in self.modules.values(): module.initialise_simulation(self) - progress_bar = None if self.show_progress_bar: - num_simulated_days = (end_date - self.start_date).days - progress_bar = ProgressBar( + num_simulated_days = (self.end_date - self.start_date).days + self.progress_bar = ProgressBar( num_simulated_days, "Simulation progress", unit="day") - progress_bar.start() + self.progress_bar.start() - while self.event_queue: - event, date = self.event_queue.next_event() - - if self.show_progress_bar: - simulation_day = (date - self.start_date).days - stats_dict = { - "date": str(date.date()), - "dataframe size": str(len(self.population.props)), - "queued events": str(len(self.event_queue)), - } - if "HealthSystem" in self.modules: - stats_dict["queued HSI events"] = str( - len(self.modules["HealthSystem"].HSI_EVENT_QUEUE) - ) - progress_bar.update(simulation_day, stats_dict=stats_dict) - - if date >= end_date: - self.date = end_date - break - self.fire_single_event(event, date) + # Run the simulation by firing events in the queue + self.step_through_events() # The simulation has ended. if self.show_progress_bar: - progress_bar.stop() + self.progress_bar.stop() for module in self.modules.values(): module.on_simulation_end() @@ -259,6 +245,17 @@ def simulate(self, *, end_date): finally: self.output_file.release() + def step_through_events(self) -> None: + """ + Method for forward-propagating the simulation, by executing + the scheduled events in the queue. This is overwritten by + inheriting classes. + """ + raise NotImplementedError( + f"{self.__name__} is not intended to be simulated, " + "use either Simulation or ThreadedSimulation to run a simulation." + ) + def schedule_event(self, event, date): """Schedule an event to happen on the given future date. @@ -275,15 +272,6 @@ def schedule_event(self, event, date): self.event_queue.schedule(event=event, date=date) - def fire_single_event(self, event, date): - """Fires the event once for the given date - - :param event: :py:class:`Event` to fire - :param date: the date of the event - """ - self.date = date - event.run() - def do_birth(self, mother_id): """Create a new child person. @@ -314,6 +302,23 @@ def find_events_for_person(self, person_id: int): return person_events + def update_progress_bar(self, new_date: Date): + """ + Updates the simulation's progress bar, if this is in use. + """ + if self.show_progress_bar: + simulation_day = (new_date - self.start_date).days + stats_dict = { + "date": str(new_date.date()), + "dataframe size": str(len(self.population.props)), + "queued events": str(len(self.event_queue)), + } + if "HealthSystem" in self.modules: + stats_dict["queued HSI events"] = str( + len(self.modules["HealthSystem"].HSI_EVENT_QUEUE) + ) + self.progress_bar.update(simulation_day, stats_dict=stats_dict) + class EventQueue: """A simple priority queue for events. @@ -346,3 +351,35 @@ def next_event(self) -> Tuple[Event, Date]: def __len__(self): """:return: the length of the queue""" return len(self.queue) + + +class Simulation(_BaseSimulation): + """ + Default simulation type, which runs a serial simulation. + Events in the event_queue are executed in sequence, one + after the other, in the order they appear in the queue. + + See `_BaseSimulation` for more details. + """ + + def step_through_events(self) -> None: + """Serial simulation: events are executed in the + order they occur in the queue.""" + while self.event_queue: + event, date = self.event_queue.next_event() + + self.update_progress_bar(date) + + if date >= self.end_date: + self.date = self.end_date + break + self.fire_single_event(event, date) + + def fire_single_event(self, event, date): + """Fires the event once for the given date + + :param event: :py:class:`Event` to fire + :param date: the date of the event + """ + self.date = date + event.run() diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index cf8f0be278..dbc6370681 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -4,7 +4,7 @@ from queue import Queue from warnings import warn -from tlo import Simulation +from tlo.simulation import _BaseSimulation from tlo.events import Event, IndividualScopeEventMixin MAX_THREADS = 4 # make more elegant, probably examine the OS @@ -88,14 +88,22 @@ def start_all(self) -> None: thread.start() -class ThreadedSimulation(Simulation): +class ThreadedSimulation(_BaseSimulation): """ - Class for running threaded simulations. WIP. + Class for running threaded simulations. Events in the queue that can + be executed in parallel are delegated to a worker pool, to be executed + when resources become available. - - No logging support - - No progress bar + Certain events cannot be executed in parallel threads safely (notably + population-level events, but also events that attempt to advance time). + When encountering such events, all workers complete the remaining + "thread-safe" events before the unsafe event is triggered. - NOTE: Re-implementing the simulate() method is the easiest way to overwrite for the purposes of PoC. Later, we can refactor a base simulation class and from there have two separate simulate methods in alternative subclasses. + Progress bar for threaded simulations only advances when time advances, + and statistics do not dynamically update as each event is fired. + + TODO: Prints to actually using the logger + TODO: Prints to include the worker thread they were spit out from """ # Tracks the job queue that will be dispatched to worker threads _worker_queue: Queue @@ -177,18 +185,7 @@ def event_must_run_in_main_thread(event: Event) -> bool: return True return False - def simulate(self, *, end_date): - """Simulate until the given end date, utilising threads to process events - that can be run simultaneously. - - :param end_date: when to stop simulating. Only events strictly before this - date will be allowed to occur. Must be given as a keyword parameter for clarity. - """ - self.end_date = end_date - - for module in self.modules.values(): - module.initialise_simulation(self) - + def step_through_events(self) -> None: # Start the threads self.thread_controller.start_all() @@ -204,15 +201,14 @@ def simulate(self, *, end_date): # event from the previous date but may still call sim.date # to get the "current" time, which would then be out-of-sync. elif date != self.date: - print("MAIN THREAD: Waiting to advance time...", flush=True, end="") # This event moves time forward, wait until all jobs # from the current day have finished before advancing time self.wait_for_workers() # All jobs from the previous day have ended. # Advance time and continue. self.date = date - print("done") - + self.update_progress_bar(self.date) + # Next, determine if the event to be run can be delegated to the # worker pool. if self.event_must_run_in_main_thread(event): @@ -232,20 +228,6 @@ def simulate(self, *, end_date): self.wait_for_workers() print("MAIN THREAD: Simulation has now ended, worker queue empty.") - for module in self.modules.values(): - module.on_simulation_end() - - # From Python logging.shutdown - if self.output_file: - try: - self.output_file.acquire() - self.output_file.flush() - self.output_file.close() - except (OSError, ValueError): - pass - finally: - self.output_file.release() - def wait_for_workers(self) -> None: """ Pauses simulation progression until all worker threads From d06a9beb4f562025d3b64f78b84898b85219d6fb Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 12:03:52 +0100 Subject: [PATCH 6/8] Attempt to make progress bar work... but it doesn't --- src/tlo/threaded_simulation.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index dbc6370681..f9a5fdc16b 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -200,9 +200,9 @@ def step_through_events(self) -> None: # the worker queue. Otherwise, a worker might be running an # event from the previous date but may still call sim.date # to get the "current" time, which would then be out-of-sync. - elif date != self.date: + elif date > self.date: # This event moves time forward, wait until all jobs - # from the current day have finished before advancing time + # from the current date have finished before advancing time self.wait_for_workers() # All jobs from the previous day have ended. # Advance time and continue. @@ -212,13 +212,10 @@ def step_through_events(self) -> None: # Next, determine if the event to be run can be delegated to the # worker pool. if self.event_must_run_in_main_thread(event): - print("MAIN THREAD: Waiting to run population level event...") # Event needs all workers to finish, then to run in # the main thread (this one) self.wait_for_workers() - print("running", flush=True, end="...") event.run() - print("done") else: # This job can be delegated to the worker pool, and run safely self._worker_queue.put(event) @@ -226,7 +223,6 @@ def step_through_events(self) -> None: # We may have exhausted all the events in the queue, but the workers will # still need time to process them all! self.wait_for_workers() - print("MAIN THREAD: Simulation has now ended, worker queue empty.") def wait_for_workers(self) -> None: """ From f63536ccc6168c2c7af8183c882440338bc20238 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 12:20:17 +0100 Subject: [PATCH 7/8] Progress bar fixed for threaded simulations --- src/tlo/threaded_simulation.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index f9a5fdc16b..5e54affaec 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -129,9 +129,6 @@ def __init__(self, n_threads: int = 1, **kwargs) -> None: # Initialise as you would for any other simulation super().__init__(**kwargs) - # Progress bar currently not supported - self.show_progress_bar = False - # Setup the thread controller self.thread_controller = ThreadController(n_threads=n_threads, name = "EventWorker-") @@ -191,23 +188,23 @@ def step_through_events(self) -> None: # Whilst the event queue is not empty while self.event_queue: - event, date = self.event_queue.next_event() + event, date_of_next_event = self.event_queue.next_event() + self.update_progress_bar(self.date) # If the simulation should end, escape - if date >= self.end_date: + if date_of_next_event >= self.end_date: break # If we want to advance time, we need to ensure that # the worker queue. Otherwise, a worker might be running an # event from the previous date but may still call sim.date # to get the "current" time, which would then be out-of-sync. - elif date > self.date: + elif date_of_next_event > self.date: # This event moves time forward, wait until all jobs # from the current date have finished before advancing time self.wait_for_workers() # All jobs from the previous day have ended. # Advance time and continue. - self.date = date - self.update_progress_bar(self.date) + self.date = date_of_next_event # Next, determine if the event to be run can be delegated to the # worker pool. @@ -223,6 +220,7 @@ def step_through_events(self) -> None: # We may have exhausted all the events in the queue, but the workers will # still need time to process them all! self.wait_for_workers() + self.update_progress_bar(date_of_next_event) def wait_for_workers(self) -> None: """ From 6c7050ccc06385d0eb2fe404012a9aace984b194 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 30 Apr 2024 16:26:35 +0100 Subject: [PATCH 8/8] Lint pass --- src/scripts/profiling/scale_run.py | 2 +- src/tlo/simulation.py | 1 + src/tlo/threaded_simulation.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/scripts/profiling/scale_run.py b/src/scripts/profiling/scale_run.py index 680386b701..c0409a88ce 100644 --- a/src/scripts/profiling/scale_run.py +++ b/src/scripts/profiling/scale_run.py @@ -13,9 +13,9 @@ from shared import print_checksum, schedule_profile_log from tlo import Date, Simulation, logging -from tlo.threaded_simulation import ThreadedSimulation from tlo.analysis.utils import parse_log_file as parse_log_file_fn from tlo.methods.fullmodel import fullmodel +from tlo.threaded_simulation import ThreadedSimulation _TLO_ROOT: Path = Path(__file__).parents[3].resolve() _TLO_OUTPUT_DIR: Path = (_TLO_ROOT / "outputs").resolve() diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index bf5348d0fd..d5b863352b 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -51,6 +51,7 @@ class _BaseSimulation: fired. """ + __name__: str = "_BaseSimulation" modules: OrderedDict[str, Module] def __init__(self, *, start_date: Date, seed: int = None, log_config: dict = None, diff --git a/src/tlo/threaded_simulation.py b/src/tlo/threaded_simulation.py index 5e54affaec..072e0bde83 100644 --- a/src/tlo/threaded_simulation.py +++ b/src/tlo/threaded_simulation.py @@ -1,11 +1,11 @@ +from queue import Queue from threading import Thread from time import sleep from typing import Callable, List -from queue import Queue from warnings import warn -from tlo.simulation import _BaseSimulation from tlo.events import Event, IndividualScopeEventMixin +from tlo.simulation import _BaseSimulation MAX_THREADS = 4 # make more elegant, probably examine the OS