Skip to content

Commit 16fbbf6

Browse files
authored
Albrja/mic 6508/load gbd data (#84)
Albrja/mic 6508/load gbd data Add feature to lead gbd data from DataLoader - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6508 Changes and notes -pin pandas below 2.0 to resolve package dependencies with Vivarium Inputs -use load_standard_data to load GBD data in DataLoader
1 parent ce9d79b commit 16fbbf6

File tree

14 files changed

+106
-40
lines changed

14 files changed

+106
-40
lines changed

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
install_requirements = [
4747
"vivarium_dependencies[numpy,pyyaml,scipy,click,tables,loguru,networkx]",
48-
"pandas>2.0.0",
48+
"pandas",
4949
"vivarium_build_utils>=2.0.1,<3.0.0",
5050
"pyarrow",
5151
"seaborn",
@@ -57,8 +57,8 @@
5757

5858
validation_requirements = [
5959
"vivarium>=3.4.0",
60-
"vivarium-inputs",
61-
"pandera",
60+
"vivarium-inputs>=7.1.0, <8.0.0",
61+
"pandera<0.23.0",
6262
]
6363

6464
interactive_requirements = [

src/vivarium_testing_utils/automated_validation/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ def from_str(cls, source: str) -> DataSource:
2020
return cls(source)
2121
except ValueError:
2222
raise ValueError(f"Source {source} not recognized. Must be one of {DataSource}")
23+
24+
25+
LOCATION_ARTIFACT_KEY = "population.location"

src/vivarium_testing_utils/automated_validation/data_loader.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
import pandas as pd
77
import yaml
88
from vivarium import Artifact
9+
from vivarium_inputs.interface import load_standard_data
910

10-
from vivarium_testing_utils.automated_validation.constants import DRAW_PREFIX, DataSource
11+
from vivarium_testing_utils.automated_validation.constants import (
12+
DRAW_PREFIX,
13+
LOCATION_ARTIFACT_KEY,
14+
DataSource,
15+
)
1116
from vivarium_testing_utils.automated_validation.data_transformation import (
1217
calculations,
1318
utils,
@@ -23,7 +28,7 @@ def __init__(self, sim_output_dir: Path, cache_size_mb: int = 1000):
2328
self._cache_size_mb = cache_size_mb
2429

2530
self._results_dir = self._sim_output_dir / "results"
26-
self._raw_data_cache: dict[DataSource, dict[str, pd.DataFrame]] = {
31+
self._raw_data_cache: dict[DataSource, dict[str, pd.DataFrame | str]] = {
2732
data_source: {} for data_source in DataSource
2833
}
2934
self._loader_mapping = {
@@ -39,6 +44,8 @@ def __init__(self, sim_output_dir: Path, cache_size_mb: int = 1000):
3944
self._add_to_cache(
4045
data_key="person_time_total", data=person_time_total, source=DataSource.SIM
4146
)
47+
# TODO: MIC-6533 - Update when all locations are in one artifact in the future.
48+
self.location = self.get_data(LOCATION_ARTIFACT_KEY, DataSource.ARTIFACT)
4249

4350
def _create_person_time_total_dataset(self) -> pd.DataFrame | None:
4451
"""
@@ -82,7 +89,8 @@ def get_artifact_keys(self) -> list[str]:
8289
def get_data(self, data_key: str, source: DataSource) -> Any:
8390
"""Return the data from the cache if it exists, otherwise load it from the source."""
8491
try:
85-
return self._raw_data_cache[source][data_key].copy()
92+
data = self._raw_data_cache[source][data_key]
93+
return data.copy() if isinstance(data, pd.DataFrame) else data
8694
except KeyError:
8795
if source == DataSource.CUSTOM:
8896
raise ValueError(
@@ -100,11 +108,14 @@ def _load_from_source(self, data_key: str, source: DataSource) -> Any:
100108
"""Load the data from the given source via the loader mapping."""
101109
return self._loader_mapping[source](data_key)
102110

103-
def _add_to_cache(self, data_key: str, source: DataSource, data: pd.DataFrame) -> None:
111+
def _add_to_cache(
112+
self, data_key: str, source: DataSource, data: pd.DataFrame | str
113+
) -> None:
104114
"""Update the raw_data_cache with the given data."""
105115
if data_key in self._raw_data_cache.get(source, {}):
106116
raise ValueError(f"Data for {data_key} already exist in the cache.")
107-
self._raw_data_cache[source].update({data_key: data.copy()})
117+
cache_data = data.copy() if isinstance(data, pd.DataFrame) else data
118+
self._raw_data_cache[source].update({data_key: cache_data})
108119

109120
@utils.check_io(out=SimOutputData)
110121
def _load_from_sim(self, data_key: str) -> pd.DataFrame:
@@ -148,11 +159,18 @@ def _load_from_artifact(self, data_key: str) -> Any:
148159
and not data.columns.empty
149160
and data.columns.str.startswith(DRAW_PREFIX).all()
150161
):
151-
data = calculations.clean_artifact_draws(data)
162+
data = calculations.clean_draw_columns(data)
152163
return data
153164

154-
def _load_from_gbd(self, data_key: str) -> pd.DataFrame:
155-
raise NotImplementedError
165+
def _load_from_gbd(self, data_key: str) -> Any:
166+
data = load_standard_data(data_key, self.location)
167+
if (
168+
isinstance(data, pd.DataFrame)
169+
and not data.columns.empty
170+
and data.columns.str.startswith(DRAW_PREFIX).all()
171+
):
172+
data = calculations.clean_draw_columns(data)
173+
return data
156174

157175
def _get_raw_data_from_source(
158176
self, measure_keys: dict[str, str], source: DataSource

src/vivarium_testing_utils/automated_validation/data_transformation/age_groups.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,9 +507,7 @@ def rebin_count_dataframe(
507507
result_matrix_for_col.columns.name = AGE_GROUP_COLUMN
508508

509509
# Stack the new age group columns into the index
510-
stacked_series_for_col = result_matrix_for_col.stack(
511-
level=AGE_GROUP_COLUMN, future_stack=True
512-
)
510+
stacked_series_for_col = result_matrix_for_col.stack(level=AGE_GROUP_COLUMN)
513511
stacked_series_for_col.name = val_col
514512

515513
all_results_series.append(stacked_series_for_col)

src/vivarium_testing_utils/automated_validation/data_transformation/calculations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def linear_combination(
124124

125125

126126
@utils.check_io(data=DrawData, out=SingleNumericColumn)
127-
def clean_artifact_draws(
127+
def clean_draw_columns(
128128
data: pd.DataFrame,
129129
) -> pd.DataFrame:
130130
"""Clean the artifact data by dropping unnecessary columns and renaming the value column."""

src/vivarium_testing_utils/automated_validation/data_transformation/data_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pandas as pd
2-
import pandera.pandas as pa
2+
import pandera as pa
33
from pandera.typing import Index
44

55

src/vivarium_testing_utils/automated_validation/data_transformation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, TypeVar
44

55
import pandas as pd
6-
import pandera.pandas as pa
6+
import pandera as pa
77

88
F = TypeVar("F", bound=Callable[..., Any])
99

tests/automated_validation/conftest.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import shutil
12
from pathlib import Path
23
from unittest import mock
34

@@ -7,7 +8,11 @@
78
from pytest import TempPathFactory
89
from vivarium.framework.artifact import Artifact
910

10-
from vivarium_testing_utils.automated_validation.constants import DRAW_INDEX, SEED_INDEX
11+
from vivarium_testing_utils.automated_validation.constants import (
12+
DRAW_INDEX,
13+
LOCATION_ARTIFACT_KEY,
14+
SEED_INDEX,
15+
)
1116
from vivarium_testing_utils.automated_validation.data_loader import (
1217
_convert_to_total_person_time,
1318
)
@@ -529,7 +534,7 @@ def _make_artifact_prevalence() -> pd.DataFrame:
529534

530535

531536
@pytest.fixture(scope="session")
532-
def _artifact_keys_mapper() -> dict[str, pd.DataFrame | dict[str, str]]:
537+
def _artifact_keys_mapper() -> dict[str, str | pd.DataFrame | dict[str, str]]:
533538
_raw_artifact_disease_incidence = _create_raw_artifact_disease_incidence()
534539
_raw_artifact_risk_exposure = _create_raw_artifact_risk_exposure()
535540
_sample_age_group_df = _create_sample_age_group_df()
@@ -543,6 +548,7 @@ def _artifact_keys_mapper() -> dict[str, pd.DataFrame | dict[str, str]]:
543548
"risk_factor.risky_risk.categories": _risk_categories,
544549
"population.structure": _population_structure,
545550
"cause.disease.prevalence": _artifact_prevalence,
551+
LOCATION_ARTIFACT_KEY: "Ethiopia",
546552
}
547553

548554

@@ -603,3 +609,11 @@ def reference_weights() -> pd.DataFrame:
603609
names=["year", "sex", "age"],
604610
),
605611
)
612+
613+
614+
def is_on_slurm() -> bool:
615+
"""Returns True if the current environment is a SLURM cluster."""
616+
return not shutil.which("sbatch") is not None
617+
618+
619+
NO_GBD_ACCESS = is_on_slurm()

tests/automated_validation/data_transformation/test_calculations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,19 +244,19 @@ def test_aggregate_sum_preserves_string_order() -> None:
244244
(
245245
["sex"],
246246
[
247-
2.83,
248247
6.92,
248+
2.83,
249249
], # Male: (20*2 + 100*3)/(20+100) ≈ 2.83, Female: (2*5 + 50*7)/(2+50) ≈ 6.92
250-
pd.Index(["Male", "Female"], name="sex"),
250+
pd.Index(["Female", "Male"], name="sex"),
251251
),
252252
# Test aggregating by color
253253
(
254254
["color"],
255255
[
256-
2.27,
257256
4.33,
257+
2.27,
258258
], # Red: (20*2 + 2*5)/(20+2) ≈ 2.27, Blue: (100*3 + 50*7)/(100+50) ≈ 4.33
259-
pd.Index(["Red", "Blue"], name="color"),
259+
pd.Index(["Blue", "Red"], name="color"),
260260
),
261261
# Test no aggregation - keeping all index levels
262262
(

tests/automated_validation/data_transformation/test_data_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pandas as pd
2-
import pandera.pandas as pa
2+
import pandera as pa
33
import pytest
44
from pandera.errors import SchemaError
55

0 commit comments

Comments
 (0)