Skip to content

Commit 5244efc

Browse files
mypy fixes tests/framework/test_state_machine.py (#598)
1 parent a4397ac commit 5244efc

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
**3.3.10 - TBD/TBD/TBD**
1+
**3.3.10 - 03/11/25**
22

3+
- Type-hinting: Fix mypy errors in tests/framework/test_state_machine.py
34
- Type-hinting: Fix mypy errors in tests/framework/test_time.py
45

56
**3.3.9 - 03/10/25**

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ exclude = [
5353
'tests/framework/test_event.py',
5454
'tests/framework/test_lifecycle.py',
5555
'tests/framework/test_plugins.py',
56-
'tests/framework/test_state_machine.py',
5756
'tests/framework/test_values.py',
5857
'tests/helpers.py',
5958
'tests/interface/test_cli.py',

tests/framework/test_state_machine.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88
from layered_config_tree import LayeredConfigTree
9+
from pytest_mock import MockerFixture
910

1011
from tests.helpers import ColumnCreator
1112
from vivarium import InteractiveContext
@@ -14,7 +15,7 @@
1415
from vivarium.framework.population import SimulantData
1516
from vivarium.framework.resource import Resource
1617
from vivarium.framework.state_machine import Machine, State, Transition
17-
from vivarium.types import ClockTime, LookupTableData
18+
from vivarium.types import ClockTime, DataInput, LookupTableData
1819

1920

2021
def test_initialize_allowing_self_transition() -> None:
@@ -37,33 +38,35 @@ def test_initialize_with_initial_state() -> None:
3738

3839
@pytest.mark.parametrize("weights_type", ["artifact", "callable", "scalar"])
3940
def test_initialize_with_scalar_initialization_weights(
40-
base_config: LayeredConfigTree, weights_type: str
41+
base_config: LayeredConfigTree, weights_type: str, mocker: MockerFixture
4142
) -> None:
4243
state_weights = {"state_a.weights": 0.2, "state_b.weights": 0.8}
4344

4445
def mock_load(key: str) -> float:
45-
return state_weights.get(key)
46+
return state_weights[key]
4647

4748
base_config.update(
4849
{"population": {"population_size": 10000}, "randomness": {"key_columns": []}}
4950
)
5051

51-
def initialization_weights(
52-
key: str,
53-
) -> LookupTableData | str | Callable[[Builder], LookupTableData]:
54-
return {
52+
def initialization_weights(key: str) -> DataInput:
53+
weights = {
5554
"artifact": key,
5655
"callable": lambda _: state_weights[key],
5756
"scalar": state_weights[key],
5857
}[weights_type]
58+
assert isinstance(weights, (str, float)) or callable(weights)
59+
return weights
5960

6061
state_a = State("a", initialization_weights=initialization_weights("state_a.weights"))
6162
state_b = State("b", initialization_weights=initialization_weights("state_b.weights"))
6263
machine = Machine("state", states=[state_a, state_b])
6364
simulation = InteractiveContext(
6465
components=[machine], configuration=base_config, setup=False
6566
)
66-
simulation._builder.data.load = mock_load
67+
mocker.patch(
68+
"vivarium.framework.artifact.manager.ArtifactInterface.load", side_effect=mock_load
69+
)
6770
simulation.setup()
6871

6972
state = simulation.get_population()["state"]
@@ -73,7 +76,9 @@ def initialization_weights(
7376

7477

7578
@pytest.mark.parametrize("weights_type", ["artifact", "callable", "dataframe"])
76-
def test_initialize_with_array_initialization_weights(weights_type: str) -> None:
79+
def test_initialize_with_array_initialization_weights(
80+
weights_type: str, mocker: MockerFixture
81+
) -> None:
7782
state_weights = {
7883
"state_a.weights": pd.DataFrame(
7984
{"test_column_1": [0, 1, 2], "value": [0.2, 0.7, 0.4]}
@@ -84,7 +89,7 @@ def test_initialize_with_array_initialization_weights(weights_type: str) -> None
8489
}
8590

8691
def mock_load(key: str) -> pd.DataFrame:
87-
return state_weights.get(key)
92+
return state_weights[key]
8893

8994
config = build_simulation_configuration()
9095
config.update(
@@ -101,20 +106,24 @@ def initialization_requirements(self) -> list[str | Resource]:
101106

102107
def initialization_weights(
103108
key: str,
104-
) -> LookupTableData | str | Callable[[Builder], LookupTableData]:
105-
return {
109+
) -> DataInput:
110+
weights = {
106111
"artifact": key,
107112
"callable": lambda _: state_weights[key],
108113
"dataframe": state_weights[key],
109114
}[weights_type]
115+
assert isinstance(weights, (str, pd.DataFrame)) or callable(weights)
116+
return weights
110117

111118
state_a = State("a", initialization_weights=initialization_weights("state_a.weights"))
112119
state_b = State("b", initialization_weights=initialization_weights("state_b.weights"))
113120
machine = TestMachine("state", states=[state_a, state_b])
114121
simulation = InteractiveContext(
115122
components=[machine, ColumnCreator()], configuration=config, setup=False
116123
)
117-
simulation._builder.data.load = mock_load
124+
mocker.patch(
125+
"vivarium.framework.artifact.manager.ArtifactInterface.load", side_effect=mock_load
126+
)
118127
simulation.setup()
119128

120129
pop = simulation.get_population()[["state", "test_column_1"]]

0 commit comments

Comments
 (0)