Skip to content

Commit f6df1f1

Browse files
type hint tests/framework/test_engine.py (#616)
1 parent d5b628a commit f6df1f1

File tree

3 files changed

+111
-49
lines changed

3 files changed

+111
-49
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
**3.3.19 - 04/29/25**
22

3-
- Type-hinting: Fix mypy errors in test/framework/test_values.py
3+
- Type-hinting: Fix mypy errors in tests/framework/test_values.py
4+
- Type-hinting: Fix mypy errors in tests/framework/test_engine.py
45

56
**3.3.18 - 04/24/25**
67

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ exclude = [
4040
'src/vivarium/examples/disease_model/risk.py',
4141
'src/vivarium/testing_utilities.py',
4242
'tests/framework/results/test_context.py',
43-
'tests/framework/test_engine.py',
4443
'tests/framework/test_event.py',
4544
'tests/interface/test_utilities.py',
4645
]

tests/framework/test_engine.py

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import math
2+
from collections.abc import Callable, Generator
23
from itertools import product
34
from pathlib import Path
45
from time import time
6+
from types import MethodType
7+
from typing import Any, cast
58

69
import dill
710
import pandas as pd
811
import pytest
9-
import pytest_mock
12+
from _pytest.logging import LogCaptureFixture
13+
from layered_config_tree import LayeredConfigTree
14+
from pytest_mock import MockerFixture
1015

1116
from tests.framework.results.helpers import (
1217
FAMILIARS,
@@ -42,40 +47,43 @@
4247
from vivarium.framework.values import ValuesInterface, ValuesManager
4348

4449

45-
def is_same_object_method(m1, m2):
46-
return m1.__func__ is m2.__func__ and m1.__self__ is m2.__self__
50+
def is_same_object_method(
51+
m1: MethodType | Callable[..., Any], m2: Callable[..., Any]
52+
) -> bool:
53+
method1: MethodType = cast(MethodType, m1)
54+
method2: MethodType = cast(MethodType, m2)
55+
return method1.__func__ is method2.__func__ and method1.__self__ is method2.__self__
4756

4857

4958
@pytest.fixture()
50-
def SimulationContext():
59+
def SimulationContext() -> Generator[type[SimulationContext_], None, None]:
5160
yield SimulationContext_
5261
SimulationContext_._clear_context_cache()
5362

5463

5564
@pytest.fixture
56-
def components():
65+
def components() -> list[Component]:
5766
return [
5867
MockComponentA("gretchen", "whimsy"),
5968
Listener("listener"),
6069
MockComponentB("spoon", "antelope", "23"),
6170
]
6271

6372

64-
@pytest.fixture
65-
def log(mocker):
66-
return mocker.patch("vivarium.framework.logging.manager.loguru.logger")
67-
68-
69-
def test_simulation_with_non_components(SimulationContext, components: list[Component]):
73+
def test_simulation_with_non_components(
74+
SimulationContext: type[SimulationContext_], components: list[Component]
75+
) -> None:
7076
class NonComponent:
71-
def __init__(self):
77+
def __init__(self) -> None:
7278
self.name = "non_component"
7379

74-
with pytest.raises(ComponentConfigError):
75-
SimulationContext(components=components + [NonComponent()])
80+
with pytest.raises(
81+
ComponentConfigError, match="that do not inherit from `vivarium.Component`"
82+
):
83+
SimulationContext(components=components + [NonComponent()]) # type: ignore[list-item]
7684

7785

78-
def test_SimulationContext_get_sim_name(SimulationContext):
86+
def test_SimulationContext_get_sim_name(SimulationContext: type[SimulationContext_]) -> None:
7987
assert SimulationContext._created_simulation_contexts == set()
8088

8189
assert SimulationContext._get_context_name(None) == "simulation_1"
@@ -84,7 +92,9 @@ def test_SimulationContext_get_sim_name(SimulationContext):
8492
assert SimulationContext._created_simulation_contexts == {"simulation_1", "foo"}
8593

8694

87-
def test_SimulationContext_init_default(SimulationContext, components):
95+
def test_SimulationContext_init_default(
96+
SimulationContext: type[SimulationContext_], components: list[Component]
97+
) -> None:
8898
sim = SimulationContext(components=components)
8999

90100
assert isinstance(sim._logging, LoggingManager)
@@ -151,7 +161,9 @@ def test_SimulationContext_init_default(SimulationContext, components):
151161
assert list(sim._component_manager._components) == unpacked_components
152162

153163

154-
def test_SimulationContext_name_management(SimulationContext):
164+
def test_SimulationContext_name_management(
165+
SimulationContext: type[SimulationContext_],
166+
) -> None:
155167
assert SimulationContext._created_simulation_contexts == set()
156168

157169
sim1 = SimulationContext()
@@ -171,7 +183,9 @@ def test_SimulationContext_name_management(SimulationContext):
171183
}
172184

173185

174-
def test_SimulationContext_run_simulation(SimulationContext, mocker):
186+
def test_SimulationContext_run_simulation(
187+
SimulationContext: type[SimulationContext_], mocker: MockerFixture
188+
) -> None:
175189
sim = SimulationContext()
176190

177191
expected_calls = [
@@ -197,9 +211,13 @@ def test_SimulationContext_run_simulation(SimulationContext, mocker):
197211
assert actual_calls == expected_calls
198212

199213

200-
def test_SimulationContext_setup_default(SimulationContext, base_config, components):
214+
def test_SimulationContext_setup_default(
215+
SimulationContext: type[SimulationContext_],
216+
base_config: LayeredConfigTree,
217+
components: list[Component],
218+
) -> None:
201219
sim = SimulationContext(base_config, components)
202-
listener = [c for c in components if "listener" in c.args][0]
220+
listener: Listener = cast(Listener, [c for c in components if "listener" in c.name][0])
203221
assert not listener.post_setup_called
204222
sim.setup()
205223

@@ -212,7 +230,12 @@ def test_SimulationContext_setup_default(SimulationContext, base_config, compone
212230
for a, b in zip(sim._component_manager._components, unpacked_components):
213231
assert type(a) == type(b)
214232
if hasattr(a, "args"):
215-
assert a.args == b.args
233+
if isinstance(a, (MockComponentA, MockComponentB, Listener)) and isinstance(
234+
b, (MockComponentA, MockComponentB, Listener)
235+
):
236+
assert a.args == b.args
237+
else:
238+
raise RuntimeError("Unexpected component type")
216239

217240
assert is_same_object_method(sim.simulant_creator, sim._population._create_simulants)
218241
assert sim.time_step_events == [
@@ -233,7 +256,11 @@ def test_SimulationContext_setup_default(SimulationContext, base_config, compone
233256
assert listener.post_setup_called
234257

235258

236-
def test_SimulationContext_initialize_simulants(SimulationContext, base_config, components):
259+
def test_SimulationContext_initialize_simulants(
260+
SimulationContext: type[SimulationContext_],
261+
base_config: LayeredConfigTree,
262+
components: list[Component],
263+
) -> None:
237264
sim = SimulationContext(base_config, components)
238265
sim.setup()
239266
pop_size = sim.configuration.population.population_size
@@ -245,24 +272,30 @@ def test_SimulationContext_initialize_simulants(SimulationContext, base_config,
245272
assert sim._clock.time == current_time
246273

247274

248-
def test_SimulationContext_step(SimulationContext, log, base_config, components):
275+
def test_SimulationContext_step(
276+
SimulationContext: type[SimulationContext_],
277+
base_config: LayeredConfigTree,
278+
components: list[Component],
279+
caplog: LogCaptureFixture,
280+
) -> None:
249281
sim = SimulationContext(base_config, components)
250282
sim.setup()
251283
sim.initialize_simulants()
252284

253285
current_time = sim._clock.time
254286
step_size = sim._clock.step_size
255287

256-
listener = [c for c in components if "listener" in c.args][0]
288+
listener: Listener = cast(Listener, [c for c in components if "listener" in c.name][0])
257289

258290
assert not listener.time_step_prepare_called
259291
assert not listener.time_step_called
260292
assert not listener.time_step_cleanup_called
261293
assert not listener.collect_metrics_called
262294

295+
assert f"{current_time}" not in caplog.text
263296
sim.step()
297+
assert f"{current_time}" in caplog.text
264298

265-
assert log.debug.called_once_with(current_time)
266299
assert listener.time_step_prepare_called
267300
assert listener.time_step_called
268301
assert listener.time_step_cleanup_called
@@ -271,9 +304,13 @@ def test_SimulationContext_step(SimulationContext, log, base_config, components)
271304
assert sim._clock.time == current_time + step_size
272305

273306

274-
def test_SimulationContext_finalize(SimulationContext, base_config, components):
307+
def test_SimulationContext_finalize(
308+
SimulationContext: type[SimulationContext_],
309+
base_config: LayeredConfigTree,
310+
components: list[Component],
311+
) -> None:
275312
sim = SimulationContext(base_config, components)
276-
listener = [c for c in components if "listener" in c.args][0]
313+
listener: Listener = cast(Listener, [c for c in components if "listener" in c.name][0])
277314
sim.setup()
278315
sim.initialize_simulants()
279316
sim.step()
@@ -282,7 +319,9 @@ def test_SimulationContext_finalize(SimulationContext, base_config, components):
282319
assert listener.simulation_end_called
283320

284321

285-
def test_get_results(SimulationContext, base_config):
322+
def test_get_results(
323+
SimulationContext: type[SimulationContext_], base_config: LayeredConfigTree
324+
) -> None:
286325
"""Test that get_results returns expected values. This does NOT test for
287326
correct formatting.
288327
"""
@@ -300,7 +339,11 @@ def test_get_results(SimulationContext, base_config):
300339
assert results.set_index(raw_results.index.names)[[VALUE_COLUMN]].equals(raw_results)
301340

302341

303-
def test_SimulationContext_report_no_write_warning(SimulationContext, base_config, caplog):
342+
def test_SimulationContext_report_no_write_warning(
343+
SimulationContext: type[SimulationContext_],
344+
base_config: LayeredConfigTree,
345+
caplog: LogCaptureFixture,
346+
) -> None:
304347
components = [
305348
Hogwarts(),
306349
HousePointsObserver(),
@@ -315,13 +358,20 @@ def test_SimulationContext_report_no_write_warning(SimulationContext, base_confi
315358
assert set(results) == set(
316359
["house_points", "quidditch_wins", "no_stratifications_quidditch_wins"]
317360
)
318-
assert all([isinstance(df, pd.DataFrame) for df in results.values()])
361+
assert all(isinstance(df, pd.DataFrame) for df in results.values())
319362

320363

321-
def test_SimulationContext_report_write(SimulationContext, base_config, components, tmpdir):
364+
def test_SimulationContext_report_write(
365+
SimulationContext: type[SimulationContext_],
366+
base_config: LayeredConfigTree,
367+
components: list[Component],
368+
tmp_path: Path,
369+
) -> None:
322370
"""Test that the written results match get_results"""
323-
results_root = Path(tmpdir)
324-
configuration = {"output_data": {"results_directory": str(results_root)}}
371+
results_root = tmp_path
372+
configuration: dict[str, object] = {
373+
"output_data": {"results_directory": str(results_root)}
374+
}
325375
configuration.update(HARRY_POTTER_CONFIG)
326376
components = [
327377
Hogwarts(),
@@ -349,24 +399,33 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
349399
assert results.equals(written_results)
350400

351401

352-
def test_SimulationContext_write_backup(mocker, SimulationContext, tmpdir):
402+
def test_SimulationContext_write_backup(
403+
mocker: MockerFixture, SimulationContext: type[SimulationContext_], tmp_path: Path
404+
) -> None:
353405
# TODO MIC-5216: Remove mocks when we can use dill in pytest.
354406
mocker.patch("vivarium.framework.engine.dill.dump")
355407
mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext())
356408
sim = SimulationContext()
357-
backup_path = tmpdir / "backup.pkl"
409+
backup_path = tmp_path / "backup.pkl"
358410
sim.write_backup(backup_path)
359411
assert backup_path.exists()
360412
with open(backup_path, "rb") as f:
361413
sim_backup = dill.load(f)
362414
assert isinstance(sim_backup, SimulationContext)
363415

364416

365-
def test_SimulationContext_run_with_backup(mocker, SimulationContext, base_config, tmpdir):
366-
mocker.patch("vivarium.framework.engine.SimulationContext.write_backup")
417+
def test_SimulationContext_run_with_backup(
418+
mocker: MockerFixture,
419+
SimulationContext: type[SimulationContext_],
420+
base_config: LayeredConfigTree,
421+
tmp_path: Path,
422+
) -> None:
423+
mocked_write_backup = mocker.patch(
424+
"vivarium.framework.engine.SimulationContext.write_backup"
425+
)
367426
original_time = time()
368427

369-
def time_generator():
428+
def time_generator() -> Generator[float, None, None]:
370429
current_time = original_time
371430
while True:
372431
yield current_time
@@ -381,14 +440,16 @@ def time_generator():
381440
HogwartsResultsStratifier(),
382441
]
383442
sim = SimulationContext(base_config, components, configuration=HARRY_POTTER_CONFIG)
384-
backup_path = tmpdir / "backup.pkl"
443+
backup_path = tmp_path / "backup.pkl"
385444
sim.setup()
386445
sim.initialize_simulants()
387446
sim.run(backup_path=backup_path, backup_freq=5)
388-
assert sim.write_backup.call_count == _get_num_steps(sim)
447+
assert mocked_write_backup.call_count == _get_num_steps(sim)
389448

390449

391-
def test_get_results_formatting(SimulationContext, base_config):
450+
def test_get_results_formatting(
451+
SimulationContext: type[SimulationContext_], base_config: LayeredConfigTree
452+
) -> None:
392453
"""Test formatted results are as expected"""
393454
components = [
394455
Hogwarts(),
@@ -445,15 +506,15 @@ def test_get_results_formatting(SimulationContext, base_config):
445506

446507

447508
def test_SimulationContext_load_from_backup(
448-
mocker: pytest_mock.MockFixture,
449-
SimulationContext: SimulationContext_,
450-
tmpdir: Path,
451-
):
509+
mocker: MockerFixture,
510+
SimulationContext: type[SimulationContext_],
511+
tmp_path: Path,
512+
) -> None:
452513
# TODO MIC-5216: Remove mocks when we can use dill in pytest.
453514
mocker.patch("vivarium.framework.engine.dill.dump")
454515
mocker.patch("vivarium.framework.engine.dill.load", return_value=SimulationContext())
455516
sim = SimulationContext()
456-
backup_path = tmpdir / "backup.pkl"
517+
backup_path = tmp_path / "backup.pkl"
457518
sim.write_backup(backup_path)
458519
# Load from backup
459520
sim_backup = SimulationContext.load_from_backup(backup_path)
@@ -469,9 +530,10 @@ def _convert_to_datetime(date_dict: dict[str, int]) -> pd.Timestamp:
469530
)
470531

471532

472-
def _get_num_steps(sim: SimulationContext) -> int:
533+
def _get_num_steps(sim: SimulationContext_) -> int:
473534
time_dict = sim.configuration.time.to_dict()
474535
end_date = _convert_to_datetime(time_dict["end"])
475536
start_date = _convert_to_datetime(time_dict["start"])
476537
num_steps = math.ceil((end_date - start_date).days / time_dict["step_size"])
538+
assert isinstance(num_steps, int)
477539
return num_steps

0 commit comments

Comments
 (0)