Skip to content

Commit 06ec6d3

Browse files
authored
Albrja/mic 6186/intervention component (#527)
Albrja/mic 6186/intervention component Refactor and implementation for Intervention component - *Category*: Implementation/Refactor - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6186 Changes and notes -Adds super classes Exposure and Exposure effect -Adds Intervention and InterventionEffect classes equivalent to Risk and RiskEffect classes -updates RiskExposureDistribution classes to take the Exposure component instead of the entity string (exposure vs coverage)
1 parent 7c7f669 commit 06ec6d3

27 files changed

+1237
-617
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
**4.3.0 - 07/29/25**
2+
3+
- Refactor to create two new abstract classes, Exposure and ExposureEffect
4+
- Add Intervention and InterventionEffect classes
5+
- Refactor to pass Exposure component to RiskExposureDistribution classes
6+
- Update exposure category names for DichotomousDistribution
7+
18
**4.2.6 - 07/25/25**
29

310
- Feature: Support new environment creation via 'make build-env'
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.. automodule:: vivarium_public_health.exposure.distributions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.. automodule:: vivarium_public_health.exposure.effect
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.. automodule:: vivarium_public_health.exposure.exposure
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
=======================
2+
Exposure Modeling Tools
3+
=======================
4+
5+
.. automodule:: vivarium_public_health.exposure
6+
7+
.. toctree::
8+
:maxdepth: 2
9+
:glob:
10+
11+
*

docs/source/api_reference/risks/distributions.rst

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.. automodule:: vivarium_public_health.treatment.intervention

src/vivarium_public_health/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
Risk,
4747
RiskEffect,
4848
)
49-
from vivarium_public_health.treatment import AbsoluteShift, LinearScaleUp, TherapeuticInertia
49+
from vivarium_public_health.treatment import (
50+
AbsoluteShift,
51+
Intervention,
52+
InterventionEffect,
53+
LinearScaleUp,
54+
TherapeuticInertia,
55+
)
5056

5157
__all__ = [
5258
__author__,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .effect import ExposureEffect
2+
from .exposure import Exposure

src/vivarium_public_health/risks/distributions.py renamed to src/vivarium_public_health/exposure/distributions.py

Lines changed: 102 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""
22
=================================
3-
Risk Exposure Distribution Models
3+
Exposure Distribution Models
44
=================================
55
66
This module contains tools for modeling several different risk
77
exposure distributions.
88
99
"""
1010

11+
from __future__ import annotations
12+
13+
import warnings
1114
from abc import ABC, abstractmethod
1215
from collections.abc import Callable
16+
from typing import TYPE_CHECKING
1317

1418
import numpy as np
1519
import pandas as pd
@@ -24,36 +28,48 @@
2428
from vivarium_public_health.risks.data_transformations import pivot_categorical
2529
from vivarium_public_health.utilities import EntityString, get_lookup_columns
2630

31+
if TYPE_CHECKING:
32+
from vivarium_public_health.exposure import Exposure
33+
2734

2835
class MissingDataError(Exception):
2936
pass
3037

3138

32-
class RiskExposureDistribution(Component, ABC):
39+
class ExposureDistribution(Component, ABC):
3340

3441
#####################
3542
# Lifecycle methods #
3643
#####################
3744

3845
def __init__(
3946
self,
40-
risk: EntityString,
47+
exposure_component: Exposure,
4148
distribution_type: str,
4249
exposure_data: int | float | pd.DataFrame | None = None,
4350
) -> None:
4451
super().__init__()
45-
self.risk = risk
52+
self.exposure_component = exposure_component
4653
self.distribution_type = distribution_type
54+
if (
55+
self.distribution_type != "dichotomous"
56+
and self.exposure_component.entity.type == "intervention"
57+
):
58+
raise NotImplementedError(
59+
f"Distribution type {self.distribution_type} is not supported for interventions."
60+
)
4761
self._exposure_data = exposure_data
4862

49-
self.parameters_pipeline_name = f"{self.risk}.exposure_parameters"
63+
self.parameters_pipeline_name = (
64+
f"{self.exposure_component.entity}.exposure_parameters"
65+
)
5066

5167
#################
5268
# Setup methods #
5369
#################
5470

5571
def get_configuration(self, builder: "Builder") -> LayeredConfigTree | None:
56-
return builder.configuration[self.risk]
72+
return builder.configuration[self.exposure_component.entity]
5773

5874
@abstractmethod
5975
def build_all_lookup_tables(self, builder: "Builder") -> None:
@@ -62,7 +78,9 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
6278
def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
6379
if self._exposure_data is not None:
6480
return self._exposure_data
65-
return self.get_data(builder, self.configuration["data_sources"]["exposure"])
81+
return self.get_data(
82+
builder, self.configuration["data_sources"][self.exposure_component.exposure_type]
83+
)
6684

6785
# noinspection PyAttributeOutsideInit
6886
def setup(self, builder: Builder) -> None:
@@ -87,7 +105,7 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
87105
raise NotImplementedError
88106

89107

90-
class EnsembleDistribution(RiskExposureDistribution):
108+
class EnsembleDistribution(ExposureDistribution):
91109
##############
92110
# Properties #
93111
##############
@@ -106,7 +124,7 @@ def initialization_requirements(self) -> list[str | Resource]:
106124

107125
def __init__(self, risk: EntityString, distribution_type: str = "ensemble") -> None:
108126
super().__init__(risk, distribution_type)
109-
self._propensity = f"ensemble_propensity_{self.risk}"
127+
self._propensity = f"ensemble_propensity_{self.exposure_component.entity}"
110128

111129
#################
112130
# Setup methods #
@@ -129,7 +147,11 @@ def build_all_lookup_tables(self, builder: Builder) -> None:
129147
distributions = list(raw_weights["parameter"].unique())
130148

131149
raw_weights = pivot_categorical(
132-
builder, self.risk, raw_weights, pivot_column="parameter", reset_index=False
150+
builder,
151+
self.exposure_component.entity,
152+
raw_weights,
153+
pivot_column="parameter",
154+
reset_index=False,
133155
)
134156

135157
weights, parameters = rd.EnsembleDistribution.get_parameters(
@@ -201,7 +223,7 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
201223
return x
202224

203225

204-
class ContinuousDistribution(RiskExposureDistribution):
226+
class ContinuousDistribution(ExposureDistribution):
205227
#####################
206228
# Lifecycle methods #
207229
#####################
@@ -261,12 +283,12 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
261283
return x
262284

263285

264-
class PolytomousDistribution(RiskExposureDistribution):
286+
class PolytomousDistribution(ExposureDistribution):
265287
@property
266288
def categories(self) -> list[str]:
267-
# These need to be sorted so the cumulative sum is in the ocrrect order of categories
289+
# These need to be sorted so the cumulative sum is in the correct order of categories
268290
# and results are therefore reproducible and correct
269-
return sorted(self.lookup_tables["exposure"].value_columns)
291+
return sorted(self.lookup_tables[self.exposure_component.exposure_type].value_columns)
270292

271293
#################
272294
# Setup methods #
@@ -277,9 +299,11 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
277299
exposure_value_columns = self.get_exposure_value_columns(exposure_data)
278300

279301
if isinstance(exposure_data, pd.DataFrame):
280-
exposure_data = pivot_categorical(builder, self.risk, exposure_data, "parameter")
302+
exposure_data = pivot_categorical(
303+
builder, self.exposure_component.entity, exposure_data, "parameter"
304+
)
281305

282-
self.lookup_tables["exposure"] = self.build_lookup_table(
306+
self.lookup_tables[self.exposure_component.exposure_type] = self.build_lookup_table(
283307
builder, exposure_data, exposure_value_columns
284308
)
285309

@@ -293,9 +317,11 @@ def get_exposure_value_columns(
293317
def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
294318
return builder.value.register_value_producer(
295319
self.parameters_pipeline_name,
296-
source=self.lookup_tables["exposure"],
320+
source=self.lookup_tables[self.exposure_component.exposure_type],
297321
component=self,
298-
required_resources=get_lookup_columns([self.lookup_tables["exposure"]]),
322+
required_resources=get_lookup_columns(
323+
[self.lookup_tables[self.exposure_component.exposure_type]]
324+
),
299325
)
300326

301327
##################
@@ -313,12 +339,12 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
313339
).sum(axis=1)
314340
return pd.Series(
315341
np.array(self.categories)[category_index],
316-
name=self.risk + ".exposure",
342+
name=f"{self.exposure_component.entity}.exposure",
317343
index=quantiles.index,
318344
)
319345

320346

321-
class DichotomousDistribution(RiskExposureDistribution):
347+
class DichotomousDistribution(ExposureDistribution):
322348

323349
#################
324350
# Setup methods #
@@ -332,11 +358,15 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
332358
any_negatives = (exposure_data[exposure_value_columns] < 0).any().any()
333359
any_over_one = (exposure_data[exposure_value_columns] > 1).any().any()
334360
if any_negatives or any_over_one:
335-
raise ValueError(f"All exposures must be in the range [0, 1] for {self.risk}")
361+
raise ValueError(
362+
f"All exposures must be in the range [0, 1] for {self.exposure_component.entity}"
363+
)
336364
elif exposure_data < 0 or exposure_data > 1:
337-
raise ValueError(f"Exposure must be in the range [0, 1] for {self.risk}")
365+
raise ValueError(
366+
f"Exposure must be in the range [0, 1] for {self.exposure_component.entity}"
367+
)
338368

339-
self.lookup_tables["exposure"] = self.build_lookup_table(
369+
self.lookup_tables[self.exposure_component.exposure_type] = self.build_lookup_table(
340370
builder, exposure_data, exposure_value_columns
341371
)
342372
self.lookup_tables["paf"] = self.build_lookup_table(builder, 0.0)
@@ -350,20 +380,43 @@ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
350380
# rebin exposure categories
351381
self.validate_rebin_source(builder, exposure_data)
352382
rebin_exposed_categories = set(self.configuration["rebinned_exposed"])
383+
# Check if risk exposure is exposed vs cat1
384+
if (
385+
"cat1" in exposure_data["parameter"].unique()
386+
and self.exposure_component.entity.type == "risk_factor"
387+
):
388+
warnings.warn(
389+
"Using 'cat1' and 'cat2' for dichotomous exposure is deprecated and will be removed in a future release. Use 'exposed' and 'unexposed' instead.",
390+
FutureWarning,
391+
stacklevel=2,
392+
)
393+
exposure_data["parameter"] = exposure_data["parameter"].replace(
394+
{
395+
"cat1": self.exposure_component.dichotomous_exposure_category_names.exposed,
396+
"cat2": self.exposure_component.dichotomous_exposure_category_names.unexposed,
397+
}
398+
)
353399
if rebin_exposed_categories:
354-
exposure_data = self._rebin_exposure_data(exposure_data, rebin_exposed_categories)
400+
exposure_data = self._rebin_exposure_data(
401+
exposure_data,
402+
rebin_exposed_categories,
403+
self.exposure_component.dichotomous_exposure_category_names.exposed,
404+
)
355405

356-
exposure_data = exposure_data[exposure_data["parameter"] == "cat1"]
406+
exposure_data = exposure_data[
407+
exposure_data["parameter"]
408+
== self.exposure_component.dichotomous_exposure_category_names.exposed
409+
]
357410
return exposure_data.drop(columns="parameter")
358411

359412
@staticmethod
360413
def _rebin_exposure_data(
361-
exposure_data: pd.DataFrame, rebin_exposed_categories: set
414+
exposure_data: pd.DataFrame, rebin_exposed_categories: set, exposed_category_name: str
362415
) -> pd.DataFrame:
363416
exposure_data = exposure_data[
364417
exposure_data["parameter"].isin(rebin_exposed_categories)
365418
]
366-
exposure_data["parameter"] = "cat1"
419+
exposure_data["parameter"] = exposed_category_name
367420
exposure_data = (
368421
exposure_data.groupby(list(exposure_data.columns.difference(["value"])))
369422
.sum()
@@ -382,7 +435,7 @@ def get_exposure_value_columns(
382435
def setup(self, builder: Builder) -> None:
383436
super().setup(builder)
384437
self.joint_paf = builder.value.register_value_producer(
385-
f"{self.risk}.exposure_parameters.paf",
438+
f"{self.exposure_component.entity}.exposure_parameters.paf",
386439
source=lambda index: [self.lookup_tables["paf"](index)],
387440
component=self,
388441
preferred_combiner=list_combiner,
@@ -391,10 +444,12 @@ def setup(self, builder: Builder) -> None:
391444

392445
def get_exposure_parameter_pipeline(self, builder: Builder) -> Pipeline:
393446
return builder.value.register_value_producer(
394-
f"{self.risk}.exposure_parameters",
447+
f"{self.exposure_component.entity}.exposure_parameters",
395448
source=self.exposure_parameter_source,
396449
component=self,
397-
required_resources=get_lookup_columns([self.lookup_tables["exposure"]]),
450+
required_resources=get_lookup_columns(
451+
[self.lookup_tables[self.exposure_component.exposure_type]]
452+
),
398453
)
399454

400455
##############
@@ -405,29 +460,31 @@ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
405460
if not isinstance(data, pd.DataFrame):
406461
return
407462

408-
rebin_exposed_categories = set(builder.configuration[self.risk]["rebinned_exposed"])
463+
rebin_exposed_categories = set(
464+
builder.configuration[self.exposure_component.entity]["rebinned_exposed"]
465+
)
409466

410467
if (
411468
rebin_exposed_categories
412-
and builder.configuration[self.risk]["category_thresholds"]
469+
and builder.configuration[self.exposure_component.entity]["category_thresholds"]
413470
):
414471
raise ValueError(
415472
f"Rebinning and category thresholds are mutually exclusive. "
416-
f"You provided both for {self.risk.name}."
473+
f"You provided both for {self.exposure_component.entity.name}."
417474
)
418475

419476
invalid_cats = rebin_exposed_categories.difference(set(data.parameter))
420477
if invalid_cats:
421478
raise ValueError(
422479
f"The following provided categories for the rebinned exposed "
423-
f"category of {self.risk.name} are not found in the exposure data: "
480+
f"category of {self.exposure_component.entity.name} are not found in the exposure data: "
424481
f"{invalid_cats}."
425482
)
426483

427484
if rebin_exposed_categories == set(data.parameter):
428485
raise ValueError(
429486
f"The provided categories for the rebinned exposed category of "
430-
f"{self.risk.name} comprise all categories for the exposure data. "
487+
f"{self.exposure_component.entity.name} comprise all categories for the exposure data. "
431488
f"At least one category must be left out of the provided categories "
432489
f"to be rebinned into the unexposed category."
433490
)
@@ -437,7 +494,9 @@ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
437494
##################################
438495

439496
def exposure_parameter_source(self, index: pd.Index) -> pd.Series:
440-
base_exposure = self.lookup_tables["exposure"](index).values
497+
base_exposure = self.lookup_tables[self.exposure_component.exposure_type](
498+
index
499+
).values
441500
joint_paf = self.joint_paf(index).values
442501
return pd.Series(base_exposure * (1 - joint_paf), index=index, name="values")
443502

@@ -448,8 +507,13 @@ def exposure_parameter_source(self, index: pd.Index) -> pd.Series:
448507
def ppf(self, quantiles: pd.Series) -> pd.Series:
449508
exposed = quantiles < self.exposure_parameters(quantiles.index)
450509
return pd.Series(
451-
exposed.replace({True: "cat1", False: "cat2"}),
452-
name=self.risk + ".exposure",
510+
exposed.replace(
511+
{
512+
True: self.exposure_component.dichotomous_exposure_category_names.exposed,
513+
False: self.exposure_component.dichotomous_exposure_category_names.unexposed,
514+
}
515+
),
516+
name=f"{self.exposure_component.entity}.{self.exposure_component.exposure_type}",
453517
index=quantiles.index,
454518
)
455519

0 commit comments

Comments
 (0)