Skip to content

Commit 6991a03

Browse files
albrjarmudambi
andauthored
Albrja/mic 5649/mypy framework engine (#562)
* Type-hinting: fix errors in engine and plugins --------- Co-authored-by: Rajan Mudambi <11376379+rmudambi@users.noreply.github.com>
1 parent 56dd0a0 commit 6991a03

File tree

9 files changed

+216
-112
lines changed

9 files changed

+216
-112
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**3.2.12 - 12/26/24**
2+
3+
- Type-hinting: Fix mypy errors in vivarium/framework/engine.py
4+
15
**3.2.11 - 12/23/24**
26

37
- Type-hinting: Fix mypy errors in vivarium/framework/components/parser.py

docs/nitpick-exceptions

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ py:class ResultsUpdater
4545
py:class _NestedDict
4646
py:exc ResultsConfigurationError
4747
py:exc vivarium.framework.results.exceptions.ResultsConfigurationError
48+
py:class vivarium.framework.plugins.M
49+
py:class vivarium.framework.plugins.I
4850

4951
# layered_config_tree
5052
py:class layered_config_tree.main.LayeredConfigTree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ exclude = [
4444
'src/vivarium/examples/disease_model/observer.py',
4545
'src/vivarium/examples/disease_model/population.py',
4646
'src/vivarium/examples/disease_model/risk.py',
47-
'src/vivarium/framework/engine.py',
4847
'src/vivarium/framework/population/manager.py',
4948
'src/vivarium/framework/population/population_view.py',
5049
'src/vivarium/interface/cli.py',
@@ -98,5 +97,6 @@ module = [
9897
"scipy.*",
9998
"ipywidgets.*",
10099
"Ipython.*",
100+
"dill",
101101
]
102102
ignore_missing_imports = true

src/vivarium/framework/engine.py

Lines changed: 68 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: ignore-errors
21
"""
32
===================
43
The Vivarium Engine
@@ -23,6 +22,7 @@
2322
from pathlib import Path
2423
from pprint import pformat
2524
from time import time
25+
from typing import Any
2626

2727
import dill
2828
import numpy as np
@@ -32,20 +32,24 @@
3232

3333
from vivarium import Component
3434
from vivarium.exceptions import VivariumError
35-
from vivarium.framework.artifact import ArtifactInterface
36-
from vivarium.framework.components import ComponentConfigError, ComponentInterface
35+
from vivarium.framework.artifact import ArtifactInterface, ArtifactManager
36+
from vivarium.framework.components import (
37+
ComponentConfigError,
38+
ComponentInterface,
39+
ComponentManager,
40+
)
3741
from vivarium.framework.configuration import build_model_specification
38-
from vivarium.framework.event import EventInterface
39-
from vivarium.framework.lifecycle import LifeCycleInterface
40-
from vivarium.framework.logging import LoggingInterface
41-
from vivarium.framework.lookup import LookupTableInterface
42+
from vivarium.framework.event import EventInterface, EventManager
43+
from vivarium.framework.lifecycle import LifeCycleInterface, LifeCycleManager
44+
from vivarium.framework.logging import LoggingInterface, LoggingManager
45+
from vivarium.framework.lookup import LookupTableInterface, LookupTableManager
4246
from vivarium.framework.plugins import PluginManager
43-
from vivarium.framework.population import PopulationInterface
44-
from vivarium.framework.randomness import RandomnessInterface
45-
from vivarium.framework.resource import ResourceInterface
46-
from vivarium.framework.results import ResultsInterface
47-
from vivarium.framework.time import TimeInterface
48-
from vivarium.framework.values import ValuesInterface
47+
from vivarium.framework.population import PopulationInterface, PopulationManager
48+
from vivarium.framework.randomness import RandomnessInterface, RandomnessManager
49+
from vivarium.framework.resource import ResourceInterface, ResourceManager
50+
from vivarium.framework.results import ResultsInterface, ResultsManager
51+
from vivarium.framework.time import SimulationClock, TimeInterface
52+
from vivarium.framework.values import ValuesInterface, ValuesManager
4953
from vivarium.types import ClockTime
5054

5155

@@ -87,7 +91,7 @@ def _get_context_name(sim_name: str | None) -> str:
8791
return sim_name
8892

8993
@staticmethod
90-
def _clear_context_cache():
94+
def _clear_context_cache() -> None:
9195
"""Clear the cache of simulation context names.
9296
9397
Notes
@@ -99,12 +103,12 @@ def _clear_context_cache():
99103
def __init__(
100104
self,
101105
model_specification: str | Path | LayeredConfigTree | None = None,
102-
components: list[Component] | dict | LayeredConfigTree | None = None,
103-
configuration: dict | LayeredConfigTree | None = None,
104-
plugin_configuration: dict | LayeredConfigTree | None = None,
106+
components: list[Component] | dict[str, Any] | LayeredConfigTree | None = None,
107+
configuration: dict[str, Any] | LayeredConfigTree | None = None,
108+
plugin_configuration: dict[str, Any] | LayeredConfigTree | None = None,
105109
sim_name: str | None = None,
106110
logging_verbosity: int = 1,
107-
):
111+
) -> None:
108112
self._name = self._get_context_name(sim_name)
109113

110114
# Bootstrap phase: Parse arguments, make private managers
@@ -122,7 +126,7 @@ def __init__(
122126

123127
self._plugin_manager = PluginManager(self.model_specification.plugins)
124128

125-
self._logging = self._plugin_manager.get_plugin("logging")
129+
self._logging = self._plugin_manager.get_plugin(LoggingManager)
126130
self._logging.configure_logging(
127131
simulation_name=self.name,
128132
verbosity=logging_verbosity,
@@ -133,7 +137,7 @@ def __init__(
133137

134138
# This formally starts the initialization phase (this call makes the
135139
# life-cycle manager).
136-
self._lifecycle = self._plugin_manager.get_plugin("lifecycle")
140+
self._lifecycle = self._plugin_manager.get_plugin(LifeCycleManager)
137141
self._lifecycle.add_phase("setup", ["setup", "post_setup", "population_creation"])
138142
self._lifecycle.add_phase(
139143
"main_loop",
@@ -142,21 +146,22 @@ def __init__(
142146
)
143147
self._lifecycle.add_phase("simulation_end", ["simulation_end", "report"])
144148

145-
self._component_manager = self._plugin_manager.get_plugin("component_manager")
149+
self._component_manager = self._plugin_manager.get_plugin(ComponentManager)
146150
self._component_manager.setup_manager(self.configuration, self._lifecycle)
147151

148-
self._clock = self._plugin_manager.get_plugin("clock")
149-
self._values = self._plugin_manager.get_plugin("value")
150-
self._events = self._plugin_manager.get_plugin("event")
151-
self._population = self._plugin_manager.get_plugin("population")
152-
self._resource = self._plugin_manager.get_plugin("resource")
153-
self._results = self._plugin_manager.get_plugin("results")
154-
self._tables = self._plugin_manager.get_plugin("lookup")
155-
self._randomness = self._plugin_manager.get_plugin("randomness")
156-
self._data = self._plugin_manager.get_plugin("data")
152+
self._clock = self._plugin_manager.get_plugin(SimulationClock)
153+
self._values = self._plugin_manager.get_plugin(ValuesManager)
154+
self._events = self._plugin_manager.get_plugin(EventManager)
155+
self._population = self._plugin_manager.get_plugin(PopulationManager)
156+
self._resource = self._plugin_manager.get_plugin(ResourceManager)
157+
self._results = self._plugin_manager.get_plugin(ResultsManager)
158+
self._tables = self._plugin_manager.get_plugin(LookupTableManager)
159+
self._randomness = self._plugin_manager.get_plugin(RandomnessManager)
160+
self._data = self._plugin_manager.get_plugin(ArtifactManager)
157161

158-
for name, controller in self._plugin_manager.get_optional_controllers().items():
159-
setattr(self, f"_{name}", controller)
162+
optional_managers = self._plugin_manager.get_optional_controllers()
163+
for name in optional_managers:
164+
setattr(self, f"_{name}", optional_managers[name])
160165

161166
# The order the managers are added is important. It represents the
162167
# order in which they will be set up. The logging manager and the clock are
@@ -179,16 +184,14 @@ def __init__(
179184
] + list(self._plugin_manager.get_optional_controllers().values())
180185
self._component_manager.add_managers(managers)
181186

182-
component_config_parser = self._plugin_manager.get_plugin(
183-
"component_configuration_parser"
184-
)
187+
component_config_parser = self._plugin_manager.get_component_config_parser()
185188
# Tack extra components onto the end of the list generated from the model specification.
186-
components = (
189+
components_list: list[Component] = (
187190
component_config_parser.get_components(self._component_configuration)
188191
+ self._additional_components
189192
)
190193

191-
non_components = [obj for obj in components if not isinstance(obj, Component)]
194+
non_components = [obj for obj in components_list if not isinstance(obj, Component)]
192195
if non_components:
193196
message = (
194197
"Attempting to create a simulation with the following components "
@@ -202,7 +205,7 @@ def __init__(
202205
self.get_population, restrict_during=["initialization", "setup", "post_setup"]
203206
)
204207

205-
self.add_components(components)
208+
self.add_components(components_list)
206209

207210
@property
208211
def name(self) -> str:
@@ -241,7 +244,7 @@ def setup(self) -> None:
241244

242245
post_setup = self._builder.event.get_emitter("post_setup")
243246
self._lifecycle.set_state("post_setup")
244-
post_setup(None)
247+
post_setup(pd.Index([]), None)
245248

246249
def initialize_simulants(self) -> None:
247250
self._lifecycle.set_state("population_creation")
@@ -262,29 +265,29 @@ def step(self) -> None:
262265
self._clock.event_time,
263266
)
264267
self._logger.debug(f"Updating: {len(pop_to_update)}")
265-
self.time_step_emitters[event](pop_to_update)
268+
self.time_step_emitters[event](pop_to_update, None)
266269
self._clock.step_forward(self.get_population().index)
267270

268271
def run(
269272
self,
270273
backup_path: Path | None = None,
271-
backup_freq: int | float = None,
274+
backup_freq: int | float | None = None,
272275
) -> None:
273-
if backup_freq:
276+
if backup_freq and backup_path:
274277
time_to_save = time() + backup_freq
275-
while self.current_time < self._clock.stop_time:
278+
while self.current_time < self._clock.stop_time: # type: ignore [operator]
276279
self.step()
277280
if time() >= time_to_save:
278281
self._logger.debug(f"Writing Simulation Backup to {backup_path}")
279282
self.write_backup(backup_path)
280283
time_to_save = time() + backup_freq
281284
else:
282-
while self.current_time < self._clock.stop_time:
285+
while self.current_time < self._clock.stop_time: # type: ignore [operator]
283286
self.step()
284287

285288
def finalize(self) -> None:
286289
self._lifecycle.set_state("simulation_end")
287-
self.end_emitter(self.get_population().index)
290+
self.end_emitter(self.get_population().index, None)
288291
unused_config_keys = self.configuration.unused_keys()
289292
if unused_config_keys:
290293
self._logger.warning(
@@ -293,17 +296,17 @@ def finalize(self) -> None:
293296

294297
def report(self, print_results: bool = True) -> None:
295298
self._lifecycle.set_state("report")
296-
self.report_emitter(self.get_population().index)
299+
self.report_emitter(self.get_population().index, None)
297300
results = self.get_results()
298301
if print_results:
299302
for measure, df in results.items():
300303
self._logger.info(f"\n{measure}:\n{pformat(df)}")
301304
performance_metrics = self.get_performance_metrics()
302-
performance_metrics = performance_metrics.to_string(
305+
performance_metrics_str: str = performance_metrics.to_string(
303306
index=False,
304307
float_format=lambda x: f"{x:.2f}",
305308
)
306-
self._logger.info("\n" + performance_metrics)
309+
self._logger.info("\n" + performance_metrics_str)
307310
self._write_results(results)
308311

309312
def _write_results(self, results: dict[str, pd.DataFrame]) -> None:
@@ -345,7 +348,7 @@ def add_components(self, component_list: list[Component]) -> None:
345348
def get_population(self, untracked: bool = True) -> pd.DataFrame:
346349
return self._population.get_population(untracked)
347350

348-
def __repr__(self):
351+
def __repr__(self) -> str:
349352
return f"SimulationContext({self.name})"
350353

351354
def get_number_of_steps_remaining(self) -> int:
@@ -365,64 +368,60 @@ class Builder:
365368
366369
"""
367370

368-
def __init__(self, configuration: LayeredConfigTree, plugin_manager):
371+
def __init__(
372+
self, configuration: LayeredConfigTree, plugin_manager: PluginManager
373+
) -> None:
369374
self.configuration = configuration
370375
"""Provides access to the :ref:`configuration<configuration_concept>`"""
371376

372-
self.logging: LoggingInterface = plugin_manager.get_plugin_interface("logging")
377+
self.logging = plugin_manager.get_plugin_interface(LoggingInterface)
373378
"""Provides access to the :ref:`logging<logging_concept>` system."""
374379

375-
self.lookup: LookupTableInterface = plugin_manager.get_plugin_interface("lookup")
380+
self.lookup = plugin_manager.get_plugin_interface(LookupTableInterface)
376381
"""Provides access to simulant-specific data via the
377382
:ref:`lookup table<lookup_concept>` abstraction."""
378383

379-
self.value: ValuesInterface = plugin_manager.get_plugin_interface("value")
384+
self.value = plugin_manager.get_plugin_interface(ValuesInterface)
380385
"""Provides access to computed simulant attribute values via the
381386
:ref:`value pipeline<values_concept>` system."""
382387

383-
self.event: EventInterface = plugin_manager.get_plugin_interface("event")
388+
self.event = plugin_manager.get_plugin_interface(EventInterface)
384389
"""Provides access to event listeners utilized in the
385390
:ref:`event<event_concept>` system."""
386391

387-
self.population: PopulationInterface = plugin_manager.get_plugin_interface(
388-
"population"
389-
)
392+
self.population = plugin_manager.get_plugin_interface(PopulationInterface)
390393
"""Provides access to simulant state table via the
391394
:ref:`population<population_concept>` system."""
392395

393-
self.resources: ResourceInterface = plugin_manager.get_plugin_interface("resource")
396+
self.resources = plugin_manager.get_plugin_interface(ResourceInterface)
394397
"""Provides access to the :ref:`resource<resource_concept>` system,
395398
which manages dependencies between components.
396399
"""
397400

398-
self.results: ResultsInterface = plugin_manager.get_plugin_interface("results")
401+
self.results = plugin_manager.get_plugin_interface(ResultsInterface)
399402
"""Provides access to the :ref:`results<results_concept>` system."""
400403

401-
self.randomness: RandomnessInterface = plugin_manager.get_plugin_interface(
402-
"randomness"
403-
)
404+
self.randomness = plugin_manager.get_plugin_interface(RandomnessInterface)
404405
"""Provides access to the :ref:`randomness<crn_concept>` system."""
405406

406-
self.time: TimeInterface = plugin_manager.get_plugin_interface("clock")
407+
self.time: TimeInterface = plugin_manager.get_plugin_interface(TimeInterface)
407408
"""Provides access to the simulation's :ref:`clock<time_concept>`."""
408409

409-
self.components: ComponentInterface = plugin_manager.get_plugin_interface(
410-
"component_manager"
411-
)
410+
self.components = plugin_manager.get_plugin_interface(ComponentInterface)
412411
"""Provides access to the :ref:`component management<components_concept>`
413412
system, which maintains a reference to all managers and components in
414413
the simulation."""
415414

416-
self.lifecycle: LifeCycleInterface = plugin_manager.get_plugin_interface("lifecycle")
415+
self.lifecycle = plugin_manager.get_plugin_interface(LifeCycleInterface)
417416
"""Provides access to the :ref:`life-cycle<lifecycle_concept>` system,
418417
which manages the simulation's execution life-cycle."""
419418

420-
self.data = plugin_manager.get_plugin_interface("data") # type: ArtifactInterface
419+
self.data = plugin_manager.get_plugin_interface(ArtifactInterface)
421420
"""Provides access to the simulation's input data housed in the
422421
:ref:`data artifact<data_concept>`."""
423422

424423
for name, interface in plugin_manager.get_optional_interfaces().items():
425424
setattr(self, name, interface)
426425

427-
def __repr__(self):
426+
def __repr__(self) -> str:
428427
return "Builder()"

0 commit comments

Comments
 (0)