11import math
2+ from collections .abc import Callable , Generator
23from itertools import product
34from pathlib import Path
45from time import time
6+ from types import MethodType
7+ from typing import Any , cast
58
69import dill
710import pandas as pd
811import pytest
9- import pytest_mock
12+ from _pytest .logging import LogCaptureFixture
13+ from layered_config_tree import LayeredConfigTree
14+ from pytest_mock import MockerFixture
1015
1116from tests .framework .results .helpers import (
1217 FAMILIARS ,
4247from vivarium .framework .values import ValuesInterface , ValuesManager
4348
4449
45- def is_same_object_method (m1 , m2 ):
46- return m1 .__func__ is m2 .__func__ and m1 .__self__ is m2 .__self__
50+ def is_same_object_method (
51+ m1 : MethodType | Callable [..., Any ], m2 : Callable [..., Any ]
52+ ) -> bool :
53+ method1 : MethodType = cast (MethodType , m1 )
54+ method2 : MethodType = cast (MethodType , m2 )
55+ return method1 .__func__ is method2 .__func__ and method1 .__self__ is method2 .__self__
4756
4857
4958@pytest .fixture ()
50- def SimulationContext ():
59+ def SimulationContext () -> Generator [ type [ SimulationContext_ ], None , None ] :
5160 yield SimulationContext_
5261 SimulationContext_ ._clear_context_cache ()
5362
5463
5564@pytest .fixture
56- def components ():
65+ def components () -> list [ Component ] :
5766 return [
5867 MockComponentA ("gretchen" , "whimsy" ),
5968 Listener ("listener" ),
6069 MockComponentB ("spoon" , "antelope" , "23" ),
6170 ]
6271
6372
64- @pytest .fixture
65- def log (mocker ):
66- return mocker .patch ("vivarium.framework.logging.manager.loguru.logger" )
67-
68-
69- def test_simulation_with_non_components (SimulationContext , components : list [Component ]):
73+ def test_simulation_with_non_components (
74+ SimulationContext : type [SimulationContext_ ], components : list [Component ]
75+ ) -> None :
7076 class NonComponent :
71- def __init__ (self ):
77+ def __init__ (self ) -> None :
7278 self .name = "non_component"
7379
74- with pytest .raises (ComponentConfigError ):
75- SimulationContext (components = components + [NonComponent ()])
80+ with pytest .raises (
81+ ComponentConfigError , match = "that do not inherit from `vivarium.Component`"
82+ ):
83+ SimulationContext (components = components + [NonComponent ()]) # type: ignore[list-item]
7684
7785
78- def test_SimulationContext_get_sim_name (SimulationContext ) :
86+ def test_SimulationContext_get_sim_name (SimulationContext : type [ SimulationContext_ ]) -> None :
7987 assert SimulationContext ._created_simulation_contexts == set ()
8088
8189 assert SimulationContext ._get_context_name (None ) == "simulation_1"
@@ -84,7 +92,9 @@ def test_SimulationContext_get_sim_name(SimulationContext):
8492 assert SimulationContext ._created_simulation_contexts == {"simulation_1" , "foo" }
8593
8694
87- def test_SimulationContext_init_default (SimulationContext , components ):
95+ def test_SimulationContext_init_default (
96+ SimulationContext : type [SimulationContext_ ], components : list [Component ]
97+ ) -> None :
8898 sim = SimulationContext (components = components )
8999
90100 assert isinstance (sim ._logging , LoggingManager )
@@ -151,7 +161,9 @@ def test_SimulationContext_init_default(SimulationContext, components):
151161 assert list (sim ._component_manager ._components ) == unpacked_components
152162
153163
154- def test_SimulationContext_name_management (SimulationContext ):
164+ def test_SimulationContext_name_management (
165+ SimulationContext : type [SimulationContext_ ],
166+ ) -> None :
155167 assert SimulationContext ._created_simulation_contexts == set ()
156168
157169 sim1 = SimulationContext ()
@@ -171,7 +183,9 @@ def test_SimulationContext_name_management(SimulationContext):
171183 }
172184
173185
174- def test_SimulationContext_run_simulation (SimulationContext , mocker ):
186+ def test_SimulationContext_run_simulation (
187+ SimulationContext : type [SimulationContext_ ], mocker : MockerFixture
188+ ) -> None :
175189 sim = SimulationContext ()
176190
177191 expected_calls = [
@@ -197,9 +211,13 @@ def test_SimulationContext_run_simulation(SimulationContext, mocker):
197211 assert actual_calls == expected_calls
198212
199213
200- def test_SimulationContext_setup_default (SimulationContext , base_config , components ):
214+ def test_SimulationContext_setup_default (
215+ SimulationContext : type [SimulationContext_ ],
216+ base_config : LayeredConfigTree ,
217+ components : list [Component ],
218+ ) -> None :
201219 sim = SimulationContext (base_config , components )
202- listener = [c for c in components if "listener" in c .args ][0 ]
220+ listener : Listener = cast ( Listener , [c for c in components if "listener" in c .name ][0 ])
203221 assert not listener .post_setup_called
204222 sim .setup ()
205223
@@ -212,7 +230,12 @@ def test_SimulationContext_setup_default(SimulationContext, base_config, compone
212230 for a , b in zip (sim ._component_manager ._components , unpacked_components ):
213231 assert type (a ) == type (b )
214232 if hasattr (a , "args" ):
215- assert a .args == b .args
233+ if isinstance (a , (MockComponentA , MockComponentB , Listener )) and isinstance (
234+ b , (MockComponentA , MockComponentB , Listener )
235+ ):
236+ assert a .args == b .args
237+ else :
238+ raise RuntimeError ("Unexpected component type" )
216239
217240 assert is_same_object_method (sim .simulant_creator , sim ._population ._create_simulants )
218241 assert sim .time_step_events == [
@@ -233,7 +256,11 @@ def test_SimulationContext_setup_default(SimulationContext, base_config, compone
233256 assert listener .post_setup_called
234257
235258
236- def test_SimulationContext_initialize_simulants (SimulationContext , base_config , components ):
259+ def test_SimulationContext_initialize_simulants (
260+ SimulationContext : type [SimulationContext_ ],
261+ base_config : LayeredConfigTree ,
262+ components : list [Component ],
263+ ) -> None :
237264 sim = SimulationContext (base_config , components )
238265 sim .setup ()
239266 pop_size = sim .configuration .population .population_size
@@ -245,24 +272,30 @@ def test_SimulationContext_initialize_simulants(SimulationContext, base_config,
245272 assert sim ._clock .time == current_time
246273
247274
248- def test_SimulationContext_step (SimulationContext , log , base_config , components ):
275+ def test_SimulationContext_step (
276+ SimulationContext : type [SimulationContext_ ],
277+ base_config : LayeredConfigTree ,
278+ components : list [Component ],
279+ caplog : LogCaptureFixture ,
280+ ) -> None :
249281 sim = SimulationContext (base_config , components )
250282 sim .setup ()
251283 sim .initialize_simulants ()
252284
253285 current_time = sim ._clock .time
254286 step_size = sim ._clock .step_size
255287
256- listener = [c for c in components if "listener" in c .args ][0 ]
288+ listener : Listener = cast ( Listener , [c for c in components if "listener" in c .name ][0 ])
257289
258290 assert not listener .time_step_prepare_called
259291 assert not listener .time_step_called
260292 assert not listener .time_step_cleanup_called
261293 assert not listener .collect_metrics_called
262294
295+ assert f"{ current_time } " not in caplog .text
263296 sim .step ()
297+ assert f"{ current_time } " in caplog .text
264298
265- assert log .debug .called_once_with (current_time )
266299 assert listener .time_step_prepare_called
267300 assert listener .time_step_called
268301 assert listener .time_step_cleanup_called
@@ -271,9 +304,13 @@ def test_SimulationContext_step(SimulationContext, log, base_config, components)
271304 assert sim ._clock .time == current_time + step_size
272305
273306
274- def test_SimulationContext_finalize (SimulationContext , base_config , components ):
307+ def test_SimulationContext_finalize (
308+ SimulationContext : type [SimulationContext_ ],
309+ base_config : LayeredConfigTree ,
310+ components : list [Component ],
311+ ) -> None :
275312 sim = SimulationContext (base_config , components )
276- listener = [c for c in components if "listener" in c .args ][0 ]
313+ listener : Listener = cast ( Listener , [c for c in components if "listener" in c .name ][0 ])
277314 sim .setup ()
278315 sim .initialize_simulants ()
279316 sim .step ()
@@ -282,7 +319,9 @@ def test_SimulationContext_finalize(SimulationContext, base_config, components):
282319 assert listener .simulation_end_called
283320
284321
285- def test_get_results (SimulationContext , base_config ):
322+ def test_get_results (
323+ SimulationContext : type [SimulationContext_ ], base_config : LayeredConfigTree
324+ ) -> None :
286325 """Test that get_results returns expected values. This does NOT test for
287326 correct formatting.
288327 """
@@ -300,7 +339,11 @@ def test_get_results(SimulationContext, base_config):
300339 assert results .set_index (raw_results .index .names )[[VALUE_COLUMN ]].equals (raw_results )
301340
302341
303- def test_SimulationContext_report_no_write_warning (SimulationContext , base_config , caplog ):
342+ def test_SimulationContext_report_no_write_warning (
343+ SimulationContext : type [SimulationContext_ ],
344+ base_config : LayeredConfigTree ,
345+ caplog : LogCaptureFixture ,
346+ ) -> None :
304347 components = [
305348 Hogwarts (),
306349 HousePointsObserver (),
@@ -315,13 +358,20 @@ def test_SimulationContext_report_no_write_warning(SimulationContext, base_confi
315358 assert set (results ) == set (
316359 ["house_points" , "quidditch_wins" , "no_stratifications_quidditch_wins" ]
317360 )
318- assert all ([ isinstance (df , pd .DataFrame ) for df in results .values ()] )
361+ assert all (isinstance (df , pd .DataFrame ) for df in results .values ())
319362
320363
321- def test_SimulationContext_report_write (SimulationContext , base_config , components , tmpdir ):
364+ def test_SimulationContext_report_write (
365+ SimulationContext : type [SimulationContext_ ],
366+ base_config : LayeredConfigTree ,
367+ components : list [Component ],
368+ tmp_path : Path ,
369+ ) -> None :
322370 """Test that the written results match get_results"""
323- results_root = Path (tmpdir )
324- configuration = {"output_data" : {"results_directory" : str (results_root )}}
371+ results_root = tmp_path
372+ configuration : dict [str , object ] = {
373+ "output_data" : {"results_directory" : str (results_root )}
374+ }
325375 configuration .update (HARRY_POTTER_CONFIG )
326376 components = [
327377 Hogwarts (),
@@ -349,24 +399,33 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
349399 assert results .equals (written_results )
350400
351401
352- def test_SimulationContext_write_backup (mocker , SimulationContext , tmpdir ):
402+ def test_SimulationContext_write_backup (
403+ mocker : MockerFixture , SimulationContext : type [SimulationContext_ ], tmp_path : Path
404+ ) -> None :
353405 # TODO MIC-5216: Remove mocks when we can use dill in pytest.
354406 mocker .patch ("vivarium.framework.engine.dill.dump" )
355407 mocker .patch ("vivarium.framework.engine.dill.load" , return_value = SimulationContext ())
356408 sim = SimulationContext ()
357- backup_path = tmpdir / "backup.pkl"
409+ backup_path = tmp_path / "backup.pkl"
358410 sim .write_backup (backup_path )
359411 assert backup_path .exists ()
360412 with open (backup_path , "rb" ) as f :
361413 sim_backup = dill .load (f )
362414 assert isinstance (sim_backup , SimulationContext )
363415
364416
365- def test_SimulationContext_run_with_backup (mocker , SimulationContext , base_config , tmpdir ):
366- mocker .patch ("vivarium.framework.engine.SimulationContext.write_backup" )
417+ def test_SimulationContext_run_with_backup (
418+ mocker : MockerFixture ,
419+ SimulationContext : type [SimulationContext_ ],
420+ base_config : LayeredConfigTree ,
421+ tmp_path : Path ,
422+ ) -> None :
423+ mocked_write_backup = mocker .patch (
424+ "vivarium.framework.engine.SimulationContext.write_backup"
425+ )
367426 original_time = time ()
368427
369- def time_generator ():
428+ def time_generator () -> Generator [ float , None , None ] :
370429 current_time = original_time
371430 while True :
372431 yield current_time
@@ -381,14 +440,16 @@ def time_generator():
381440 HogwartsResultsStratifier (),
382441 ]
383442 sim = SimulationContext (base_config , components , configuration = HARRY_POTTER_CONFIG )
384- backup_path = tmpdir / "backup.pkl"
443+ backup_path = tmp_path / "backup.pkl"
385444 sim .setup ()
386445 sim .initialize_simulants ()
387446 sim .run (backup_path = backup_path , backup_freq = 5 )
388- assert sim . write_backup .call_count == _get_num_steps (sim )
447+ assert mocked_write_backup .call_count == _get_num_steps (sim )
389448
390449
391- def test_get_results_formatting (SimulationContext , base_config ):
450+ def test_get_results_formatting (
451+ SimulationContext : type [SimulationContext_ ], base_config : LayeredConfigTree
452+ ) -> None :
392453 """Test formatted results are as expected"""
393454 components = [
394455 Hogwarts (),
@@ -445,15 +506,15 @@ def test_get_results_formatting(SimulationContext, base_config):
445506
446507
447508def test_SimulationContext_load_from_backup (
448- mocker : pytest_mock . MockFixture ,
449- SimulationContext : SimulationContext_ ,
450- tmpdir : Path ,
451- ):
509+ mocker : MockerFixture ,
510+ SimulationContext : type [ SimulationContext_ ] ,
511+ tmp_path : Path ,
512+ ) -> None :
452513 # TODO MIC-5216: Remove mocks when we can use dill in pytest.
453514 mocker .patch ("vivarium.framework.engine.dill.dump" )
454515 mocker .patch ("vivarium.framework.engine.dill.load" , return_value = SimulationContext ())
455516 sim = SimulationContext ()
456- backup_path = tmpdir / "backup.pkl"
517+ backup_path = tmp_path / "backup.pkl"
457518 sim .write_backup (backup_path )
458519 # Load from backup
459520 sim_backup = SimulationContext .load_from_backup (backup_path )
@@ -469,9 +530,10 @@ def _convert_to_datetime(date_dict: dict[str, int]) -> pd.Timestamp:
469530 )
470531
471532
472- def _get_num_steps (sim : SimulationContext ) -> int :
533+ def _get_num_steps (sim : SimulationContext_ ) -> int :
473534 time_dict = sim .configuration .time .to_dict ()
474535 end_date = _convert_to_datetime (time_dict ["end" ])
475536 start_date = _convert_to_datetime (time_dict ["start" ])
476537 num_steps = math .ceil ((end_date - start_date ).days / time_dict ["step_size" ])
538+ assert isinstance (num_steps , int )
477539 return num_steps
0 commit comments