Skip to content

Commit 49cb099

Browse files
Cache exposure pipeline for nonloglinear effects
1 parent 21960e8 commit 49cb099

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

src/vivarium_public_health/risks/base_risk.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pandas as pd
1414
from vivarium import Component
1515
from vivarium.framework.engine import Builder
16+
from vivarium.framework.event import Event
1617
from vivarium.framework.population import SimulantData
1718
from vivarium.framework.randomness import RandomnessStream
1819

@@ -171,6 +172,7 @@ def __init__(self, risk: str):
171172
self.randomness_stream_name = f"initial_{self.risk.name}_propensity"
172173
self.propensity_name = f"{self.risk.name}.propensity"
173174
self.exposure_name = f"{self.risk.name}.exposure"
175+
self.exposure_column_name = f"{self.risk.name}_exposure_for_non_loglinear_riskeffect"
174176

175177
#################
176178
# Setup methods #
@@ -184,9 +186,20 @@ def setup(self, builder: Builder) -> None:
184186

185187
self.randomness = self.get_randomness_stream(builder)
186188
self.register_exposure_pipeline(builder)
189+
190+
self.includes_non_loglinear_risk_effect = bool(
191+
[
192+
component
193+
for component in builder.components.list_components()
194+
if component.startswith(f"non_log_linear_risk_effect.{self.risk.name}_on_")
195+
]
196+
)
197+
columns_to_create = [self.propensity_name]
198+
if self.includes_non_loglinear_risk_effect:
199+
columns_to_create.append(self.exposure_column_name)
187200
builder.population.register_initializer(
188201
initializer=self.on_initialize_simulants,
189-
columns=self.propensity_name,
202+
columns=columns_to_create,
190203
required_resources=[self.randomness],
191204
)
192205

@@ -279,3 +292,22 @@ def on_initialize_simulants(self, pop_data: SimulantData) -> None:
279292
self.randomness.get_draw(pop_data.index), name=self.propensity_name
280293
)
281294
self.population_view.update(propensity)
295+
self.update_exposure_column(pop_data.index)
296+
297+
def on_time_step_prepare(self, event: Event) -> None:
298+
self.update_exposure_column(event.index)
299+
300+
def update_exposure_column(self, index: pd.Index) -> None:
301+
"""Updates the exposure column with pipeline values.
302+
303+
HACK: This is effectively caching the exposure pipeline for use by other
304+
components. Specifically, :meth:`vivarium_public_health.risks.effect.NonLogLinearRiskEffect.get_relative_risk_source`
305+
needs the exposure values but calling that pipeline was very slow. By
306+
maintaining a cached copy of the exposure values in a private column, we
307+
can then request that corresponding "simple" pipeline from the population
308+
view instead which is significantly faster.
309+
"""
310+
if self.includes_non_loglinear_risk_effect:
311+
exposure = self.population_view.get_attributes(index, self.exposure_name)
312+
exposure.name = self.exposure_column_name
313+
self.population_view.update(exposure)

src/vivarium_public_health/risks/effect.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,12 @@ def define_rr_intervals(df: pd.DataFrame) -> pd.DataFrame:
443443
.reset_index()
444444
)
445445
rr_data = rr_data.drop("parameter", axis=1)
446-
rr_data[f"{self.risk.name}.exposure_start"] = rr_data["left_exposure"]
447-
rr_data[f"{self.risk.name}.exposure_end"] = rr_data["right_exposure"]
446+
rr_data[f"{self.risk.name}_exposure_for_nonloglinear_riskeffect_start"] = rr_data[
447+
"left_exposure"
448+
]
449+
rr_data[f"{self.risk.name}_exposure_for_nonloglinear_riskeffect_end"] = rr_data[
450+
"right_exposure"
451+
]
448452
# build lookup table
449453
rr_value_cols = ["left_exposure", "left_rr", "right_exposure", "right_rr"]
450454
return self.build_lookup_table(
@@ -511,8 +515,10 @@ def get_rr_at_tmrel(rr_data: pd.DataFrame) -> float:
511515
def get_relative_risk_source(self, builder: Builder) -> Callable[[pd.Index], pd.Series]:
512516
def generate_relative_risk(index: pd.Index) -> pd.Series:
513517
rr_intervals = self.relative_risk_table(index)
518+
# NOTE: We are calling the cached exposure pipeline here for performance
519+
# purposes (as opposed to the f{self.risk.name}.expousure pipeline).
514520
exposure = self.population_view.get_attributes(
515-
index, f"{self.risk.name}.exposure"
521+
index, f"{self.risk.name}_exposure_for_nonloglinear_riskeffect"
516522
)
517523
x1, x2 = (
518524
rr_intervals["left_exposure"].values,

0 commit comments

Comments
 (0)