Skip to content

Commit d73ded1

Browse files
authored
Epic/auto validation (#100)
Epic/auto validation phase 1 Phase 1 Automated V&V work including new ValidationContext tool - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-XYZ Changes and notes -Automated V&V phase 1
1 parent e5e4568 commit d73ded1

36 files changed

+8942
-3
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,8 @@ notebooks/
115115

116116
# Version file
117117
src/vivarium_testing_utils/_version.py
118+
119+
# Copilot instructions
120+
.github/copilot_instructions.md
121+
.github/prompts/
122+
.github/chatmodes/

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**0.3.0 - 12/12/25**
2+
3+
- Phase 1 Automated Validation, ValidationContext component for simulation validation
4+
15
**0.2.6 - 11/20/25**
26

37
- Improve 'make build-env': better handle args and make the env name optional

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ source = ["vivarium_testing_utils"]
1010

1111
[tool.coverage.report]
1212
show_missing = true
13+
exclude_also = [
14+
"raise NotImplementedError",
15+
"@(abc\\.)?abstractmethod",
16+
]
1317

1418
[tool.black]
1519
line_length = 94
@@ -30,13 +34,17 @@ exclude = [
3034
"build",
3135
"setup.py",
3236
"docs/source/conf.py",
37+
"src/vivarium_testing_utils/automated_validation/visualization/plot_utils.py", # Only stubbed out currently
3338
]
39+
plugins = ["pandera.mypy"]
3440

3541
# handle mypy errors when 3rd party packages are not typed.
3642
[[tool.mypy.overrides]]
3743
module = [
3844
"py._path.local",
3945
"scipy.*",
4046
# "sklearn.*",
47+
"vivarium_inputs.*",
48+
"gbd_mapping.*",
4149
]
42-
ignore_missing_imports = true
50+
ignore_missing_imports = true

setup.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,31 @@
4444
long_description = f.read()
4545

4646
install_requirements = [
47-
"vivarium_dependencies[pandas,numpy,pyyaml,scipy,click,tables,loguru,networkx]",
47+
"vivarium_dependencies[numpy,pyyaml,scipy,click,tables,loguru,networkx]",
48+
"pandas",
4849
"vivarium_build_utils>=2.0.1,<3.0.0",
50+
"pyarrow",
51+
"seaborn",
4952
# Type stubs
5053
"types-setuptools",
5154
]
5255

5356
setup_requires = ["setuptools_scm"]
5457

58+
validation_requirements = [
59+
"vivarium>=3.4.0",
60+
"vivarium-inputs>=7.1.0, <8.0.0",
61+
"pandera<0.23.0",
62+
"gbd_mapping",
63+
]
64+
5565
interactive_requirements = [
5666
"vivarium_dependencies[interactive]",
5767
]
5868

5969
test_requirements = [
6070
"vivarium_dependencies[pytest]",
71+
"pytest-check",
6172
]
6273

6374
doc_requirements = [
@@ -106,10 +117,12 @@
106117
"docs": doc_requirements,
107118
"test": test_requirements,
108119
"interactive": interactive_requirements,
120+
"validation": validation_requirements,
109121
"dev": doc_requirements
110122
+ test_requirements
111123
+ interactive_requirements
112-
+ lint_requirements,
124+
+ lint_requirements
125+
+ validation_requirements,
113126
},
114127
zip_safe=False,
115128
use_scm_version={
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from vivarium_testing_utils.automated_validation.interface import ValidationContext
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from abc import ABC
2+
from collections.abc import Collection
3+
from typing import Any, Literal
4+
5+
import pandas as pd
6+
7+
from vivarium_testing_utils.automated_validation.constants import (
8+
DRAW_INDEX,
9+
SEED_INDEX,
10+
DataSource,
11+
)
12+
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
13+
from vivarium_testing_utils.automated_validation.data_transformation import (
14+
age_groups,
15+
calculations,
16+
)
17+
from vivarium_testing_utils.automated_validation.data_transformation.measures import (
18+
CategoricalRelativeRisk,
19+
Measure,
20+
RatioMeasure,
21+
RiskExposure,
22+
)
23+
from vivarium_testing_utils.automated_validation.visualization import dataframe_utils
24+
25+
26+
class MeasureDataBundle(ABC):
27+
measure: Measure
28+
source: DataSource
29+
data_loader: DataLoader
30+
scenarios: dict[str, str] | None
31+
32+
33+
class RatioMeasureDataBundle:
34+
def __init__(
35+
self,
36+
measure: RatioMeasure,
37+
source: DataSource,
38+
data_loader: DataLoader,
39+
age_group_df: pd.DataFrame,
40+
scenarios: dict[str, str] | None = None,
41+
) -> None:
42+
self.measure = measure
43+
self.source = source
44+
self.scenarios = scenarios if scenarios is not None else {}
45+
self.datasets = self._get_formatted_datasets(data_loader, age_group_df)
46+
self.weights = self._get_aggregated_weights(data_loader, age_group_df)
47+
48+
@property
49+
def dataset_names(self) -> dict[str, str]:
50+
"""Return a dictionary of required datasets for the specified source."""
51+
if self.source == DataSource.SIM:
52+
return self.measure.sim_output_datasets
53+
elif self.source in ([DataSource.ARTIFACT, DataSource.GBD]):
54+
return self.measure.sim_input_datasets
55+
else:
56+
raise ValueError(f"Unsupported data source: {self.source}")
57+
58+
@property
59+
def index_names(self) -> set[str]:
60+
return {
61+
index_name
62+
for key in self.datasets
63+
for index_name in self.datasets[key].index.names
64+
}
65+
66+
def get_metadata(self) -> dict[str, Any]:
67+
"""Organize the data information into a dictionary for display by a styled pandas DataFrame.
68+
Apply formatting to values that need special handling.
69+
70+
Returns:
71+
--------
72+
A dictionary containing the formatted data information.
73+
74+
"""
75+
dataframe = self.get_measure_data("all")
76+
data_info: dict[str, Any] = {}
77+
78+
# Source as string
79+
data_info["source"] = self.source.value
80+
81+
# Index columns as comma-separated string
82+
data_info["index_columns"] = list(dataframe.index.names)
83+
84+
# Size as formatted string
85+
size = dataframe.shape
86+
data_info["size"] = f"{size[0]:,} rows × {size[1]:,} columns"
87+
88+
# Draw information
89+
if DRAW_INDEX in dataframe.index.names:
90+
num_draws = dataframe.index.get_level_values(DRAW_INDEX).nunique()
91+
data_info["num_draws"] = f"{num_draws:,}"
92+
draw_values = list(dataframe.index.get_level_values(DRAW_INDEX).unique())
93+
data_info[DRAW_INDEX + "s"] = dataframe_utils.format_draws_sample(
94+
draw_values, self.source
95+
)
96+
97+
# Seeds information
98+
if SEED_INDEX in dataframe.index.names:
99+
num_seeds = dataframe.index.get_level_values(SEED_INDEX).nunique()
100+
data_info["num_seeds"] = f"{num_seeds:,}"
101+
102+
return data_info
103+
104+
def _get_formatted_datasets(
105+
self, data_loader: DataLoader, age_group_data: pd.DataFrame
106+
) -> dict[str, pd.DataFrame]:
107+
"""Formats measure datasets depending on the source."""
108+
raw_datasets = data_loader._get_raw_data_from_source(self.dataset_names, self.source)
109+
if self.source == DataSource.SIM:
110+
datasets = self.measure.get_ratio_datasets_from_sim(
111+
**raw_datasets,
112+
)
113+
elif self.source in [DataSource.ARTIFACT, DataSource.GBD]:
114+
data = self.measure.get_measure_data_from_sim_inputs(**raw_datasets)
115+
datasets = {"data": data}
116+
elif self.source == DataSource.CUSTOM:
117+
raise NotImplementedError
118+
else:
119+
raise ValueError(f"Unsupported data source: {self.source}")
120+
121+
datasets = {
122+
dataset_name: age_groups.format_dataframe_from_age_bin_df(dataset, age_group_data)
123+
for dataset_name, dataset in datasets.items()
124+
}
125+
datasets = {
126+
key: calculations.filter_data(dataset, self.scenarios, drop_singles=True)
127+
for key, dataset in datasets.items()
128+
}
129+
130+
return datasets
131+
132+
def _get_aggregated_weights(
133+
self, data_loader: DataLoader, age_group_data: pd.DataFrame
134+
) -> pd.DataFrame | None:
135+
"""Fetches and aggregates weights if required by the measure."""
136+
if self.source not in [DataSource.ARTIFACT, DataSource.GBD]:
137+
return None
138+
139+
raw_weights = data_loader._get_raw_data_from_source(
140+
self.measure.rate_aggregation_weights.weight_keys, self.source
141+
)
142+
weights = self.measure.rate_aggregation_weights.get_weights(**raw_weights)
143+
return age_groups.format_dataframe_from_age_bin_df(weights, age_group_data)
144+
145+
def get_measure_data(
146+
self, stratifications: Collection[str] | Literal["all"]
147+
) -> pd.DataFrame:
148+
"""Get the measure data, optionally aggregated over specified stratifications."""
149+
if self.source == DataSource.SIM:
150+
return self._aggregate_scenario_stratifications(self.datasets, stratifications)
151+
elif self.source in [DataSource.ARTIFACT, DataSource.GBD]:
152+
return self._aggregate_sim_input_stratifications(stratifications)
153+
elif self.source == DataSource.CUSTOM:
154+
raise NotImplementedError
155+
else:
156+
raise ValueError(f"Unsupported data source: {self.source}")
157+
158+
def _aggregate_scenario_stratifications(
159+
self,
160+
datasets: dict[str, pd.DataFrame],
161+
stratifications: Collection[str] | Literal["all"],
162+
) -> pd.DataFrame:
163+
"""This will remove index levels corresponding to the specified stratifications"""
164+
datasets = {
165+
key: calculations.stratify(datasets[key], stratifications) for key in datasets
166+
}
167+
return self.measure.get_measure_data_from_ratio(**datasets)
168+
169+
def _aggregate_sim_input_stratifications(
170+
self, stratifications: Collection[str] | Literal["all"]
171+
) -> pd.DataFrame:
172+
"""Aggregate the artifact data over specified stratifications. Stratifactions will be retained
173+
in the returned data."""
174+
data = self.datasets["data"].copy()
175+
if stratifications != "all":
176+
stratifications = list(stratifications)
177+
# Retain input_draw, comparison._aggregate_over_draws is the only place we should aggregate over draws.
178+
if DRAW_INDEX in data.index.names and DRAW_INDEX not in stratifications:
179+
stratifications.append(DRAW_INDEX)
180+
if self.weights is None:
181+
raise ValueError("Weights are required for aggregating artifact data.")
182+
183+
# Update scenario columns to retain during aggregation
184+
scenario_cols = []
185+
# NOTE: This is a hack to handle alignment of index levels in weighted_average. Risk
186+
# stratification column is treated as a scenario column and the population can be
187+
# broadcast across each index group since the exposure for each group should sum to 1.
188+
if isinstance(self.measure, (RiskExposure, CategoricalRelativeRisk)):
189+
scenario_cols.append(self.measure.risk_stratification_column)
190+
scenario_cols.extend(list(self.scenarios.keys()))
191+
weighted_avg = calculations.weighted_average(
192+
data, self.weights, stratifications, scenario_cols
193+
)
194+
195+
# Reference data can be a float or dataframe. Convert floats so dataframes are aligned
196+
if not isinstance(weighted_avg, pd.DataFrame):
197+
weighted_avg = pd.DataFrame(
198+
{"value": [weighted_avg]}, index=pd.Index([0], name="index")
199+
)
200+
return weighted_avg

0 commit comments

Comments
 (0)