11"""
22=================================
3- Risk Exposure Distribution Models
3+ Exposure Distribution Models
44=================================
55
66This module contains tools for modeling several different risk
77exposure distributions.
88
99"""
1010
11+ from __future__ import annotations
12+
13+ import warnings
1114from abc import ABC , abstractmethod
1215from collections .abc import Callable
16+ from typing import TYPE_CHECKING
1317
1418import numpy as np
1519import pandas as pd
2428from vivarium_public_health .risks .data_transformations import pivot_categorical
2529from vivarium_public_health .utilities import EntityString , get_lookup_columns
2630
31+ if TYPE_CHECKING :
32+ from vivarium_public_health .exposure import Exposure
33+
2734
2835class MissingDataError (Exception ):
2936 pass
3037
3138
32- class RiskExposureDistribution (Component , ABC ):
39+ class ExposureDistribution (Component , ABC ):
3340
3441 #####################
3542 # Lifecycle methods #
3643 #####################
3744
3845 def __init__ (
3946 self ,
40- risk : EntityString ,
47+ exposure_component : Exposure ,
4148 distribution_type : str ,
4249 exposure_data : int | float | pd .DataFrame | None = None ,
4350 ) -> None :
4451 super ().__init__ ()
45- self .risk = risk
52+ self .exposure_component = exposure_component
4653 self .distribution_type = distribution_type
54+ if (
55+ self .distribution_type != "dichotomous"
56+ and self .exposure_component .entity .type == "intervention"
57+ ):
58+ raise NotImplementedError (
59+ f"Distribution type { self .distribution_type } is not supported for interventions."
60+ )
4761 self ._exposure_data = exposure_data
4862
49- self .parameters_pipeline_name = f"{ self .risk } .exposure_parameters"
63+ self .parameters_pipeline_name = (
64+ f"{ self .exposure_component .entity } .exposure_parameters"
65+ )
5066
5167 #################
5268 # Setup methods #
5369 #################
5470
5571 def get_configuration (self , builder : "Builder" ) -> LayeredConfigTree | None :
56- return builder .configuration [self .risk ]
72+ return builder .configuration [self .exposure_component . entity ]
5773
5874 @abstractmethod
5975 def build_all_lookup_tables (self , builder : "Builder" ) -> None :
@@ -62,7 +78,9 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
6278 def get_exposure_data (self , builder : Builder ) -> int | float | pd .DataFrame :
6379 if self ._exposure_data is not None :
6480 return self ._exposure_data
65- return self .get_data (builder , self .configuration ["data_sources" ]["exposure" ])
81+ return self .get_data (
82+ builder , self .configuration ["data_sources" ][self .exposure_component .exposure_type ]
83+ )
6684
6785 # noinspection PyAttributeOutsideInit
6886 def setup (self , builder : Builder ) -> None :
@@ -87,7 +105,7 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
87105 raise NotImplementedError
88106
89107
90- class EnsembleDistribution (RiskExposureDistribution ):
108+ class EnsembleDistribution (ExposureDistribution ):
91109 ##############
92110 # Properties #
93111 ##############
@@ -106,7 +124,7 @@ def initialization_requirements(self) -> list[str | Resource]:
106124
107125 def __init__ (self , risk : EntityString , distribution_type : str = "ensemble" ) -> None :
108126 super ().__init__ (risk , distribution_type )
109- self ._propensity = f"ensemble_propensity_{ self .risk } "
127+ self ._propensity = f"ensemble_propensity_{ self .exposure_component . entity } "
110128
111129 #################
112130 # Setup methods #
@@ -129,7 +147,11 @@ def build_all_lookup_tables(self, builder: Builder) -> None:
129147 distributions = list (raw_weights ["parameter" ].unique ())
130148
131149 raw_weights = pivot_categorical (
132- builder , self .risk , raw_weights , pivot_column = "parameter" , reset_index = False
150+ builder ,
151+ self .exposure_component .entity ,
152+ raw_weights ,
153+ pivot_column = "parameter" ,
154+ reset_index = False ,
133155 )
134156
135157 weights , parameters = rd .EnsembleDistribution .get_parameters (
@@ -201,7 +223,7 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
201223 return x
202224
203225
204- class ContinuousDistribution (RiskExposureDistribution ):
226+ class ContinuousDistribution (ExposureDistribution ):
205227 #####################
206228 # Lifecycle methods #
207229 #####################
@@ -261,12 +283,12 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
261283 return x
262284
263285
264- class PolytomousDistribution (RiskExposureDistribution ):
286+ class PolytomousDistribution (ExposureDistribution ):
265287 @property
266288 def categories (self ) -> list [str ]:
267- # These need to be sorted so the cumulative sum is in the ocrrect order of categories
289+ # These need to be sorted so the cumulative sum is in the correct order of categories
268290 # and results are therefore reproducible and correct
269- return sorted (self .lookup_tables ["exposure" ].value_columns )
291+ return sorted (self .lookup_tables [self . exposure_component . exposure_type ].value_columns )
270292
271293 #################
272294 # Setup methods #
@@ -277,9 +299,11 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
277299 exposure_value_columns = self .get_exposure_value_columns (exposure_data )
278300
279301 if isinstance (exposure_data , pd .DataFrame ):
280- exposure_data = pivot_categorical (builder , self .risk , exposure_data , "parameter" )
302+ exposure_data = pivot_categorical (
303+ builder , self .exposure_component .entity , exposure_data , "parameter"
304+ )
281305
282- self .lookup_tables ["exposure" ] = self .build_lookup_table (
306+ self .lookup_tables [self . exposure_component . exposure_type ] = self .build_lookup_table (
283307 builder , exposure_data , exposure_value_columns
284308 )
285309
@@ -293,9 +317,11 @@ def get_exposure_value_columns(
293317 def get_exposure_parameter_pipeline (self , builder : Builder ) -> Pipeline :
294318 return builder .value .register_value_producer (
295319 self .parameters_pipeline_name ,
296- source = self .lookup_tables ["exposure" ],
320+ source = self .lookup_tables [self . exposure_component . exposure_type ],
297321 component = self ,
298- required_resources = get_lookup_columns ([self .lookup_tables ["exposure" ]]),
322+ required_resources = get_lookup_columns (
323+ [self .lookup_tables [self .exposure_component .exposure_type ]]
324+ ),
299325 )
300326
301327 ##################
@@ -313,12 +339,12 @@ def ppf(self, quantiles: pd.Series) -> pd.Series:
313339 ).sum (axis = 1 )
314340 return pd .Series (
315341 np .array (self .categories )[category_index ],
316- name = self .risk + " .exposure" ,
342+ name = f" { self .exposure_component . entity } .exposure" ,
317343 index = quantiles .index ,
318344 )
319345
320346
321- class DichotomousDistribution (RiskExposureDistribution ):
347+ class DichotomousDistribution (ExposureDistribution ):
322348
323349 #################
324350 # Setup methods #
@@ -332,11 +358,15 @@ def build_all_lookup_tables(self, builder: "Builder") -> None:
332358 any_negatives = (exposure_data [exposure_value_columns ] < 0 ).any ().any ()
333359 any_over_one = (exposure_data [exposure_value_columns ] > 1 ).any ().any ()
334360 if any_negatives or any_over_one :
335- raise ValueError (f"All exposures must be in the range [0, 1] for { self .risk } " )
361+ raise ValueError (
362+ f"All exposures must be in the range [0, 1] for { self .exposure_component .entity } "
363+ )
336364 elif exposure_data < 0 or exposure_data > 1 :
337- raise ValueError (f"Exposure must be in the range [0, 1] for { self .risk } " )
365+ raise ValueError (
366+ f"Exposure must be in the range [0, 1] for { self .exposure_component .entity } "
367+ )
338368
339- self .lookup_tables ["exposure" ] = self .build_lookup_table (
369+ self .lookup_tables [self . exposure_component . exposure_type ] = self .build_lookup_table (
340370 builder , exposure_data , exposure_value_columns
341371 )
342372 self .lookup_tables ["paf" ] = self .build_lookup_table (builder , 0.0 )
@@ -350,20 +380,43 @@ def get_exposure_data(self, builder: Builder) -> int | float | pd.DataFrame:
350380 # rebin exposure categories
351381 self .validate_rebin_source (builder , exposure_data )
352382 rebin_exposed_categories = set (self .configuration ["rebinned_exposed" ])
383+ # Check if risk exposure is exposed vs cat1
384+ if (
385+ "cat1" in exposure_data ["parameter" ].unique ()
386+ and self .exposure_component .entity .type == "risk_factor"
387+ ):
388+ warnings .warn (
389+ "Using 'cat1' and 'cat2' for dichotomous exposure is deprecated and will be removed in a future release. Use 'exposed' and 'unexposed' instead." ,
390+ FutureWarning ,
391+ stacklevel = 2 ,
392+ )
393+ exposure_data ["parameter" ] = exposure_data ["parameter" ].replace (
394+ {
395+ "cat1" : self .exposure_component .dichotomous_exposure_category_names .exposed ,
396+ "cat2" : self .exposure_component .dichotomous_exposure_category_names .unexposed ,
397+ }
398+ )
353399 if rebin_exposed_categories :
354- exposure_data = self ._rebin_exposure_data (exposure_data , rebin_exposed_categories )
400+ exposure_data = self ._rebin_exposure_data (
401+ exposure_data ,
402+ rebin_exposed_categories ,
403+ self .exposure_component .dichotomous_exposure_category_names .exposed ,
404+ )
355405
356- exposure_data = exposure_data [exposure_data ["parameter" ] == "cat1" ]
406+ exposure_data = exposure_data [
407+ exposure_data ["parameter" ]
408+ == self .exposure_component .dichotomous_exposure_category_names .exposed
409+ ]
357410 return exposure_data .drop (columns = "parameter" )
358411
359412 @staticmethod
360413 def _rebin_exposure_data (
361- exposure_data : pd .DataFrame , rebin_exposed_categories : set
414+ exposure_data : pd .DataFrame , rebin_exposed_categories : set , exposed_category_name : str
362415 ) -> pd .DataFrame :
363416 exposure_data = exposure_data [
364417 exposure_data ["parameter" ].isin (rebin_exposed_categories )
365418 ]
366- exposure_data ["parameter" ] = "cat1"
419+ exposure_data ["parameter" ] = exposed_category_name
367420 exposure_data = (
368421 exposure_data .groupby (list (exposure_data .columns .difference (["value" ])))
369422 .sum ()
@@ -382,7 +435,7 @@ def get_exposure_value_columns(
382435 def setup (self , builder : Builder ) -> None :
383436 super ().setup (builder )
384437 self .joint_paf = builder .value .register_value_producer (
385- f"{ self .risk } .exposure_parameters.paf" ,
438+ f"{ self .exposure_component . entity } .exposure_parameters.paf" ,
386439 source = lambda index : [self .lookup_tables ["paf" ](index )],
387440 component = self ,
388441 preferred_combiner = list_combiner ,
@@ -391,10 +444,12 @@ def setup(self, builder: Builder) -> None:
391444
392445 def get_exposure_parameter_pipeline (self , builder : Builder ) -> Pipeline :
393446 return builder .value .register_value_producer (
394- f"{ self .risk } .exposure_parameters" ,
447+ f"{ self .exposure_component . entity } .exposure_parameters" ,
395448 source = self .exposure_parameter_source ,
396449 component = self ,
397- required_resources = get_lookup_columns ([self .lookup_tables ["exposure" ]]),
450+ required_resources = get_lookup_columns (
451+ [self .lookup_tables [self .exposure_component .exposure_type ]]
452+ ),
398453 )
399454
400455 ##############
@@ -405,29 +460,31 @@ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
405460 if not isinstance (data , pd .DataFrame ):
406461 return
407462
408- rebin_exposed_categories = set (builder .configuration [self .risk ]["rebinned_exposed" ])
463+ rebin_exposed_categories = set (
464+ builder .configuration [self .exposure_component .entity ]["rebinned_exposed" ]
465+ )
409466
410467 if (
411468 rebin_exposed_categories
412- and builder .configuration [self .risk ]["category_thresholds" ]
469+ and builder .configuration [self .exposure_component . entity ]["category_thresholds" ]
413470 ):
414471 raise ValueError (
415472 f"Rebinning and category thresholds are mutually exclusive. "
416- f"You provided both for { self .risk .name } ."
473+ f"You provided both for { self .exposure_component . entity .name } ."
417474 )
418475
419476 invalid_cats = rebin_exposed_categories .difference (set (data .parameter ))
420477 if invalid_cats :
421478 raise ValueError (
422479 f"The following provided categories for the rebinned exposed "
423- f"category of { self .risk .name } are not found in the exposure data: "
480+ f"category of { self .exposure_component . entity .name } are not found in the exposure data: "
424481 f"{ invalid_cats } ."
425482 )
426483
427484 if rebin_exposed_categories == set (data .parameter ):
428485 raise ValueError (
429486 f"The provided categories for the rebinned exposed category of "
430- f"{ self .risk .name } comprise all categories for the exposure data. "
487+ f"{ self .exposure_component . entity .name } comprise all categories for the exposure data. "
431488 f"At least one category must be left out of the provided categories "
432489 f"to be rebinned into the unexposed category."
433490 )
@@ -437,7 +494,9 @@ def validate_rebin_source(self, builder, data: pd.DataFrame) -> None:
437494 ##################################
438495
439496 def exposure_parameter_source (self , index : pd .Index ) -> pd .Series :
440- base_exposure = self .lookup_tables ["exposure" ](index ).values
497+ base_exposure = self .lookup_tables [self .exposure_component .exposure_type ](
498+ index
499+ ).values
441500 joint_paf = self .joint_paf (index ).values
442501 return pd .Series (base_exposure * (1 - joint_paf ), index = index , name = "values" )
443502
@@ -448,8 +507,13 @@ def exposure_parameter_source(self, index: pd.Index) -> pd.Series:
448507 def ppf (self , quantiles : pd .Series ) -> pd .Series :
449508 exposed = quantiles < self .exposure_parameters (quantiles .index )
450509 return pd .Series (
451- exposed .replace ({True : "cat1" , False : "cat2" }),
452- name = self .risk + ".exposure" ,
510+ exposed .replace (
511+ {
512+ True : self .exposure_component .dichotomous_exposure_category_names .exposed ,
513+ False : self .exposure_component .dichotomous_exposure_category_names .unexposed ,
514+ }
515+ ),
516+ name = f"{ self .exposure_component .entity } .{ self .exposure_component .exposure_type } " ,
453517 index = quantiles .index ,
454518 )
455519
0 commit comments