Skip to content

Commit 583e502

Browse files
authored
Albrja/mic-5653/Fix mypy errors in testing utilities (#622)
Albrja/mic-5653/Fix mypy errors in testing utilities Fix myy erros in testing utilities - *Category*: Type hinting - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5653 Changes and notes -fix mypy erros in testing_utilities.py ### Testing <!-- Details on how code was verified, any unit tests local for the repo, regression testing, etc. At a minimum, this should include an integration test for a framework change. Consider: plots, images, (small) csv file. -->
1 parent 7bc29ed commit 583e502

File tree

4 files changed

+48
-50
lines changed

4 files changed

+48
-50
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**3.4.2 - 05/22/25**
2+
3+
- Type-hinting: Fix mypy errors in frameworking/testing_utilties.py
4+
15
**3.4.1 - 05/21/25**
26

37
- Type-hinting: Fix mypy errors in tests/framework/results/test_context.py

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ exclude = [
3737
'src/vivarium/examples/disease_model/observer.py',
3838
'src/vivarium/examples/disease_model/population.py',
3939
'src/vivarium/examples/disease_model/risk.py',
40-
'src/vivarium/testing_utilities.py',
4140
]
4241

4342
disable_error_code = []

src/vivarium/examples/disease_model/observer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# mypy: ignore-errors
12
from typing import Any
23

34
import pandas as pd

src/vivarium/testing_utilities.py

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: ignore-errors
21
"""
32
==========================
43
Vivarium Testing Utilities
@@ -7,7 +6,10 @@
76
Utility functions and classes to make testing ``vivarium`` components easier.
87
98
"""
9+
from __future__ import annotations
1010

11+
from collections.abc import Callable, Sequence
12+
from datetime import datetime
1113
from itertools import product
1214
from pathlib import Path
1315
from typing import Any
@@ -16,11 +18,12 @@
1618
import pandas as pd
1719

1820
from vivarium import Component
19-
from vivarium.framework import randomness
2021
from vivarium.framework.engine import Builder
2122
from vivarium.framework.event import Event
2223
from vivarium.framework.population import SimulantData
2324
from vivarium.framework.randomness.index_map import IndexMap
25+
from vivarium.framework.randomness.stream import RandomnessStream
26+
from vivarium.types import ClockStepSize, ClockTime
2427

2528

2629
class NonCRNTestPopulation(Component):
@@ -64,7 +67,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
6467

6568
def on_time_step(self, event: Event) -> None:
6669
population = self.population_view.get(event.index, query="alive == 'alive'")
67-
population["age"] += event.step_size / pd.Timedelta(days=365)
70+
# This component won't work if event.step_size is an int
71+
if not isinstance(event.step_size, int):
72+
population["age"] += event.step_size / pd.Timedelta(days=365)
6873
self.population_view.update(population)
6974

7075

@@ -85,7 +90,11 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
8590
)
8691
age_draw = self.age_randomness.get_draw(pop_data.index)
8792
if age_start == age_end:
88-
age = age_draw * (pop_data.creation_window / pd.Timedelta(days=365)) + age_start
93+
# This component won't work if creation window is an int
94+
if not isinstance(pop_data.creation_window, int):
95+
age = (
96+
age_draw * (pop_data.creation_window / pd.Timedelta(days=365)) + age_start
97+
)
8998
else:
9099
age = age_draw * (age_end - age_start) + age_start
91100

@@ -104,7 +113,9 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
104113
self.population_view.update(population)
105114

106115

107-
def _build_population(core_population, location, randomness_stream):
116+
def _build_population(
117+
core_population: pd.DataFrame, location: str, randomness_stream: RandomnessStream
118+
) -> pd.DataFrame:
108119
index = core_population.index
109120

110121
population = pd.DataFrame(
@@ -124,13 +135,20 @@ def _build_population(core_population, location, randomness_stream):
124135

125136

126137
def _non_crn_build_population(
127-
index, age_start, age_end, location, creation_time, creation_window, randomness_stream
128-
):
138+
index: pd.Index[int],
139+
age_start: float,
140+
age_end: float,
141+
location: str,
142+
creation_time: ClockTime,
143+
creation_window: ClockStepSize,
144+
randomness_stream: RandomnessStream,
145+
) -> pd.DataFrame:
129146
if age_start == age_end:
130-
age = (
131-
randomness_stream.get_draw(index) * (creation_window / pd.Timedelta(days=365))
132-
+ age_start
133-
)
147+
if not isinstance(creation_window, int):
148+
age = (
149+
randomness_stream.get_draw(index) * (creation_window / pd.Timedelta(days=365))
150+
+ age_start
151+
)
134152
else:
135153
age = randomness_stream.get_draw(index) * (age_end - age_start) + age_start
136154

@@ -152,12 +170,12 @@ def _non_crn_build_population(
152170

153171
def build_table(
154172
value: Any,
155-
parameter_columns: dict = {
173+
parameter_columns: dict[str, Sequence[int]] = {
156174
"age": (0, 125),
157175
"year": (1990, 2020),
158176
},
159-
key_columns: dict = {"sex": ("Female", "Male")},
160-
value_columns: list = ["value"],
177+
key_columns: dict[str, Sequence[Any]] = {"sex": ("Female", "Male")},
178+
value_columns: list[str] = ["value"],
161179
) -> pd.DataFrame:
162180
"""
163181
@@ -191,7 +209,7 @@ def build_table(
191209
}
192210
# Build out dict of items we will need cartesian product of to make dataframe
193211
product_dict = dict(range_parameter_product)
194-
product_dict.update(key_columns)
212+
product_dict.update(key_columns) # type: ignore [arg-type]
195213
products = product(*product_dict.values())
196214

197215
rows = []
@@ -212,10 +230,12 @@ def build_table(
212230
# Transform parameter column values
213231
parameter_columns_index_values = item[: len(parameter_columns)]
214232
# Create intervals for parameter columns. Example year, year+1 for year_start and year_end
215-
parameter_columns_index_values = [
233+
unpacked_parameter_columns_index_values: list[Any] = [
216234
v for val in parameter_columns_index_values for v in (val, val + 1)
217235
]
218-
rows.append(parameter_columns_index_values + key_columns_index_values + r_values)
236+
rows.append(
237+
unpacked_parameter_columns_index_values + key_columns_index_values + r_values
238+
)
219239

220240
# Make list of parameter column names
221241
parameter_column_names = [
@@ -228,34 +248,13 @@ def build_table(
228248
)
229249

230250

231-
def make_dummy_column(name, initial_value):
232-
class DummyColumnMaker:
233-
@property
234-
def name(self):
235-
return "dummy_column_maker"
236-
237-
def setup(self, builder):
238-
self.population_view = builder.population.get_view(name)
239-
builder.population.initializes_simulants(self.make_column, creates_columns=name)
240-
241-
def make_column(self, pop_data):
242-
self.population_view.update(
243-
pd.Series(initial_value, index=pop_data.index, name=name)
244-
)
245-
246-
def __repr__(self):
247-
return f"dummy_column(name={name}, initial_value={initial_value})"
248-
249-
return DummyColumnMaker()
250-
251-
252251
def get_randomness(
253-
key="test",
254-
clock=lambda: pd.Timestamp(1990, 7, 2),
255-
seed=12345,
256-
initializes_crn_attributes=False,
257-
):
258-
return randomness.RandomnessStream(
252+
key: str = "test",
253+
clock: Callable[[], pd.Timestamp | datetime | int] = lambda: pd.Timestamp(1990, 7, 2),
254+
seed: int = 12345,
255+
initializes_crn_attributes: bool = False,
256+
) -> RandomnessStream:
257+
return RandomnessStream(
259258
key,
260259
clock,
261260
seed=seed,
@@ -264,10 +263,5 @@ def get_randomness(
264263
)
265264

266265

267-
def reset_mocks(mocks):
268-
for mock in mocks:
269-
mock.reset_mock()
270-
271-
272266
def metadata(file_path: str, layer: str = "override") -> dict[str, str]:
273267
return {"layer": layer, "source": str(Path(file_path).resolve())}

0 commit comments

Comments
 (0)