Skip to content

Commit e8c5460

Browse files
committed
VTU Mypy (#34)
* type hint for calculations * in flux * in progress * do approximately half of typing * removoe more series types * lint * data_loader * type measures.py * fix strats * type test_interface * type test_formatting * type test_data_schema * refactor patch * lint * remove fuzzy checker ignores * "delete unused types.py file" * merge changes from refactor and fix * address comments * remve ref ro 'fuzzy' * move error up * remove trailing commas * change to collection * empty commit * lint * remove unused import
1 parent 10a1532 commit e8c5460

File tree

15 files changed

+121
-99
lines changed

15 files changed

+121
-99
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ exclude = [
3030
"build",
3131
"setup.py",
3232
"docs/source/conf.py",
33+
"src/vivarium_testing_utils/automated_validation/comparison.py", # Only stubbed out currently
34+
"src/vivarium_testing_utils/automated_validation/plot_utils.py", # Only stubbed out currently
3335
]
36+
plugins = ["pandera.mypy"]
3437

3538
# handle mypy errors when 3rd party packages are not typed.
3639
[[tool.mypy.overrides]]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747
"vivarium_dependencies[pandas,numpy,pyyaml,scipy,click,tables,loguru,networkx]",
4848
"vivarium_build_utils>=2.0.1,<3.0.0",
4949
"pyarrow",
50+
"vivarium",
5051
# Type stubs
5152
"types-setuptools",
52-
"vivarium",
5353
]
5454

5555
setup_requires = ["setuptools_scm"]
Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from abc import ABC, abstractmethod
2+
from typing import Collection
23

34
import pandas as pd
45

5-
from vivarium_testing_utils.automated_validation.data_transformation.data_schema import (
6-
RatioData,
7-
SingleNumericColumn,
8-
)
96
from vivarium_testing_utils.automated_validation.data_transformation.measures import (
107
Measure,
118
RatioMeasure,
@@ -24,36 +21,36 @@ class Comparison(ABC):
2421
stratifications: list[str]
2522

2623
@abstractmethod
27-
def verify(self, stratifications: list[str]):
24+
def verify(self, stratifications: Collection[str] = ()):
2825
pass
2926

3027
@abstractmethod
31-
def summarize(self, stratifications: list[str]):
28+
def summarize(self, stratifications: Collection[str] = ()):
3229
pass
3330

3431
@abstractmethod
35-
def heads(self, stratifications: list[str]):
32+
def heads(self, stratifications: Collection[str] = ()):
3633
pass
3734

3835

39-
class FuzzyComparison:
36+
class FuzzyComparison(Comparison):
4037
def __init__(
4138
self,
4239
measure: RatioMeasure,
4340
test_data: pd.DataFrame,
4441
reference_data: pd.DataFrame,
45-
stratifications: list[str] = [],
42+
stratifications: Collection[str] = (),
4643
):
4744
self.measure = measure
4845
self.test_data = test_data
4946
self.reference_data = reference_data
5047
self.stratifications = stratifications
5148

52-
def verify(self, stratifications: list[str]):
49+
def verify(self, stratifications: Collection[str] = ()):
5350
raise NotImplementedError
5451

55-
def summarize(self, stratifications: list[str]):
52+
def summarize(self, stratifications: Collection[str] = ()):
5653
raise NotImplementedError
5754

58-
def heads(self, stratifications: list[str]):
55+
def heads(self, stratifications: Collection[str] = ()):
5956
raise NotImplementedError

src/vivarium_testing_utils/automated_validation/data_loader.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55

66
import pandas as pd
7-
import pandera as pa
87
import yaml
98
from vivarium import Artifact
109

@@ -33,16 +32,18 @@ def from_str(cls, source: str) -> DataSource:
3332

3433

3534
class DataLoader:
36-
def __init__(self, sim_output_dir: str, cache_size_mb: int = 1000):
37-
self._sim_output_dir = Path(sim_output_dir)
38-
self._results_dir = self._sim_output_dir / "results"
35+
def __init__(self, sim_output_dir: Path, cache_size_mb: int = 1000):
36+
self._sim_output_dir = sim_output_dir
3937
self._cache_size_mb = cache_size_mb
40-
self._raw_datasets = {data_source: {} for data_source in DataSource}
38+
39+
self._results_dir = self._sim_output_dir / "results"
40+
self._raw_datasets: dict[DataSource, dict[str, pd.DataFrame]] = {
41+
data_source: {} for data_source in DataSource
42+
}
4143
self._loader_mapping = {
4244
DataSource.SIM: self._load_from_sim,
4345
DataSource.GBD: self._load_from_gbd,
4446
DataSource.ARTIFACT: self._load_from_artifact,
45-
DataSource.CUSTOM: self._raise_custom_data_error,
4647
}
4748
self._artifact = self._load_artifact(self._sim_output_dir)
4849

@@ -59,14 +60,19 @@ def get_dataset(self, dataset_key: str, source: DataSource) -> pd.DataFrame:
5960
try:
6061
return self._raw_datasets[source][dataset_key].copy()
6162
except KeyError:
63+
if source == DataSource.CUSTOM:
64+
raise ValueError(
65+
f"No custom dataset found for {dataset_key}."
66+
"Please upload a dataset using ValidationContext.upload_custom_data."
67+
)
6268
dataset = self._load_from_source(dataset_key, source)
6369
self._add_to_cache(dataset_key, source, dataset)
6470
return dataset
6571

66-
def upload_custom_data(self, dataset_key: str, data: pd.DataFrame | pd.Series) -> None:
72+
def upload_custom_data(self, dataset_key: str, data: pd.DataFrame) -> None:
6773
self._add_to_cache(dataset_key, DataSource.CUSTOM, data)
6874

69-
def _load_from_source(self, dataset_key: str, source: DataSource) -> None:
75+
def _load_from_source(self, dataset_key: str, source: DataSource) -> pd.DataFrame:
7076
"""Load the data from the given source via the loader mapping."""
7177
return self._loader_mapping[source](dataset_key)
7278

@@ -102,8 +108,8 @@ def _load_from_sim(self, dataset_key: str) -> pd.DataFrame:
102108
return multi_index_df
103109

104110
@staticmethod
105-
def _load_artifact(results_dir: str) -> Artifact:
106-
model_spec_path = Path(results_dir) / "model_specification.yaml"
111+
def _load_artifact(results_dir: Path) -> Artifact:
112+
model_spec_path = results_dir / "model_specification.yaml"
107113
artifact_path = yaml.safe_load(model_spec_path.open("r"))["configuration"][
108114
"input_data"
109115
]["artifact_path"]
@@ -117,9 +123,3 @@ def _load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
117123

118124
def _load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
119125
raise NotImplementedError
120-
121-
def _raise_custom_data_error(self, dataset_key: str) -> pd.DataFrame:
122-
raise ValueError(
123-
f"No custom dataset found for {dataset_key}."
124-
"Please upload a dataset using ValidationContext.upload_custom_data."
125-
)

src/vivarium_testing_utils/automated_validation/data_transformation/calculations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import TypeVar
24

35
import pandas as pd
@@ -12,22 +14,20 @@
1214
series_to_dataframe,
1315
)
1416

15-
DataSet = TypeVar("DataSet", pd.DataFrame, pd.Series)
16-
1717
DRAW_PREFIX = "draw_"
1818

1919

20-
def align_indexes(datasets: list[DataSet]) -> list[DataSet]:
20+
def align_indexes(datasets: list[pd.DataFrame]) -> list[pd.DataFrame]:
2121
"""Put each dataframe on a common index by choosing the intersection of index columns
2222
and marginalizing over the rest."""
2323
# Get the common index columns
24-
common_index = set.intersection(*(set(data.index.names) for data in datasets))
24+
common_index = list(set.intersection(*(set(data.index.names) for data in datasets)))
2525

2626
# Marginalize over the rest
2727
return [marginalize(data, common_index) for data in datasets]
2828

2929

30-
def filter_data(data: DataSet, filter_cols: dict[str, list]) -> DataSet:
30+
def filter_data(data: pd.DataFrame, filter_cols: dict[str, list[str]]) -> pd.DataFrame:
3131
"""Filter a DataFrame by the given index columns and values.
3232
3333
The filter_cols argument
@@ -57,21 +57,21 @@ def ratio(data: pd.DataFrame, numerator: str, denominator: str) -> pd.DataFrame:
5757
return series_to_dataframe(data[numerator] / data[denominator])
5858

5959

60-
def aggregate_sum(data: DataSet, groupby_cols: list[str]) -> DataSet:
60+
def aggregate_sum(data: pd.DataFrame, groupby_cols: list[str]) -> pd.DataFrame:
6161
"""Aggregate the dataframe over the specified index columns by summing."""
6262
if not groupby_cols:
6363
return data
6464
return data.groupby(groupby_cols).sum()
6565

6666

67-
def stratify(data: DataSet, stratification_cols: list[str]) -> DataSet:
67+
def stratify(data: pd.DataFrame, stratification_cols: list[str]) -> pd.DataFrame:
6868
"""Stratify the data by the index columns, summing over everything else. Syntactic sugar for aggregate."""
6969
return aggregate_sum(data, stratification_cols)
7070

7171

72-
def marginalize(data: DataSet, marginalize_cols: list[str]) -> DataSet:
72+
def marginalize(data: pd.DataFrame, marginalize_cols: list[str]) -> pd.DataFrame:
7373
"""Sum over marginalize columns, keeping the rest. Syntactic sugar for aggregate."""
74-
return aggregate_sum(data, data.index.names.difference(marginalize_cols))
74+
return aggregate_sum(data, [x for x in data.index.names if x not in marginalize_cols])
7575

7676

7777
def linear_combination(

src/vivarium_testing_utils/automated_validation/data_transformation/formatting.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
from abc import ABC, abstractmethod
2-
31
import pandas as pd
42

53
from vivarium_testing_utils.automated_validation.data_transformation.calculations import (
64
filter_data,
75
marginalize,
86
)
9-
from vivarium_testing_utils.automated_validation.data_transformation.data_schema import (
10-
SimOutputData,
11-
)
127

138

149
class SimDataFormatter:
@@ -30,7 +25,7 @@ def __init__(self, type: str, cause: str, filter_value: str) -> None:
3025
self.filter_value = filter_value
3126
self.new_value_column_name = f"{self.filter_value}_{self.type}"
3227

33-
def format_dataset(self, dataset: SimOutputData) -> SimOutputData:
28+
def format_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
3429
"""Clean up redundant columns, filter for the state, and rename the value column."""
3530
for column, value in self.redundant_columns.items():
3631
dataset = _drop_redundant_index(
@@ -56,13 +51,13 @@ def __init__(self, cause: str, start_state: str, end_state: str) -> None:
5651
class PersonTime(SimDataFormatter):
5752
"""Formatter for simulation data that contains person time."""
5853

59-
def __init__(self, cause: str, state=None) -> None:
54+
def __init__(self, cause: str, state: str | None = None) -> None:
6055
super().__init__("person_time", cause, state or "total")
6156

6257

6358
def _drop_redundant_index(
6459
data: pd.DataFrame, idx_column_name: str, idx_column_value: str
65-
) -> None:
60+
) -> pd.DataFrame:
6661
"""Validate that a DataFrame column is singular-valued, then drop it from the index."""
6762
# TODO: Make sure we handle this case appropriately when we
6863
# want to automatically add many comparisons

src/vivarium_testing_utils/automated_validation/data_transformation/measures.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
23

34
import pandas as pd
45
import pandera as pa
@@ -25,21 +26,32 @@ class Measure(ABC):
2526
"""A Measure contains key information and methods to take raw data from a DataSource
2627
and process it into an epidemiological measure suitable for use in a Comparison."""
2728

28-
sim_datasets: dict[str, str]
29-
artifact_datasets: dict[str, str]
29+
measure_key: str
30+
31+
@property
32+
@abstractmethod
33+
def sim_datasets(self) -> dict[str, str]:
34+
"""Return a dictionary of required datasets for this measure."""
35+
pass
36+
37+
@property
38+
@abstractmethod
39+
def artifact_datasets(self) -> dict[str, str]:
40+
"""Return a dictionary of required datasets for this measure."""
41+
pass
3042

3143
@abstractmethod
32-
def get_measure_data_from_artifact(self, *args, **kwargs) -> pd.DataFrame:
44+
def get_measure_data_from_artifact(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
3345
"""Process artifact data into a format suitable for calculations."""
3446
pass
3547

3648
@abstractmethod
37-
def get_measure_data_from_sim(self, *args, **kwargs) -> pd.DataFrame:
49+
def get_measure_data_from_sim(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
3850
"""Process raw simulation data into a format suitable for calculations."""
3951
pass
4052

4153
@check_io(out=SingleNumericColumn)
42-
def get_measure_data(self, source: DataSource, *args, **kwargs) -> pd.DataFrame:
54+
def get_measure_data(self, source: DataSource, *args: Any, **kwargs: Any) -> pd.DataFrame:
4355
"""Process data from the specified source into a format suitable for calculations."""
4456
if source == DataSource.SIM:
4557
return self.get_measure_data_from_sim(*args, **kwargs)
@@ -80,15 +92,6 @@ def artifact_datasets(self) -> dict[str, str]:
8092
"artifact_data": self.measure_key,
8193
}
8294

83-
@abstractmethod
84-
def get_ratio_data_from_sim(
85-
self,
86-
numerator_data: pd.DataFrame,
87-
denominator_data: pd.DataFrame,
88-
) -> pd.DataFrame:
89-
"""Process raw simulation data into a format suitable for calculations."""
90-
pass
91-
9295
@check_io(artifact_data=SingleNumericColumn, out=SingleNumericColumn)
9396
def get_measure_data_from_artifact(self, artifact_data: pd.DataFrame) -> pd.DataFrame:
9497
return artifact_data
@@ -103,7 +106,7 @@ def get_measure_data_from_ratio(self, ratio_data: pd.DataFrame) -> pd.DataFrame:
103106
)
104107

105108
@check_io(out=SingleNumericColumn)
106-
def get_measure_data_from_sim(self, *args, **kwargs) -> pd.DataFrame:
109+
def get_measure_data_from_sim(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
107110
"""Process raw simulation data into a format suitable for calculations."""
108111
return self.get_measure_data_from_ratio(self.get_ratio_data_from_sim(*args, **kwargs))
109112

src/vivarium_testing_utils/automated_validation/data_transformation/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

3+
from typing import Any, Callable, TypeVar
4+
35
import pandas as pd
46
import pandera as pa
57

8+
F = TypeVar("F", bound=Callable[..., Any])
9+
610

7-
def check_io(**model_dict):
11+
def check_io(**model_dict: type) -> Callable[[F], F]:
812
"""
913
A wrapper for pa.check_io that automatically converts SchemaModels to schemas.
1014

0 commit comments

Comments
 (0)