Skip to content

Commit faa1a53

Browse files
authored
Albrja/mic 6510/gbd to databundle (#86)
Albrja/mic 6510/gbd to databundle Pipe GBD feature to DataBundle - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6510 Changes and notes -allows GBD to be a source for DataBundle -same implementation as artifact
1 parent de60e81 commit faa1a53

File tree

5 files changed

+76
-30
lines changed

5 files changed

+76
-30
lines changed

src/vivarium_testing_utils/automated_validation/bundle.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(
4747
def dataset_names(self) -> dict[str, str]:
4848
"""Return a dictionary of required datasets for the specified source."""
4949
if self.source == DataSource.SIM:
50-
return self.measure.sim_datasets
51-
elif self.source == DataSource.ARTIFACT:
50+
return self.measure.sim_output_datasets
51+
elif self.source in ([DataSource.ARTIFACT, DataSource.GBD]):
5252
return self.measure.sim_input_datasets
5353
else:
5454
raise ValueError(f"Unsupported data source: {self.source}")
@@ -107,11 +107,9 @@ def _get_formatted_datasets(
107107
datasets = self.measure.get_ratio_datasets_from_sim(
108108
**raw_datasets,
109109
)
110-
elif self.source == DataSource.ARTIFACT:
110+
elif self.source in [DataSource.ARTIFACT, DataSource.GBD]:
111111
data = self.measure.get_measure_data_from_sim_inputs(**raw_datasets)
112112
datasets = {"data": data}
113-
elif self.source == DataSource.GBD:
114-
raise NotImplementedError
115113
elif self.source == DataSource.CUSTOM:
116114
raise NotImplementedError
117115
else:
@@ -132,7 +130,7 @@ def _get_aggregated_weights(
132130
self, data_loader: DataLoader, age_group_data: pd.DataFrame
133131
) -> pd.DataFrame | None:
134132
"""Fetches and aggregates weights if required by the measure."""
135-
if self.source != DataSource.ARTIFACT:
133+
if self.source not in [DataSource.ARTIFACT, DataSource.GBD]:
136134
return None
137135

138136
raw_weights = data_loader._get_raw_data_from_source(
@@ -147,10 +145,8 @@ def get_measure_data(
147145
"""Get the measure data, optionally aggregated over specified stratifications."""
148146
if self.source == DataSource.SIM:
149147
return self._aggregate_scenario_stratifications(self.datasets, stratifications)
150-
elif self.source == DataSource.ARTIFACT:
151-
return self._aggregate_artifact_stratifications(stratifications)
152-
elif self.source == DataSource.GBD:
153-
raise NotImplementedError
148+
elif self.source in [DataSource.ARTIFACT, DataSource.GBD]:
149+
return self._aggregate_sim_input_stratifications(stratifications)
154150
elif self.source == DataSource.CUSTOM:
155151
raise NotImplementedError
156152
else:
@@ -167,7 +163,7 @@ def _aggregate_scenario_stratifications(
167163
}
168164
return self.measure.get_measure_data_from_ratio(**datasets)
169165

170-
def _aggregate_artifact_stratifications(
166+
def _aggregate_sim_input_stratifications(
171167
self, stratifications: Collection[str] | Literal["all"]
172168
) -> pd.DataFrame:
173169
"""Aggregate the artifact data over specified stratifications. Stratifactions will be retained

src/vivarium_testing_utils/automated_validation/data_transformation/measures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __str__(self) -> str:
5858

5959
@property
6060
@abstractmethod
61-
def sim_datasets(self) -> dict[str, str]:
61+
def sim_output_datasets(self) -> dict[str, str]:
6262
"""Return a dictionary of required datasets for this measure."""
6363
pass
6464

@@ -101,7 +101,7 @@ def __init__(
101101
self.denominator = denominator
102102

103103
@property
104-
def sim_datasets(self) -> dict[str, str]:
104+
def sim_output_datasets(self) -> dict[str, str]:
105105
"""Return a dictionary of required datasets for this measure."""
106106
return {
107107
"numerator_data": self.numerator.raw_dataset_name,

tests/automated_validation/data_transformation/test_measures.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_incidence(
4242
measure = Incidence(cause)
4343
assert measure.measure_key == f"cause.{cause}.incidence_rate"
4444
assert measure.title == "Disease Incidence Rate"
45-
assert measure.sim_datasets == {
45+
assert measure.sim_output_datasets == {
4646
"numerator_data": f"transition_count_{cause}",
4747
"denominator_data": f"person_time_{cause}",
4848
}
@@ -72,7 +72,7 @@ def test_prevalence(person_time_data: pd.DataFrame) -> None:
7272
measure = Prevalence(cause)
7373
assert measure.measure_key == f"cause.{cause}.prevalence"
7474
assert measure.title == "Disease Prevalence"
75-
assert measure.sim_datasets == {
75+
assert measure.sim_output_datasets == {
7676
"numerator_data": f"person_time_{cause}",
7777
"denominator_data": f"person_time_{cause}",
7878
}
@@ -160,7 +160,7 @@ def test_si_remission(
160160
measure = SIRemission(cause)
161161
assert measure.measure_key == f"cause.{cause}.remission_rate"
162162
assert measure.title == "Disease Remission Rate"
163-
assert measure.sim_datasets == {
163+
assert measure.sim_output_datasets == {
164164
"numerator_data": f"transition_count_{cause}",
165165
"denominator_data": f"person_time_{cause}",
166166
}
@@ -190,7 +190,7 @@ def test_all_cause_mortality_rate(
190190
measure = CauseSpecificMortalityRate("all_causes")
191191
assert measure.measure_key == "cause.all_causes.cause_specific_mortality_rate"
192192
assert measure.title == "All Causes Cause Specific Mortality Rate"
193-
assert measure.sim_datasets == {
193+
assert measure.sim_output_datasets == {
194194
"numerator_data": "deaths",
195195
"denominator_data": "person_time_total",
196196
}
@@ -226,7 +226,7 @@ def test_cause_specific_mortality_rate(
226226
measure = CauseSpecificMortalityRate(cause)
227227
assert measure.measure_key == f"cause.{cause}.cause_specific_mortality_rate"
228228
assert measure.title == "Disease Cause Specific Mortality Rate"
229-
assert measure.sim_datasets == {
229+
assert measure.sim_output_datasets == {
230230
"numerator_data": f"deaths",
231231
"denominator_data": "person_time_total",
232232
}
@@ -261,7 +261,7 @@ def test_excess_mortality_rate(
261261
measure = ExcessMortalityRate(cause)
262262
assert measure.measure_key == f"cause.{cause}.excess_mortality_rate"
263263
assert measure.title == "Disease Excess Mortality Rate"
264-
assert measure.sim_datasets == {
264+
assert measure.sim_output_datasets == {
265265
"numerator_data": f"deaths",
266266
"denominator_data": f"person_time_{cause}",
267267
}
@@ -296,7 +296,7 @@ def test_risk_exposure(risk_state_person_time_data: pd.DataFrame) -> None:
296296
measure = RiskExposure(risk_factor)
297297
assert measure.measure_key == f"risk_factor.{risk_factor}.exposure"
298298
assert measure.title == "Child Stunting Exposure"
299-
assert measure.sim_datasets == {
299+
assert measure.sim_output_datasets == {
300300
"numerator_data": f"person_time_{risk_factor}",
301301
"denominator_data": f"person_time_{risk_factor}",
302302
}
@@ -366,7 +366,7 @@ def test_population_structure(person_time_data: pd.DataFrame) -> None:
366366

367367
assert measure.measure_key == "population.structure"
368368
assert measure.title == "Population Structure"
369-
assert measure.sim_datasets == {
369+
assert measure.sim_output_datasets == {
370370
"numerator_data": "person_time_total",
371371
"denominator_data": "person_time_total",
372372
}
@@ -465,7 +465,7 @@ def test_categorical_relative_risk(
465465
assert measure.title == "Effect of Risky Risk on Disease Excess Mortality Rate"
466466
assert measure.affected_entity == affected_entity
467467
assert measure.affected_measure_name == "excess_mortality_rate"
468-
assert measure.sim_datasets == {
468+
assert measure.sim_output_datasets == {
469469
"numerator_data": "deaths",
470470
"denominator_data": f"person_time_{affected_entity}",
471471
}

tests/automated_validation/test_data_bundle.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import pandas as pd
77
import pytest
88
from pytest_mock import MockFixture
9+
from vivarium_inputs import interface
910

11+
from tests.automated_validation.conftest import NO_GBD_ACCESS
1012
from vivarium_testing_utils.automated_validation.bundle import RatioMeasureDataBundle
11-
from vivarium_testing_utils.automated_validation.constants import DataSource
13+
from vivarium_testing_utils.automated_validation.constants import DRAW_INDEX, DataSource
1214
from vivarium_testing_utils.automated_validation.data_loader import DataLoader
1315
from vivarium_testing_utils.automated_validation.data_transformation import age_groups
1416
from vivarium_testing_utils.automated_validation.data_transformation.measures import (
@@ -35,7 +37,7 @@ def test_data_bundle_init(
3537
)
3638

3739
if data_source == DataSource.SIM:
38-
expected_keys = set(measure.sim_datasets.keys())
40+
expected_keys = set(measure.sim_output_datasets.keys())
3941
else:
4042
expected_keys = set(measure.sim_input_datasets.keys())
4143
assert set(bundle.dataset_names) == expected_keys
@@ -100,12 +102,10 @@ def test_get_metadata(
100102
assert metadata["size"] == "4 rows × 1 columns"
101103

102104

103-
@pytest.mark.parametrize("source", [DataSource.GBD, DataSource.CUSTOM])
104-
def test_dataset_names_value_error(
105+
def test_custom_data_source_dataset_names_value_error(
105106
mocker: MockFixture,
106107
mock_ratio_measure: RatioMeasure,
107108
sample_age_group_df: pd.DataFrame,
108-
source: DataSource,
109109
) -> None:
110110
"""Test _get_formatted_datasets raises NotImplementedError for GBD source."""
111111
mock_data_loader = mocker.MagicMock(spec=DataLoader)
@@ -114,7 +114,7 @@ def test_dataset_names_value_error(
114114
with pytest.raises(ValueError):
115115
RatioMeasureDataBundle(
116116
measure=mock_ratio_measure,
117-
source=source,
117+
source=DataSource.CUSTOM,
118118
data_loader=mock_data_loader,
119119
age_group_df=sample_age_group_df,
120120
)
@@ -186,7 +186,7 @@ def test_aggregate_reference_stratifications(
186186
data_loader=mocker.MagicMock(spec=DataLoader),
187187
age_group_df=sample_age_group_df,
188188
)
189-
aggregated = bundle._aggregate_artifact_stratifications(stratifications)
189+
aggregated = bundle._aggregate_sim_input_stratifications(stratifications)
190190

191191
if stratifications == "all":
192192
aggregated.equals(reference_data)
@@ -208,3 +208,53 @@ def test_aggregate_reference_stratifications(
208208
),
209209
)
210210
pd.testing.assert_frame_equal(aggregated, expected)
211+
212+
213+
@pytest.mark.slow
214+
def test_data_bundle_gbd_source(sim_result_dir: Path) -> None:
215+
"""Test that GBD data source is handled correctly in RatioMeasureDataBundle."""
216+
if NO_GBD_ACCESS:
217+
pytest.skip("GBD access not available for this test.")
218+
219+
age_bins = interface.get_age_bins()
220+
age_bins.index.rename({"age_group_name": age_groups.AGE_GROUP_COLUMN}, inplace=True)
221+
222+
incidence = Incidence("diarrheal_diseases")
223+
bundle = RatioMeasureDataBundle(
224+
measure=incidence,
225+
source=DataSource.GBD,
226+
data_loader=DataLoader(sim_result_dir),
227+
age_group_df=age_bins,
228+
)
229+
230+
assert set(bundle.dataset_names) == {"data"}
231+
# Validate datasets and weights schema
232+
dataset_index_names = {
233+
"sex",
234+
age_groups.AGE_GROUP_COLUMN,
235+
"year_start",
236+
"year_end",
237+
DRAW_INDEX,
238+
}
239+
assert set(bundle.datasets["data"].index.names) == dataset_index_names
240+
assert set(bundle.datasets["data"].columns) == {"value"}
241+
assert bundle.weights is not None
242+
assert set(bundle.weights.index.names) == dataset_index_names.union({"location"})
243+
assert set(bundle.weights.columns) == {"value"}
244+
245+
# Validate data aggregation
246+
stratify_1 = bundle.get_measure_data("all")
247+
pd.testing.assert_frame_equal(stratify_1, bundle.datasets["data"])
248+
stratify_2 = bundle.get_measure_data(["sex", age_groups.AGE_GROUP_COLUMN])
249+
assert set(stratify_2.index.names) == {"sex", age_groups.AGE_GROUP_COLUMN, DRAW_INDEX}
250+
251+
metadata = bundle.get_metadata()
252+
assert metadata["source"] == "gbd"
253+
assert metadata["index_columns"] == "sex, year_start, year_end, input_draw, age_group"
254+
assert set(metadata.keys()) == {
255+
"source",
256+
"index_columns",
257+
"size",
258+
"num_draws",
259+
"input_draws",
260+
}

tests/automated_validation/test_data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test___get_raw_data_from_source(
210210
data_loader = DataLoader(sim_result_dir)
211211
measure = Incidence("disease")
212212
test_raw_data = data_loader._get_raw_data_from_source(
213-
measure.sim_datasets, DataSource.SIM
213+
measure.sim_output_datasets, DataSource.SIM
214214
)
215215
ref_raw_data = data_loader._get_raw_data_from_source(
216216
measure.sim_input_datasets, DataSource.ARTIFACT

0 commit comments

Comments
 (0)