1515from vivarium .framework .engine import Builder
1616from vivarium .framework .population import PopulationView , SimulantData
1717from vivarium .framework .randomness import RandomnessStream
18+ from vivarium .framework .resource import Resource
1819from vivarium .framework .state_machine import State , Transient , Transition , Trigger
1920from vivarium .framework .values import Pipeline , list_combiner , union_post_processor
2021from vivarium .types import DataInput , LookupTableData
@@ -57,12 +58,8 @@ def columns_required(self) -> list[str] | None:
5758 return [self .model , "alive" ]
5859
5960 @property
60- def initialization_requirements (self ) -> dict [str , list [str ]]:
61- return {
62- "requires_columns" : [self .model ],
63- "requires_values" : [],
64- "requires_streams" : [],
65- }
61+ def initialization_requirements (self ) -> list [str | Resource ]:
62+ return [self .model ]
6663
6764 #####################
6865 # Lifecycle methods #
@@ -327,6 +324,10 @@ class DiseaseState(BaseDiseaseState):
327324 # Properties #
328325 ##############
329326
327+ @property
328+ def initialization_requirements (self ) -> list [str | Resource ]:
329+ return super ().initialization_requirements + [self .randomness_prevalence ]
330+
330331 @property
331332 def configuration_defaults (self ) -> dict [str , Any ]:
332333 configuration_defaults = super ().configuration_defaults
@@ -465,19 +466,22 @@ def setup(self, builder: Builder) -> None:
465466 self .disability_weight = self .get_disability_weight_pipeline (builder )
466467
467468 builder .value .register_value_modifier (
468- "all_causes.disability_weight" , modifier = self .disability_weight
469+ "all_causes.disability_weight" ,
470+ modifier = self .disability_weight ,
471+ component = self ,
469472 )
470473
471474 self .has_excess_mortality = is_non_zero (
472475 self .lookup_tables ["excess_mortality_rate" ].data
473476 )
474- self .excess_mortality_rate = self .get_excess_mortality_rate_pipeline (builder )
475477 self .joint_paf = self .get_joint_paf (builder )
478+ self .excess_mortality_rate = self .get_excess_mortality_rate_pipeline (builder )
476479
477480 builder .value .register_value_modifier (
478481 "mortality_rate" ,
479482 modifier = self .adjust_mortality_rate ,
480- requires_values = [self .excess_mortality_rate_pipeline_name ],
483+ component = self ,
484+ required_resources = [self .excess_mortality_rate ],
481485 )
482486
483487 self .randomness_prevalence = self .get_randomness_prevalence (builder )
@@ -531,7 +535,8 @@ def get_dwell_time_pipeline(self, builder: Builder) -> Pipeline:
531535 return builder .value .register_value_producer (
532536 f"{ self .state_id } .dwell_time" ,
533537 source = self .lookup_tables ["dwell_time" ],
534- requires_columns = required_columns ,
538+ component = self ,
539+ required_resources = required_columns ,
535540 )
536541
537542 def get_disability_weight_source (self , disability_weight : DataInput | None ) -> DataInput :
@@ -556,7 +561,8 @@ def get_disability_weight_pipeline(self, builder: Builder) -> Pipeline:
556561 return builder .value .register_value_producer (
557562 f"{ self .state_id } .disability_weight" ,
558563 source = self .compute_disability_weight ,
559- requires_columns = lookup_columns + ["alive" , self .model ],
564+ component = self ,
565+ required_resources = lookup_columns + ["alive" , self .model ],
560566 )
561567
562568 def get_excess_mortality_rate_source (
@@ -583,21 +589,24 @@ def get_excess_mortality_rate_pipeline(self, builder: Builder) -> Pipeline:
583589 return builder .value .register_rate_producer (
584590 self .excess_mortality_rate_pipeline_name ,
585591 source = self .compute_excess_mortality_rate ,
586- requires_columns = lookup_columns + [ "alive" , self . model ] ,
587- requires_values = [ self .excess_mortality_rate_paf_pipeline_name ],
592+ component = self ,
593+ required_resources = lookup_columns + [ "alive" , self .model , self . joint_paf ],
588594 )
589595
590596 def get_joint_paf (self , builder : Builder ) -> Pipeline :
591597 paf = builder .lookup .build_table (0 )
592598 return builder .value .register_value_producer (
593599 self .excess_mortality_rate_paf_pipeline_name ,
594600 source = lambda idx : [paf (idx )],
601+ component = self ,
595602 preferred_combiner = list_combiner ,
596603 preferred_post_processor = union_post_processor ,
597604 )
598605
599606 def get_randomness_prevalence (self , builder : Builder ) -> RandomnessStream :
600- return builder .randomness .get_stream (f"{ self .state_id } _prevalent_cases" )
607+ return builder .randomness .get_stream (
608+ f"{ self .state_id } _prevalent_cases" , component = self
609+ )
601610
602611 ##################
603612 # Public methods #
0 commit comments