Skip to content

Commit db1a289

Browse files
authored
Albrja/bugfix/mic-5768/lbwsg-riskeffect-rr-pipeline (#504)
Albrja/bugfix/mic-5768/lbwsg-riskeffect-rr-pipeline Fixes bug in creation of relative risk pipeline for LBWSGRiskEffect component - *Category*: Bugfix - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5768 Changes and notes -updates way relative risk pipeline is created in LBWSGRiskEffect ### Testing <!-- Details on how code was verified, any unit tests local for the repo, regression testing, etc. At a minimum, this should include an integration test for a framework change. Consider: plots, images, (small) csv file. -->
1 parent aa23ee1 commit db1a289

File tree

6 files changed

+183
-15
lines changed

6 files changed

+183
-15
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**3.2.0 - 1/15/25**
2+
3+
- Bugfix: Fix bug in LBWSGRiskEffect where relative risk pipeline was not properly created
4+
15
**3.1.5 - 01/14/25**
26

37
- Bugfix: Vivarium InteractiveContext no longer returns int

src/vivarium_public_health/risks/effect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def setup(self, builder: Builder) -> None:
114114
self.exposure = self.get_risk_exposure(builder)
115115

116116
self._relative_risk_source = self.get_relative_risk_source(builder)
117-
self.relative_risk = self.get_relative_risk(builder)
117+
self.relative_risk = self.get_relative_risk_pipeline(builder)
118118

119119
self.register_target_modifier(builder)
120120
self.register_paf_modifier(builder)
@@ -297,7 +297,7 @@ def generate_relative_risk(index: pd.Index) -> pd.Series:
297297

298298
return generate_relative_risk
299299

300-
def get_relative_risk(self, builder: Builder) -> Pipeline:
300+
def get_relative_risk_pipeline(self, builder: Builder) -> Pipeline:
301301
return builder.value.register_value_producer(
302302
f"{self.risk.name}_on_{self.target.name}.relative_risk",
303303
self._relative_risk_source,

src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -361,18 +361,11 @@ def get_population_attributable_fraction_source(
361361
paf_data = builder.data.load(paf_key)
362362
return paf_data, builder.data.value_columns()(paf_key)
363363

364-
def get_target_modifier(
365-
self, builder: Builder
366-
) -> Callable[[pd.Index, pd.Series], pd.Series]:
367-
def adjust_target(index: pd.Index, target: pd.Series) -> pd.Series:
368-
return target * self.relative_risk(index)
369-
370-
return adjust_target
371-
372364
def register_target_modifier(self, builder: Builder) -> None:
373365
builder.value.register_value_modifier(
374366
self.target_pipeline_name,
375-
modifier=self.target_modifier,
367+
modifier=self.adjust_target,
368+
component=self,
376369
requires_values=[self.relative_risk_pipeline_name],
377370
)
378371

@@ -392,11 +385,12 @@ def get_age_intervals(self, builder: Builder) -> dict[str, pd.Interval]:
392385
for age_start in exposed_age_group_starts
393386
}
394387

395-
def get_relative_risk(self, builder: Builder) -> Pipeline:
388+
def get_relative_risk_pipeline(self, builder: Builder) -> Pipeline:
396389
return builder.value.register_value_producer(
397390
self.relative_risk_pipeline_name,
398-
source=self.get_relative_risk_source,
399-
requires_columns=["age"] + self.rr_column_names,
391+
source=self._relative_risk_source,
392+
component=self,
393+
required_resources=["age"] + self.rr_column_names,
400394
)
401395

402396
def get_interpolator(self, builder: Builder) -> pd.Series:
@@ -469,7 +463,7 @@ def get_relative_risk_for_age_group(age_group: str) -> pd.Series:
469463
# Pipeline sources and modifiers #
470464
##################################
471465

472-
def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
466+
def _get_relative_risk(self, index: pd.Index) -> pd.Series:
473467
pop = self.population_view.get(index)
474468
relative_risk = pd.Series(1.0, index=index, name=self.relative_risk_pipeline_name)
475469

@@ -479,3 +473,6 @@ def get_relative_risk_source(self, index: pd.Index) -> pd.Series:
479473
age_group_mask, self.relative_risk_column_name(age_group)
480474
]
481475
return relative_risk
476+
477+
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
478+
return self._get_relative_risk

tests/data/rr_interpolator.csv

Lines changed: 51 additions & 0 deletions
Large diffs are not rendered by default.

tests/risks/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,11 @@ def coverage_gap():
262262
cg_data["affected_risk_factors"] = ["test_risk"]
263263
cg_data["distribution"] = "dichotomous"
264264
return Risk(f"coverage_gap.{cg}"), cg_data
265+
266+
267+
@pytest.fixture
268+
def mock_rr_interpolators() -> pd.DataFrame:
269+
rr_interpolators = pd.read_csv("tests/data/rr_interpolator.csv")
270+
idx_cols = [col for col in rr_interpolators.columns if "draw" not in col]
271+
rr_interpolators = rr_interpolators.rename(columns={"draw_0": "value"})
272+
return rr_interpolators

tests/risks/test_low_birth_weight_and_short_gestation.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import numpy as np
2+
import pandas as pd
13
import pytest
24

5+
from tests.risks.test_effect import _setup_risk_effect_simulation
6+
from tests.test_utilities import make_age_bins
37
from vivarium_public_health.risks.implementations.low_birth_weight_and_short_gestation import (
48
LBWSGDistribution,
9+
LBWSGRisk,
10+
LBWSGRiskEffect,
511
)
12+
from vivarium_public_health.utilities import to_snake_case
613

714

815
@pytest.mark.parametrize(
@@ -32,3 +39,104 @@ def test_parsing_lbwsg_descriptions(description, expected_weight_values, expecte
3239
assert weight_interval.right == expected_weight_values[1]
3340
assert age_interval.left == expected_age_values[0]
3441
assert age_interval.right == expected_age_values[1]
42+
43+
44+
def test_lbwsg_risk_effect_rr_pipeline(
45+
base_config, base_plugins, mocker, mock_rr_interpolators
46+
):
47+
48+
risk = LBWSGRisk()
49+
lbwsg_effect = LBWSGRiskEffect("cause.test_cause.cause_specific_mortality_rate")
50+
51+
# Add mock data to artifact
52+
categories = {
53+
"cat81": "Neonatal preterm and LBWSG (estimation years) - [28, 30) wks, [2500, 3000) g",
54+
"cat82": "Neonatal preterm and LBWSG (estimation years) - [28, 30) wks, [3000, 3500) g",
55+
}
56+
# Create exposure with matching demograph index as age_bins
57+
age_bins = make_age_bins()
58+
agees = age_bins.drop(columns="age_group_name")
59+
exposure_data = make_categorical_exposure_data(agees)
60+
61+
# Add data dict to add to artifact
62+
data = {
63+
f"{risk.name}.exposure": exposure_data,
64+
f"{risk.name}.population_attributable_fraction": 0,
65+
f"{risk.name}.categories": categories,
66+
f"{risk.name}.relative_risk_interpolator": mock_rr_interpolators,
67+
}
68+
69+
# Only have neontal age groups
70+
age_start = 0.0
71+
age_end = 1.0
72+
base_config.update(
73+
{
74+
"population": {
75+
"initialization_age_start": age_start,
76+
"initialization_age_max": age_end,
77+
}
78+
}
79+
)
80+
sim = _setup_risk_effect_simulation(base_config, base_plugins, risk, lbwsg_effect, data)
81+
pop = sim.get_population()
82+
83+
expected_pipeline_name = (
84+
f"effect_of_{lbwsg_effect.risk.name}_on_{lbwsg_effect.target.name}.relative_risk"
85+
)
86+
assert expected_pipeline_name in sim.list_values()
87+
88+
# Get age group names to lookup rr interpolator later
89+
def map_age_groups(value):
90+
for i, row in age_bins.iterrows():
91+
if row["age_start"] <= value <= row["age_end"]:
92+
return row["age_group_name"]
93+
return None
94+
95+
mapped_age_groups = pop["age"].apply(map_age_groups)
96+
mapped_age_groups = mapped_age_groups.apply(to_snake_case)
97+
sim_data = pop[["sex", "birth_weight_exposure", "gestational_age_exposure"]].copy()
98+
sim_data["age_group_name"] = mapped_age_groups
99+
100+
# Test the 4 different demographic groups
101+
for sex in ["Male", "Female"]:
102+
for age_group_name in ["early_neonatal", "late_neonatal", "post_neonatal"]:
103+
interpolator = lbwsg_effect.interpolator[sex, age_group_name]
104+
demo_idx = sim_data.index[
105+
(sim_data["sex"] == sex) & (sim_data["age_group_name"] == age_group_name)
106+
]
107+
sub_pop = sim_data.loc[demo_idx]
108+
actual_rr = sim.get_value(expected_pipeline_name)(demo_idx)
109+
sub_pop["expected_rr"] = np.exp(
110+
interpolator(
111+
sub_pop["gestational_age_exposure"],
112+
sub_pop["birth_weight_exposure"],
113+
grid=False,
114+
)
115+
)
116+
assert (actual_rr == sub_pop["expected_rr"]).all()
117+
if age_group_name == "post_neonatal":
118+
assert (actual_rr == 1.0).all()
119+
120+
121+
def make_categorical_exposure_data(data: pd.DataFrame) -> pd.DataFrame:
122+
# Takes age gropus and adds sex, years, categories, and values
123+
exposure_dfs = []
124+
for year in range(1990, 2017):
125+
tmp = data.copy()
126+
tmp["year_start"] = year
127+
tmp["year_end"] = year + 1
128+
p_81 = tmp.copy()
129+
p_81["parameter"] = "cat81"
130+
p_81["value"] = 0.75
131+
p_82 = tmp.copy()
132+
p_82["parameter"] = "cat82"
133+
p_82["value"] = 0.25
134+
categories_df = pd.concat([p_81, p_82])
135+
male_tmp = categories_df.copy()
136+
male_tmp["sex"] = "Male"
137+
female_tmp = categories_df.copy()
138+
female_tmp["sex"] = "Female"
139+
age_sex_df = pd.concat([male_tmp, female_tmp])
140+
exposure_dfs.append(age_sex_df)
141+
142+
return pd.concat(exposure_dfs)

0 commit comments

Comments
 (0)