Skip to content

Commit e5505e2

Browse files
authored
Albrja/mic-5971/lbwsg-exposure-data (#524)
Albrja/mic-5971/lbwsg-exposure-data Update LBWSG components to use birth_exposure or exposure artifact keys depending on simulant age - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5971 Changes and notes -update LBWSG risk to determine which artifact key/pipeline to use depending upon simulants age_end -updates LBWSG DIstribution component to create lookup tables for both birth_exposure and exposure artifact keys and used the provided table name to source the exposure_parameter pipeline ### Testing <!-- Details on how code was verified, any unit tests local for the repo, regression testing, etc. At a minimum, this should include an integration test for a framework change. Consider: plots, images, (small) csv file. -->
1 parent 4fa9815 commit e5505e2

File tree

4 files changed

+217
-20
lines changed

4 files changed

+217
-20
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**4.2.0 - 06/11/25**
2+
3+
- Feature: Update LBWSG to use exposure data based on simulant age
4+
15
**4.1.1 - 05/23/25**
26

37
- Feature: Update Observer to use super class get_configuration method

src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,98 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
from layered_config_tree import ConfigurationError
18+
from loguru import logger
1719
from vivarium.framework.engine import Builder
1820
from vivarium.framework.lifecycle import LifeCycleError
1921
from vivarium.framework.population import SimulantData
2022
from vivarium.framework.resource import Resource
2123
from vivarium.framework.values import Pipeline
2224

2325
from vivarium_public_health.risks import Risk, RiskEffect
24-
from vivarium_public_health.risks.data_transformations import get_exposure_post_processor
26+
from vivarium_public_health.risks.data_transformations import (
27+
get_exposure_post_processor,
28+
pivot_categorical,
29+
)
2530
from vivarium_public_health.risks.distributions import PolytomousDistribution
26-
from vivarium_public_health.utilities import get_lookup_columns, to_snake_case
31+
from vivarium_public_health.utilities import EntityString, get_lookup_columns, to_snake_case
2732

2833
CATEGORICAL = "categorical"
2934
BIRTH_WEIGHT = "birth_weight"
3035
GESTATIONAL_AGE = "gestational_age"
3136

3237

3338
class LBWSGDistribution(PolytomousDistribution):
39+
@property
40+
def categories(self) -> list[str]:
41+
# These need to be sorted so the cumulative sum is in the correct order of categories
42+
# and results are therefore reproducible and correct
43+
return sorted(self.lookup_tables[self.exposure_key].value_columns)
3444

3545
#################
3646
# Setup methods #
3747
#################
3848

49+
def __init__(
50+
self,
51+
risk: EntityString,
52+
distribution_type: str,
53+
exposure_data: int | float | pd.DataFrame | None = None,
54+
) -> None:
55+
super().__init__(risk, distribution_type, exposure_data)
56+
self.exposure_key = "birth_exposure"
57+
3958
# noinspection PyAttributeOutsideInit
4059
def setup(self, builder: Builder) -> None:
4160
super().setup(builder)
4261
self.category_intervals = self.get_category_intervals(builder)
4362

63+
def build_all_lookup_tables(self, builder: Builder) -> None:
64+
try:
65+
birth_exposure_data = self.get_data(
66+
builder, self.configuration["data_sources"]["birth_exposure"]
67+
)
68+
birth_exposure_value_columns = self.get_exposure_value_columns(
69+
birth_exposure_data
70+
)
71+
72+
if isinstance(birth_exposure_data, pd.DataFrame):
73+
birth_exposure_data = pivot_categorical(
74+
builder, self.risk, birth_exposure_data, "parameter"
75+
)
76+
77+
self.lookup_tables["birth_exposure"] = self.build_lookup_table(
78+
builder, birth_exposure_data, birth_exposure_value_columns
79+
)
80+
except ConfigurationError:
81+
logger.warning("Birth exposure data for LBWSG is missing from the simulation")
82+
try:
83+
super().build_all_lookup_tables(builder)
84+
except ConfigurationError:
85+
logger.warning("The data for LBWSG exposure is missing from the simulation.")
86+
87+
if (
88+
"birth_exposure" not in self.lookup_tables
89+
and "exposure" not in self.lookup_tables
90+
):
91+
raise ConfigurationError(
92+
"The LBWSG distribution requires either 'birth_exposure' or 'exposure' data to be "
93+
"available in the simulation."
94+
)
95+
96+
def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
97+
lookup_columns = []
98+
if "exposure" in self.lookup_tables:
99+
lookup_columns.extend(get_lookup_columns([self.lookup_tables["exposure"]]))
100+
if "birth_exposure" in self.lookup_tables:
101+
lookup_columns.extend(get_lookup_columns([self.lookup_tables["birth_exposure"]]))
102+
return builder.value.register_value_producer(
103+
self.parameters_pipeline_name,
104+
source=lambda index: self.lookup_tables[self.exposure_key](index),
105+
component=self,
106+
required_resources=list(set(lookup_columns)),
107+
)
108+
44109
def get_category_intervals(self, builder: Builder) -> dict[str, dict[str, pd.Interval]]:
45110
"""Gets the intervals for each category.
46111
@@ -203,8 +268,9 @@ def get_exposure_column_name(axis: str) -> str:
203268
@property
204269
def configuration_defaults(self) -> dict[str, Any]:
205270
configuration_defaults = super().configuration_defaults
271+
# Add birth exposure data source
206272
configuration_defaults[self.name]["data_sources"][
207-
"exposure"
273+
"birth_exposure"
208274
] = f"{self.risk}.birth_exposure"
209275
configuration_defaults[self.name]["distribution_type"] = "lbwsg"
210276
return configuration_defaults
@@ -224,6 +290,7 @@ def __init__(self):
224290
def setup(self, builder: Builder) -> None:
225291
super().setup(builder)
226292
self.birth_exposures = self.get_birth_exposure_pipelines(builder)
293+
self.configuration_age_end = builder.configuration.population.initialization_age_max
227294

228295
#################
229296
# Setup methods #
@@ -242,7 +309,7 @@ def get_birth_exposure_pipelines(self, builder: Builder) -> dict[str, Pipeline]:
242309
self.exposure_distribution.lookup_tables.values()
243310
)
244311

245-
def get_pipeline(axis_: str):
312+
def get_pipeline(axis_: str) -> Pipeline:
246313
return builder.value.register_value_producer(
247314
self.birth_exposure_pipeline_name(axis_),
248315
source=lambda index: self.get_birth_exposure(axis_, index),
@@ -260,12 +327,23 @@ def get_pipeline(axis_: str):
260327
########################
261328

262329
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
263-
birth_exposures = {
264-
self.get_exposure_column_name(axis): self.birth_exposures[
265-
self.birth_exposure_pipeline_name(axis)
266-
](pop_data.index)
267-
for axis in self.AXES
268-
}
330+
if pop_data.user_data.get("age_end", self.configuration_age_end) == 0:
331+
self.exposure_distribution.exposure_key = "birth_exposure"
332+
else:
333+
self.exposure_distribution.exposure_key = "exposure"
334+
335+
try:
336+
birth_exposures = {
337+
self.get_exposure_column_name(axis): self.birth_exposures[
338+
self.birth_exposure_pipeline_name(axis)
339+
](pop_data.index)
340+
for axis in self.AXES
341+
}
342+
except KeyError:
343+
raise ConfigurationError(
344+
f"{self.exposure_distribution.exposure_key} data for {self.name} is missing from the "
345+
"simulation. Simulants cannot be initialized."
346+
)
269347
self.population_view.update(pd.DataFrame(birth_exposures))
270348

271349
##################################

tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Generator
22
from pathlib import Path
33

44
import pytest
5+
from _pytest.logging import LogCaptureFixture
56
from layered_config_tree import LayeredConfigTree
7+
from loguru import logger
68
from vivarium.framework.configuration import build_simulation_configuration
79
from vivarium_testing_utils import FuzzyChecker
810

@@ -69,3 +71,10 @@ def fuzzy_checker() -> FuzzyChecker:
6971
yield checker
7072
test_dir = Path(__file__).resolve().parent
7173
checker.save_diagnostic_output(test_dir)
74+
75+
76+
@pytest.fixture
77+
def caplog(caplog: LogCaptureFixture) -> Generator[LogCaptureFixture, None, None]:
78+
handler_id = logger.add(caplog.handler, format="{message}")
79+
yield caplog
80+
logger.remove(handler_id)

tests/risks/test_low_birth_weight_and_short_gestation.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
import pandas as pd
33
import pytest
4+
from layered_config_tree import ConfigurationError
5+
from vivarium import InteractiveContext
6+
from vivarium.testing_utilities import TestPopulation
47

58
from tests.risks.test_effect import _setup_risk_effect_simulation
69
from tests.test_utilities import make_age_bins
@@ -57,11 +60,15 @@ def test_lbwsg_risk_effect_rr_pipeline(base_config, base_plugins, mock_rr_interp
5760
# Have to match age bins and rr data to make age intervals
5861
rr_data = make_categorical_data(agees)
5962
# Exposure data used for risk component
60-
exposure = make_categorical_data(agees)
63+
birth_exposure = make_categorical_data(agees)
64+
exposure = birth_exposure.copy()
65+
exposure.loc[exposure["value"] == 0.75, "value"] = 0.65
66+
exposure.loc[exposure["value"] == 0.25, "value"] = 0.35
6167

6268
# Add data dict to add to artifact
6369
data = {
64-
f"{risk.name}.birth_exposure": exposure,
70+
f"{risk.name}.birth_exposure": birth_exposure,
71+
f"{risk.name}.exposure": exposure,
6572
f"{risk.name}.relative_risk": rr_data,
6673
f"{risk.name}.population_attributable_fraction": 0,
6774
f"{risk.name}.categories": categories,
@@ -81,6 +88,16 @@ def test_lbwsg_risk_effect_rr_pipeline(base_config, base_plugins, mock_rr_interp
8188
)
8289
sim = _setup_risk_effect_simulation(base_config, base_plugins, risk, lbwsg_effect, data)
8390
pop = sim.get_population()
91+
# Verify exposure is used instead of birth_exposure since age end is 1.0
92+
# Check values of pipeline match birth exposure data since age_end is 0.0
93+
exposure_pipeline_values = sim.get_value(
94+
"risk_factor.low_birth_weight_and_short_gestation.exposure_parameters"
95+
)(pop.index)
96+
assert isinstance(exposure_pipeline_values, pd.DataFrame)
97+
assert "cat81" in exposure_pipeline_values.columns
98+
assert "cat82" in exposure_pipeline_values.columns
99+
assert (exposure_pipeline_values["cat81"] == 0.65).all()
100+
assert (exposure_pipeline_values["cat82"] == 0.35).all()
84101

85102
expected_pipeline_name = (
86103
f"effect_of_{lbwsg_effect.risk.name}_on_{lbwsg_effect.target.name}.relative_risk"
@@ -120,7 +137,8 @@ def map_age_groups(value):
120137
assert (actual_rr == 1.0).all()
121138

122139

123-
def test_use_birth_exposure(base_config, base_plugins, mock_rr_interpolators):
140+
@pytest.mark.parametrize("age_end", [0.0, 1.0])
141+
def test_use_exposure(base_config, base_plugins, mock_rr_interpolators, age_end):
124142
risk = LBWSGRisk()
125143
lbwsg_effect = LBWSGRiskEffect("cause.test_cause.cause_specific_mortality_rate")
126144

@@ -135,7 +153,7 @@ def test_use_birth_exposure(base_config, base_plugins, mock_rr_interpolators):
135153
# Have to match age bins and rr data to make age intervals
136154
rr_data = make_categorical_data(ages)
137155
# Format birth exposure data
138-
exposure = pd.DataFrame(
156+
birth_exposure = pd.DataFrame(
139157
{
140158
"sex": ["Male", "Female", "Male", "Female"],
141159
"year_start": [2021, 2021, 2021, 2021],
@@ -144,35 +162,123 @@ def test_use_birth_exposure(base_config, base_plugins, mock_rr_interpolators):
144162
"value": [0.75, 0.75, 0.25, 0.25],
145163
}
146164
)
165+
exposure = birth_exposure.copy()
166+
exposure["value"] = [0.65, 0.65, 0.35, 0.35]
147167

148168
# Add data dict to add to artifact
149169
data = {
150-
f"{risk.name}.birth_exposure": exposure,
170+
f"{risk.name}.birth_exposure": birth_exposure,
171+
f"{risk.name}.exposure": exposure,
151172
f"{risk.name}.relative_risk": rr_data,
152173
f"{risk.name}.population_attributable_fraction": 0,
153174
f"{risk.name}.categories": categories,
154175
f"{risk.name}.relative_risk_interpolator": mock_rr_interpolators,
155176
}
156177

157178
# Only have neontal age groups
158-
age_start = 0.0
159-
age_end = 1.0
179+
age_end = 0.0
160180
base_config.update(
161181
{
162182
"population": {
163-
"initialization_age_start": age_start,
183+
"initialization_age_start": 0.0,
164184
"initialization_age_max": age_end,
165-
}
185+
},
166186
}
167187
)
168188
sim = _setup_risk_effect_simulation(base_config, base_plugins, risk, lbwsg_effect, data)
169189
pop = sim.get_population()
190+
# Check values of pipeline match birth exposure data since age_end is 0.0
191+
exposure_pipeline_values = sim.get_value(
192+
"risk_factor.low_birth_weight_and_short_gestation.exposure_parameters"
193+
)(pop.index)
194+
assert isinstance(exposure_pipeline_values, pd.DataFrame)
195+
assert "cat81" in exposure_pipeline_values.columns
196+
assert "cat82" in exposure_pipeline_values.columns
197+
exposure_values = {
198+
0.0: {"cat81": 0.75, "cat82": 0.25},
199+
1.0: {"cat81": 0.65, "cat82": 0.35},
200+
}
201+
assert (exposure_pipeline_values["cat81"] == exposure_values[age_end]["cat81"]).all()
202+
assert (exposure_pipeline_values["cat82"] == exposure_values[age_end]["cat82"]).all()
170203

171204
# Assert LBWSG birth exposure columns were created
172205
assert "birth_weight_exposure" in pop.columns
173206
assert "gestational_age_exposure" in pop.columns
174207

175208

209+
@pytest.mark.parametrize("exposure_key", ["birth_exposure", "exposure", "missing"])
210+
def test_lbwsg_exposure_data_logging(exposure_key, base_config, mocker, caplog) -> None:
211+
risk = LBWSGRisk()
212+
213+
# Add mock data to artifact
214+
# Format birth exposure data
215+
exposure_data = pd.DataFrame(
216+
{
217+
"sex": ["Male", "Female", "Male", "Female"],
218+
"year_start": [2021, 2021, 2021, 2021],
219+
"year_end": [2022, 2022, 2022, 2022],
220+
"parameter": ["cat81", "cat81", "cat82", "cat82"],
221+
"value": [0.75, 0.75, 0.25, 0.25],
222+
}
223+
)
224+
225+
# Only have neontal age groups
226+
if exposure_key == "birth_exposure":
227+
age_end = 0.0
228+
else:
229+
age_end = 1.0
230+
231+
if exposure_key != "missing":
232+
no_data_dict = {
233+
"birth_exposure": "exposure",
234+
"exposure": "birth_exposure",
235+
}
236+
no_data_key = no_data_dict[exposure_key]
237+
override_config = {
238+
"population": {
239+
"initialization_age_start": 0.0,
240+
"initialization_age_max": age_end,
241+
},
242+
risk.name: {
243+
"data_sources": {
244+
exposure_key: exposure_data,
245+
}
246+
},
247+
}
248+
else:
249+
override_config = {
250+
"population": {
251+
"initialization_age_start": 0.0,
252+
"initialization_age_max": age_end,
253+
},
254+
}
255+
256+
# Patch get_category intervals so we do not need the mock artifact
257+
mocker.patch(
258+
"vivarium_public_health.risks.implementations.low_birth_weight_and_short_gestation.LBWSGDistribution.get_category_intervals"
259+
)
260+
assert not caplog.records
261+
if exposure_key != "missing":
262+
missing_key = "exposure" if exposure_key == "birth_exposure" else "birth_exposure"
263+
sim = InteractiveContext(
264+
base_config,
265+
components=[TestPopulation(), risk],
266+
configuration=override_config,
267+
)
268+
assert f"The data for LBWSG {missing_key} is missing from the simulation"
269+
else:
270+
with pytest.raises(
271+
ConfigurationError,
272+
match="The LBWSG distribution requires either 'birth_exposure' or 'exposure' data to be "
273+
"available in the simulation.",
274+
):
275+
InteractiveContext(
276+
base_config,
277+
components=[TestPopulation(), risk],
278+
configuration=override_config,
279+
)
280+
281+
176282
def make_categorical_data(data: pd.DataFrame) -> pd.DataFrame:
177283
# Takes age gropus and adds sex, years, categories, and values
178284
dfs = []

0 commit comments

Comments
 (0)