Skip to content

Commit f123d9f

Browse files
authored
Albrja/mic-5451/Epic/update required resrouces (#519)
Albrja/mic-5451/Epic/update required resrouces Update required resources for all components - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5451 Changes and notes -Update required resources for all components ### Testing <!-- Details on how code was verified, any unit tests local for the repo, regression testing, etc. At a minimum, this should include an integration test for a framework change. Consider: plots, images, (small) csv file. -->
1 parent 0508e04 commit f123d9f

22 files changed

+189
-138
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# default owners
2-
* @albrja @collijk @hussain-jafari @patricktnast @rmudambi @stevebachmeier
2+
* @albrja @hussain-jafari @patricktnast @rmudambi @stevebachmeier

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**4.0.1 - 03/21/25**
2+
3+
- Feature: Update required resources for all components
4+
15
**4.0.0 - 03/17/25**
26

37
- Feature: Use birth exposure artifact key for LBWSG components

src/vivarium_public_health/disease/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def setup(self, builder: Builder) -> None:
114114
builder.value.register_value_modifier(
115115
"cause_specific_mortality_rate",
116116
self.adjust_cause_specific_mortality_rate,
117-
requires_columns=["age", "sex"],
117+
component=self,
118+
required_resources=["age", "sex"],
118119
)
119120

120121
def on_initialize_simulants(self, pop_data: SimulantData) -> None:

src/vivarium_public_health/disease/special_disease.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vivarium import Component
1717
from vivarium.framework.event import Event
1818
from vivarium.framework.population import SimulantData
19+
from vivarium.framework.resource import Resource
1920
from vivarium.framework.values import list_combiner, union_post_processor
2021

2122
from vivarium_public_health.disease.transition import TransitionString
@@ -123,12 +124,8 @@ def columns_required(self) -> list[str] | None:
123124
return ["alive"]
124125

125126
@property
126-
def initialization_requirements(self) -> dict[str, list[str]]:
127-
return {
128-
"requires_columns": [],
129-
"requires_values": [f"{self.risk.name}.exposure"],
130-
"requires_streams": [],
131-
}
127+
def initialization_requirements(self) -> list[str | Resource]:
128+
return [self.exposure_pipeline]
132129

133130
@property
134131
def state_names(self):
@@ -171,17 +168,21 @@ def setup(self, builder):
171168
self.disability_weight = builder.value.register_value_producer(
172169
f"{self.cause.name}.disability_weight",
173170
source=self.compute_disability_weight,
174-
requires_columns=get_lookup_columns(
171+
component=self,
172+
required_resources=get_lookup_columns(
175173
[self.lookup_tables["raw_disability_weight"]]
176174
),
177175
)
178176
builder.value.register_value_modifier(
179-
"all_causes.disability_weight", modifier=self.disability_weight
177+
"all_causes.disability_weight",
178+
modifier=self.disability_weight,
179+
component=self,
180180
)
181181
builder.value.register_value_modifier(
182182
"cause_specific_mortality_rate",
183183
self.adjust_cause_specific_mortality_rate,
184-
requires_columns=get_lookup_columns(
184+
component=self,
185+
required_resources=get_lookup_columns(
185186
[self.lookup_tables["cause_specific_mortality_rate"]]
186187
),
187188
)
@@ -190,29 +191,32 @@ def setup(self, builder):
190191
self.excess_mortality_rate = builder.value.register_value_producer(
191192
self.excess_mortality_rate_pipeline_name,
192193
source=self.compute_excess_mortality_rate,
193-
requires_columns=get_lookup_columns(
194+
component=self,
195+
required_resources=get_lookup_columns(
194196
[self.lookup_tables["excess_mortality_rate"]]
195-
),
196-
requires_values=[self.excess_mortality_rate_paf_pipeline_name],
197+
)
198+
+ [self.joint_paf],
197199
)
198200
self.joint_paf = builder.value.register_value_producer(
199201
self.excess_mortality_rate_paf_pipeline_name,
200202
source=lambda idx: [self.lookup_tables["population_attributable_fraction"](idx)],
203+
component=self,
201204
preferred_combiner=list_combiner,
202205
preferred_post_processor=union_post_processor,
203206
)
204207
builder.value.register_value_modifier(
205208
"mortality_rate",
206209
modifier=self.adjust_mortality_rate,
207-
requires_values=[self.excess_mortality_rate_pipeline_name],
210+
component=self,
211+
required_resources=[self.excess_mortality_rate],
208212
)
209213

210214
distribution = builder.data.load(f"{self.risk}.distribution")
211-
exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure")
215+
self.exposure_pipeline = builder.value.get_value(f"{self.risk.name}.exposure")
212216
threshold = builder.configuration[self.name].threshold
213217

214218
self.filter_by_exposure = self.get_exposure_filter(
215-
distribution, exposure_pipeline, threshold
219+
distribution, self.exposure_pipeline, threshold
216220
)
217221

218222
#################

src/vivarium_public_health/disease/state.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vivarium.framework.engine import Builder
1616
from vivarium.framework.population import PopulationView, SimulantData
1717
from vivarium.framework.randomness import RandomnessStream
18+
from vivarium.framework.resource import Resource
1819
from vivarium.framework.state_machine import State, Transient, Transition, Trigger
1920
from vivarium.framework.values import Pipeline, list_combiner, union_post_processor
2021
from vivarium.types import DataInput, LookupTableData
@@ -57,12 +58,8 @@ def columns_required(self) -> list[str] | None:
5758
return [self.model, "alive"]
5859

5960
@property
60-
def initialization_requirements(self) -> dict[str, list[str]]:
61-
return {
62-
"requires_columns": [self.model],
63-
"requires_values": [],
64-
"requires_streams": [],
65-
}
61+
def initialization_requirements(self) -> list[str | Resource]:
62+
return [self.model]
6663

6764
#####################
6865
# Lifecycle methods #
@@ -327,6 +324,10 @@ class DiseaseState(BaseDiseaseState):
327324
# Properties #
328325
##############
329326

327+
@property
328+
def initialization_requirements(self) -> list[str | Resource]:
329+
return super().initialization_requirements + [self.randomness_prevalence]
330+
330331
@property
331332
def configuration_defaults(self) -> dict[str, Any]:
332333
configuration_defaults = super().configuration_defaults
@@ -465,19 +466,22 @@ def setup(self, builder: Builder) -> None:
465466
self.disability_weight = self.get_disability_weight_pipeline(builder)
466467

467468
builder.value.register_value_modifier(
468-
"all_causes.disability_weight", modifier=self.disability_weight
469+
"all_causes.disability_weight",
470+
modifier=self.disability_weight,
471+
component=self,
469472
)
470473

471474
self.has_excess_mortality = is_non_zero(
472475
self.lookup_tables["excess_mortality_rate"].data
473476
)
474-
self.excess_mortality_rate = self.get_excess_mortality_rate_pipeline(builder)
475477
self.joint_paf = self.get_joint_paf(builder)
478+
self.excess_mortality_rate = self.get_excess_mortality_rate_pipeline(builder)
476479

477480
builder.value.register_value_modifier(
478481
"mortality_rate",
479482
modifier=self.adjust_mortality_rate,
480-
requires_values=[self.excess_mortality_rate_pipeline_name],
483+
component=self,
484+
required_resources=[self.excess_mortality_rate],
481485
)
482486

483487
self.randomness_prevalence = self.get_randomness_prevalence(builder)
@@ -531,7 +535,8 @@ def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
531535
return builder.value.register_value_producer(
532536
f"{self.state_id}.dwell_time",
533537
source=self.lookup_tables["dwell_time"],
534-
requires_columns=required_columns,
538+
component=self,
539+
required_resources=required_columns,
535540
)
536541

537542
def get_disability_weight_source(self, disability_weight: DataInput | None) -> DataInput:
@@ -556,7 +561,8 @@ def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
556561
return builder.value.register_value_producer(
557562
f"{self.state_id}.disability_weight",
558563
source=self.compute_disability_weight,
559-
requires_columns=lookup_columns + ["alive", self.model],
564+
component=self,
565+
required_resources=lookup_columns + ["alive", self.model],
560566
)
561567

562568
def get_excess_mortality_rate_source(
@@ -583,21 +589,24 @@ def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
583589
return builder.value.register_rate_producer(
584590
self.excess_mortality_rate_pipeline_name,
585591
source=self.compute_excess_mortality_rate,
586-
requires_columns=lookup_columns + ["alive", self.model],
587-
requires_values=[self.excess_mortality_rate_paf_pipeline_name],
592+
component=self,
593+
required_resources=lookup_columns + ["alive", self.model, self.joint_paf],
588594
)
589595

590596
def get_joint_paf(self, builder: Builder) -> Pipeline:
591597
paf = builder.lookup.build_table(0)
592598
return builder.value.register_value_producer(
593599
self.excess_mortality_rate_paf_pipeline_name,
594600
source=lambda idx: [paf(idx)],
601+
component=self,
595602
preferred_combiner=list_combiner,
596603
preferred_post_processor=union_post_processor,
597604
)
598605

599606
def get_randomness_prevalence(self, builder: Builder) -> RandomnessStream:
600-
return builder.randomness.get_stream(f"{self.state_id}_prevalent_cases")
607+
return builder.randomness.get_stream(
608+
f"{self.state_id}_prevalent_cases", component=self
609+
)
601610

602611
##################
603612
# Public methods #

src/vivarium_public_health/disease/transition.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,20 @@ def __init__(
129129
# noinspection PyAttributeOutsideInit
130130
def setup(self, builder: Builder) -> None:
131131
lookup_columns = get_lookup_columns([self.lookup_tables["transition_rate"]])
132-
self.transition_rate = builder.value.register_rate_producer(
133-
self.transition_rate_pipeline_name,
134-
source=self.compute_transition_rate,
135-
requires_columns=lookup_columns + ["alive"],
136-
requires_values=[f"{self.transition_rate_pipeline_name}.paf"],
137-
)
138132
paf = builder.lookup.build_table(0)
139133
self.joint_paf = builder.value.register_value_producer(
140134
f"{self.transition_rate_pipeline_name}.paf",
141135
source=lambda index: [paf(index)],
136+
component=self,
142137
preferred_combiner=list_combiner,
143138
preferred_post_processor=union_post_processor,
144139
)
140+
self.transition_rate = builder.value.register_rate_producer(
141+
self.transition_rate_pipeline_name,
142+
source=self.compute_transition_rate,
143+
component=self,
144+
required_resources=lookup_columns + ["alive", self.joint_paf],
145+
)
145146

146147
#################
147148
# Setup methods #

src/vivarium_public_health/mslt/delay.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vivarium.framework.engine import Builder
1717
from vivarium.framework.event import Event
1818
from vivarium.framework.population import SimulantData
19+
from vivarium.framework.resource import Resource
1920

2021

2122
class DelayedRisk(Component):
@@ -111,12 +112,8 @@ def columns_required(self) -> list[str] | None:
111112
return ["age", "sex", "population"]
112113

113114
@property
114-
def initialization_requirements(self) -> dict[str, list[str]]:
115-
return {
116-
"requires_columns": ["age", "sex", "population"],
117-
"requires_values": [],
118-
"requires_streams": [],
119-
}
115+
def initialization_requirements(self) -> list[str | Resource]:
116+
return ["age", "sex", "population"]
120117

121118
#####################
122119
# Lifecycle methods #
@@ -168,9 +165,11 @@ def setup(self, builder: Builder) -> None:
168165
)
169166
inc_name = "{}.incidence".format(self.risk)
170167
inc_int_name = "{}_intervention.incidence".format(self.risk)
171-
self.incidence = builder.value.register_rate_producer(inc_name, source=inc_data)
168+
self.incidence = builder.value.register_rate_producer(
169+
inc_name, source=inc_data, component=self
170+
)
172171
self.int_incidence = builder.value.register_rate_producer(
173-
inc_int_name, source=inc_data
172+
inc_int_name, source=inc_data, component=self
174173
)
175174

176175
# Load the remission rates for the BAU and intervention scenarios.
@@ -183,9 +182,11 @@ def setup(self, builder: Builder) -> None:
183182
)
184183
rem_name = "{}.remission".format(self.risk)
185184
rem_int_name = "{}_intervention.remission".format(self.risk)
186-
self.remission = builder.value.register_rate_producer(rem_name, source=rem_data)
185+
self.remission = builder.value.register_rate_producer(
186+
rem_name, source=rem_data, component=self
187+
)
187188
self.int_remission = builder.value.register_rate_producer(
188-
rem_int_name, source=rem_data
189+
rem_int_name, source=rem_data, component=self
189190
)
190191

191192
# We apply separate mortality rates to the different exposure bins.
@@ -255,6 +256,7 @@ def setup(self, builder: Builder) -> None:
255256
source=builder.lookup.build_table(
256257
mortality_data, key_columns=["sex"], parameter_columns=["age", "year"]
257258
),
259+
component=self,
258260
)
259261

260262
#################
@@ -312,7 +314,7 @@ def register_modifier(self, builder: Builder, disease: str) -> None:
312314
for template in rate_templates:
313315
rate_name = template.format(disease)
314316
modifier = lambda ix, rate: self.incidence_adjustment(disease, ix, rate)
315-
builder.value.register_value_modifier(rate_name, modifier)
317+
builder.value.register_value_modifier(rate_name, modifier, component=self)
316318

317319
########################
318320
# Event-driven methods #

0 commit comments

Comments
 (0)