Skip to content

Commit 433373c

Browse files
authored
Refactor Module Imports (#58)
* gitignore * change to common stratify index * make some dfs use from_product instead * more products * align indexes but not for pop structure * add unique strat columns to fixtures * cleanup * lint * move title to measures.py (#56) * move title to measures.py * lint * add type * fix test * change function import strategy * lint * lint * rename with "test" * add module names * remove accidental copy * revert interface change * fix typo
1 parent 587e233 commit 433373c

File tree

13 files changed

+152
-161
lines changed

13 files changed

+152
-161
lines changed

src/vivarium_testing_utils/automated_validation/comparison.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,12 @@
55

66
from vivarium_testing_utils.automated_validation.constants import DRAW_INDEX, SEED_INDEX
77
from vivarium_testing_utils.automated_validation.data_loader import DataSource
8-
from vivarium_testing_utils.automated_validation.data_transformation.calculations import (
9-
filter_data,
10-
get_singular_indices,
11-
marginalize,
12-
)
8+
from vivarium_testing_utils.automated_validation.data_transformation import calculations
139
from vivarium_testing_utils.automated_validation.data_transformation.measures import (
1410
Measure,
1511
RatioMeasure,
1612
)
17-
from vivarium_testing_utils.automated_validation.visualization.dataframe_utils import (
18-
format_draws_sample,
19-
format_metadata,
20-
)
13+
from vivarium_testing_utils.automated_validation.visualization import dataframe_utils
2114

2215

2316
class Comparison(ABC):
@@ -100,14 +93,14 @@ def __init__(
10093
self.test_source = test_source
10194
self.test_scenarios: dict[str, str] = test_scenarios if test_scenarios else {}
10295
self.test_datasets = {
103-
key: filter_data(dataset, self.test_scenarios, drop_singles=False)
96+
key: calculations.filter_data(dataset, self.test_scenarios, drop_singles=False)
10497
for key, dataset in test_datasets.items()
10598
}
10699
self.reference_source = reference_source
107100
self.reference_scenarios: dict[str, str] = (
108101
reference_scenarios if reference_scenarios else {}
109102
)
110-
self.reference_data = filter_data(
103+
self.reference_data = calculations.filter_data(
111104
reference_data, self.reference_scenarios, drop_singles=False
112105
)
113106

@@ -131,7 +124,7 @@ def metadata(self) -> pd.DataFrame:
131124
measure_key = self.measure.measure_key
132125
test_info = self._get_metadata_from_datasets("test")
133126
reference_info = self._get_metadata_from_datasets("reference")
134-
return format_metadata(measure_key, test_info, reference_info)
127+
return dataframe_utils.format_metadata(measure_key, test_info, reference_info)
135128

136129
def get_diff(
137130
self,
@@ -232,7 +225,7 @@ def _get_metadata_from_datasets(
232225
num_draws = dataframe.index.get_level_values(DRAW_INDEX).nunique()
233226
data_info["num_draws"] = f"{num_draws:,}"
234227
draw_values = list(dataframe.index.get_level_values(DRAW_INDEX).unique())
235-
data_info[DRAW_INDEX + "s"] = format_draws_sample(draw_values)
228+
data_info[DRAW_INDEX + "s"] = dataframe_utils.format_draws_sample(draw_values)
236229

237230
# Seeds information
238231
if SEED_INDEX in dataframe.index.names:
@@ -266,13 +259,17 @@ def _align_datasets(self) -> tuple[pd.DataFrame, pd.DataFrame]:
266259
# If the test data has any index levels that are not in the reference data, marginalize
267260
# over those index levels.
268261
test_datasets = {
269-
key: marginalize(self.test_datasets[key], test_indexes_to_marginalize)
262+
key: calculations.marginalize(
263+
self.test_datasets[key], test_indexes_to_marginalize
264+
)
270265
for key in self.test_datasets
271266
}
272267

273268
# Drop any singular index levels from the reference data if they are not in the test data.
274269
# If any ref-only index level is not singular, raise an error.
275-
redundant_ref_indexes = set(get_singular_indices(self.reference_data).keys())
270+
redundant_ref_indexes = set(
271+
calculations.get_singular_indices(self.reference_data).keys()
272+
)
276273
if not reference_indexes_to_drop.issubset(redundant_ref_indexes):
277274
# TODO: MIC-6075
278275
diff = reference_indexes_to_drop - redundant_ref_indexes
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
15
DRAW_PREFIX = "draw_"
26

37
DRAW_INDEX = "input_draw"
48
SEED_INDEX = "random_seed"
9+
10+
11+
class DataSource(Enum):
12+
SIM = "sim"
13+
GBD = "gbd"
14+
ARTIFACT = "artifact"
15+
CUSTOM = "custom"
16+
17+
@classmethod
18+
def from_str(cls, source: str) -> DataSource:
19+
try:
20+
return cls(source)
21+
except ValueError:
22+
raise ValueError(f"Source {source} not recognized. Must be one of {DataSource}")

src/vivarium_testing_utils/automated_validation/data_loader.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,20 @@
11
from __future__ import annotations
22

3-
from enum import Enum
43
from pathlib import Path
54

65
import pandas as pd
76
import yaml
87
from vivarium import Artifact
98

10-
from vivarium_testing_utils.automated_validation.data_transformation.calculations import (
11-
clean_artifact_data,
12-
marginalize,
9+
from vivarium_testing_utils.automated_validation.constants import DataSource
10+
from vivarium_testing_utils.automated_validation.data_transformation import (
11+
calculations,
12+
utils,
1313
)
1414
from vivarium_testing_utils.automated_validation.data_transformation.data_schema import (
1515
SimOutputData,
1616
SingleNumericColumn,
1717
)
18-
from vivarium_testing_utils.automated_validation.data_transformation.utils import check_io
19-
20-
21-
class DataSource(Enum):
22-
SIM = "sim"
23-
GBD = "gbd"
24-
ARTIFACT = "artifact"
25-
CUSTOM = "custom"
26-
27-
@classmethod
28-
def from_str(cls, source: str) -> DataSource:
29-
try:
30-
return cls(source)
31-
except ValueError:
32-
raise ValueError(f"Source {source} not recognized. Must be one of {DataSource}")
33-
3418

3519
NONSTANDARD_ARTIFACT_KEYS = {"population.age_bins"}
3620

@@ -129,7 +113,7 @@ def _add_to_cache(self, dataset_key: str, source: DataSource, data: pd.DataFrame
129113
raise ValueError(f"Dataset {dataset_key} already exists in the cache.")
130114
self._raw_datasets[source].update({dataset_key: data.copy()})
131115

132-
@check_io(out=SimOutputData)
116+
@utils.check_io(out=SimOutputData)
133117
def _load_from_sim(self, dataset_key: str) -> pd.DataFrame:
134118
"""Load the data from the simulation output directory and set the non-value columns as indices."""
135119
sim_data = pd.read_parquet(self._results_dir / f"{dataset_key}.parquet")
@@ -168,12 +152,12 @@ def _load_nonstandard_artifact(self, dataset_key: str) -> pd.DataFrame:
168152
self._artifact.clear_cache()
169153
return data
170154

171-
@check_io(out=SingleNumericColumn)
155+
@utils.check_io(out=SingleNumericColumn)
172156
def _load_from_artifact(self, dataset_key: str) -> pd.DataFrame:
173157
"""Load data directly from artifact, assuming correctly formatted data."""
174158
data: pd.DataFrame = self._artifact.load(dataset_key)
175159
self._artifact.clear_cache()
176-
return clean_artifact_data(dataset_key, data)
160+
return calculations.clean_artifact_data(dataset_key, data)
177161

178162
def _load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
179163
raise NotImplementedError
@@ -186,7 +170,7 @@ def _load_from_gbd(self, dataset_key: str) -> pd.DataFrame:
186170

187171
def _convert_to_total_person_time(data: pd.DataFrame) -> pd.DataFrame:
188172
old_index_names = data.index.names
189-
data = marginalize(data, ["entity_type", "entity", "sub_entity"])
173+
data = calculations.marginalize(data, ["entity_type", "entity", "sub_entity"])
190174
data["entity_type"] = "none"
191175
data["entity"] = "total"
192176
data["sub_entity"] = "total"

src/vivarium_testing_utils/automated_validation/data_transformation/age_groups.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
AgeTuple = tuple[str, int | float, int | float]
1313
AgeRange = tuple[int | float, int | float]
1414

15+
from vivarium_testing_utils.automated_validation.data_transformation import utils
1516
from vivarium_testing_utils.automated_validation.data_transformation.data_schema import (
1617
SingleNumericColumn,
1718
)
18-
from vivarium_testing_utils.automated_validation.data_transformation.utils import check_io
1919

2020

2121
class AgeGroup:
@@ -396,7 +396,7 @@ def can_coerce_to(self, other: AgeSchema) -> bool:
396396
return True
397397

398398

399-
def format_dataframe(target_schema: AgeSchema, df: pd.DataFrame) -> pd.DataFrame:
399+
def _format_dataframe(target_schema: AgeSchema, df: pd.DataFrame) -> pd.DataFrame:
400400
"""
401401
Format a DataFrame to match the current schema.
402402
@@ -451,7 +451,7 @@ def format_dataframe(target_schema: AgeSchema, df: pd.DataFrame) -> pd.DataFrame
451451
return data
452452

453453

454-
@check_io(df=SingleNumericColumn, out=SingleNumericColumn)
454+
@utils.check_io(df=SingleNumericColumn, out=SingleNumericColumn)
455455
def rebin_count_dataframe(
456456
target_schema: AgeSchema,
457457
df: pd.DataFrame,
@@ -535,3 +535,17 @@ def _get_transform_matrix(source_schema: AgeSchema, target_schema: AgeSchema) ->
535535
if fraction > 0:
536536
transform_matrix.loc[target_group.name, source_group.name] = fraction
537537
return transform_matrix
538+
539+
540+
def format_dataframe_from_age_bin_df(
541+
data: pd.DataFrame, age_bin_df: pd.DataFrame
542+
) -> pd.DataFrame:
543+
"""Try to merge the age groups with the data. If it fails, just return the data."""
544+
context_age_schema = AgeSchema.from_dataframe(age_bin_df)
545+
try:
546+
return _format_dataframe(context_age_schema, data)
547+
except ValueError:
548+
logger.info(
549+
"Could not resolve age groups. The DataFrame likely has no age data. Returning dataframe as-is."
550+
)
551+
return data

src/vivarium_testing_utils/automated_validation/data_transformation/calculations.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,11 @@
1010
DRAW_INDEX,
1111
DRAW_PREFIX,
1212
)
13-
from vivarium_testing_utils.automated_validation.data_transformation.age_groups import (
14-
AgeSchema,
15-
format_dataframe,
16-
)
13+
from vivarium_testing_utils.automated_validation.data_transformation import utils
1714
from vivarium_testing_utils.automated_validation.data_transformation.data_schema import (
1815
DrawData,
1916
SingleNumericColumn,
2017
)
21-
from vivarium_testing_utils.automated_validation.data_transformation.utils import (
22-
check_io,
23-
series_to_dataframe,
24-
)
2518

2619

2720
def filter_data(
@@ -53,7 +46,7 @@ def filter_data(
5346
return data
5447

5548

56-
@check_io(
49+
@utils.check_io(
5750
numerator_data=SingleNumericColumn,
5851
denominator_data=SingleNumericColumn,
5952
out=SingleNumericColumn,
@@ -109,10 +102,10 @@ def linear_combination(
109102
data: pd.DataFrame, coeff_a: float, col_a: str, coeff_b: float, col_b: str
110103
) -> pd.DataFrame:
111104
"""Return a series that is the linear combination of two columns in a DataFrame."""
112-
return series_to_dataframe((data[col_a] * coeff_a) + (data[col_b] * coeff_b))
105+
return utils.series_to_dataframe((data[col_a] * coeff_a) + (data[col_b] * coeff_b))
113106

114107

115-
@check_io(out=SingleNumericColumn)
108+
@utils.check_io(out=SingleNumericColumn)
116109
def clean_artifact_data(
117110
dataset_key: str,
118111
data: pd.DataFrame,
@@ -125,7 +118,7 @@ def clean_artifact_data(
125118
return data
126119

127120

128-
@check_io(data=DrawData, out=SingleNumericColumn)
121+
@utils.check_io(data=DrawData, out=SingleNumericColumn)
129122
def _clean_artifact_draws(
130123
data: pd.DataFrame,
131124
) -> pd.DataFrame:
@@ -144,18 +137,6 @@ def _clean_artifact_draws(
144137
return data
145138

146139

147-
def resolve_age_groups(data: pd.DataFrame, age_groups: pd.DataFrame) -> pd.DataFrame:
148-
"""Try to merge the age groups with the data. If it fails, just return the data."""
149-
context_age_schema = AgeSchema.from_dataframe(age_groups)
150-
try:
151-
return format_dataframe(context_age_schema, data)
152-
except ValueError:
153-
logger.info(
154-
"Could not resolve age groups. The DataFrame likely has no age data. Returning dataframe as-is."
155-
)
156-
return data
157-
158-
159140
def get_singular_indices(data: pd.DataFrame) -> dict[str, Any]:
160141
"""Get index levels and their values that are singular (i.e. have only one unique value)."""
161142
singular_metadata: dict[str, Any] = {}

src/vivarium_testing_utils/automated_validation/data_transformation/formatting.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import pandas as pd
22

33
from vivarium_testing_utils.automated_validation.constants import DRAW_INDEX, SEED_INDEX
4-
from vivarium_testing_utils.automated_validation.data_transformation.calculations import (
5-
filter_data,
6-
marginalize,
7-
stratify,
8-
)
4+
from vivarium_testing_utils.automated_validation.data_transformation import calculations
95

106

117
class SimDataFormatter:
@@ -29,11 +25,11 @@ def __init__(self, measure: str, entity: str, filter_value: str) -> None:
2925

3026
def format_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
3127
"""Clean up unused columns, and filter for the state."""
32-
dataset = marginalize(dataset, self.unused_columns)
28+
dataset = calculations.marginalize(dataset, self.unused_columns)
3329
if self.filter_value == "total":
34-
dataset = marginalize(dataset, [*self.filters])
30+
dataset = calculations.marginalize(dataset, [*self.filters])
3531
else:
36-
dataset = filter_data(dataset, self.filters)
32+
dataset = calculations.filter_data(dataset, self.filters)
3733
return dataset
3834

3935

@@ -82,7 +78,7 @@ def format_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
8278
levels_to_stratify = [
8379
level for level in between_scenario_levels if level in dataset.index.names
8480
]
85-
return stratify(
81+
return calculations.stratify(
8682
data=dataset,
8783
stratification_cols=levels_to_stratify,
8884
)
@@ -124,7 +120,7 @@ def __init__(self, entity: str, sum_all: bool = False) -> None:
124120
self.unused_columns = ["measure", "entity_type", "entity"]
125121

126122
def format_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
127-
dataset = marginalize(dataset, self.unused_columns)
123+
dataset = calculations.marginalize(dataset, self.unused_columns)
128124
if self.sum_all:
129125
# Get the levels to group by (all except 'sub_entity')
130126
group_levels = [

0 commit comments

Comments
 (0)