66import pandas as pd
77import pytest
88from layered_config_tree import LayeredConfigTree
9+ from pytest_mock import MockerFixture
910
1011from tests .helpers import ColumnCreator
1112from vivarium import InteractiveContext
1415from vivarium .framework .population import SimulantData
1516from vivarium .framework .resource import Resource
1617from vivarium .framework .state_machine import Machine , State , Transition
17- from vivarium .types import ClockTime , LookupTableData
18+ from vivarium .types import ClockTime , DataInput , LookupTableData
1819
1920
2021def test_initialize_allowing_self_transition () -> None :
@@ -37,33 +38,35 @@ def test_initialize_with_initial_state() -> None:
3738
3839@pytest .mark .parametrize ("weights_type" , ["artifact" , "callable" , "scalar" ])
3940def test_initialize_with_scalar_initialization_weights (
40- base_config : LayeredConfigTree , weights_type : str
41+ base_config : LayeredConfigTree , weights_type : str , mocker : MockerFixture
4142) -> None :
4243 state_weights = {"state_a.weights" : 0.2 , "state_b.weights" : 0.8 }
4344
4445 def mock_load (key : str ) -> float :
45- return state_weights . get ( key )
46+ return state_weights [ key ]
4647
4748 base_config .update (
4849 {"population" : {"population_size" : 10000 }, "randomness" : {"key_columns" : []}}
4950 )
5051
51- def initialization_weights (
52- key : str ,
53- ) -> LookupTableData | str | Callable [[Builder ], LookupTableData ]:
54- return {
52+ def initialization_weights (key : str ) -> DataInput :
53+ weights = {
5554 "artifact" : key ,
5655 "callable" : lambda _ : state_weights [key ],
5756 "scalar" : state_weights [key ],
5857 }[weights_type ]
58+ assert isinstance (weights , (str , float )) or callable (weights )
59+ return weights
5960
6061 state_a = State ("a" , initialization_weights = initialization_weights ("state_a.weights" ))
6162 state_b = State ("b" , initialization_weights = initialization_weights ("state_b.weights" ))
6263 machine = Machine ("state" , states = [state_a , state_b ])
6364 simulation = InteractiveContext (
6465 components = [machine ], configuration = base_config , setup = False
6566 )
66- simulation ._builder .data .load = mock_load
67+ mocker .patch (
68+ "vivarium.framework.artifact.manager.ArtifactInterface.load" , side_effect = mock_load
69+ )
6770 simulation .setup ()
6871
6972 state = simulation .get_population ()["state" ]
@@ -73,7 +76,9 @@ def initialization_weights(
7376
7477
7578@pytest .mark .parametrize ("weights_type" , ["artifact" , "callable" , "dataframe" ])
76- def test_initialize_with_array_initialization_weights (weights_type : str ) -> None :
79+ def test_initialize_with_array_initialization_weights (
80+ weights_type : str , mocker : MockerFixture
81+ ) -> None :
7782 state_weights = {
7883 "state_a.weights" : pd .DataFrame (
7984 {"test_column_1" : [0 , 1 , 2 ], "value" : [0.2 , 0.7 , 0.4 ]}
@@ -84,7 +89,7 @@ def test_initialize_with_array_initialization_weights(weights_type: str) -> None
8489 }
8590
8691 def mock_load (key : str ) -> pd .DataFrame :
87- return state_weights . get ( key )
92+ return state_weights [ key ]
8893
8994 config = build_simulation_configuration ()
9095 config .update (
@@ -101,20 +106,24 @@ def initialization_requirements(self) -> list[str | Resource]:
101106
102107 def initialization_weights (
103108 key : str ,
104- ) -> LookupTableData | str | Callable [[ Builder ], LookupTableData ] :
105- return {
109+ ) -> DataInput :
110+ weights = {
106111 "artifact" : key ,
107112 "callable" : lambda _ : state_weights [key ],
108113 "dataframe" : state_weights [key ],
109114 }[weights_type ]
115+ assert isinstance (weights , (str , pd .DataFrame )) or callable (weights )
116+ return weights
110117
111118 state_a = State ("a" , initialization_weights = initialization_weights ("state_a.weights" ))
112119 state_b = State ("b" , initialization_weights = initialization_weights ("state_b.weights" ))
113120 machine = TestMachine ("state" , states = [state_a , state_b ])
114121 simulation = InteractiveContext (
115122 components = [machine , ColumnCreator ()], configuration = config , setup = False
116123 )
117- simulation ._builder .data .load = mock_load
124+ mocker .patch (
125+ "vivarium.framework.artifact.manager.ArtifactInterface.load" , side_effect = mock_load
126+ )
118127 simulation .setup ()
119128
120129 pop = simulation .get_population ()[["state" , "test_column_1" ]]
0 commit comments