Skip to content

Commit d583dc6

Browse files
authored
Albrja/mic-6068/rate-to-probability-configuration (#521)
Albrja/mic-6068/rate-to-probability-configuration Update RateTransition to have rate conversion type configurable - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6068 Changes and notes -Update RateTransition to have rate conversion type configurable ### Testing <!-- Details on how code was verified, any unit tests local for the repo, regression testing, etc. At a minimum, this should include an integration test for a framework change. Consider: plots, images, (small) csv file. -->
1 parent 6b4c03f commit d583dc6

File tree

5 files changed

+129
-3
lines changed

5 files changed

+129
-3
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**4.1.0 - 05/20/25**
2+
3+
- Feature: Update RateTransition to have configuration key for rate conversion type
4+
15
**4.0.2 - 05/01/25**
26

37
- Bugfix: Fix configuration key for untracking age in FertilityCrudeBirthRate

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
long_description = f.read()
4343

4444
install_requirements = [
45-
"vivarium>=3.2.3",
45+
"vivarium>=3.4.0",
4646
"layered_config_tree>=1.0.1",
4747
"loguru",
4848
"numpy<2.0.0",

src/vivarium_public_health/disease/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
from typing import Any
1515

1616
import pandas as pd
17+
from layered_config_tree import ConfigurationError
1718
from vivarium.framework.engine import Builder
19+
from vivarium.framework.event import Event
1820
from vivarium.framework.population import SimulantData
1921
from vivarium.framework.state_machine import Machine
2022
from vivarium.types import DataInput, LookupTableData
2123

2224
from vivarium_public_health.disease.exceptions import DiseaseModelError
2325
from vivarium_public_health.disease.state import BaseDiseaseState, SusceptibleState
24-
from vivarium_public_health.disease.transition import TransitionString
26+
from vivarium_public_health.disease.transition import RateTransition, TransitionString
2527

2628

2729
class DiseaseModel(Machine):
@@ -118,6 +120,18 @@ def setup(self, builder: Builder) -> None:
118120
required_resources=["age", "sex"],
119121
)
120122

123+
def on_post_setup(self, event: Event) -> None:
124+
conversion_types = set()
125+
for state in self.states:
126+
for transition in state.transition_set.transitions:
127+
if isinstance(transition, RateTransition):
128+
conversion_types.add(transition.rate_conversion_type)
129+
if len(conversion_types) > 1:
130+
raise ConfigurationError(
131+
"All transitions in a disease model must have the same rate conversion type."
132+
f" Found: {conversion_types}."
133+
)
134+
121135
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
122136
"""Initialize the simulants in the population.
123137

src/vivarium_public_health/disease/transition.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def configuration_defaults(self) -> dict[str, Any]:
4848
"data_sources": {
4949
"transition_rate": self._rate_source,
5050
},
51+
"rate_conversion_type": "linear",
5152
},
5253
}
5354

@@ -143,6 +144,7 @@ def setup(self, builder: Builder) -> None:
143144
component=self,
144145
required_resources=lookup_columns + ["alive", self.joint_paf],
145146
)
147+
self.rate_conversion_type = self.configuration["rate_conversion_type"]
146148

147149
#################
148150
# Setup methods #
@@ -184,7 +186,12 @@ def compute_transition_rate(self, index: pd.Index) -> pd.Series:
184186
##################
185187

186188
def _probability(self, index: pd.Index) -> pd.Series:
187-
return pd.Series(rate_to_probability(self.transition_rate(index)))
189+
return pd.Series(
190+
rate_to_probability(
191+
self.transition_rate(index),
192+
rate_conversion_type=self.rate_conversion_type,
193+
)
194+
)
188195

189196

190197
class ProportionTransition(Transition):

tests/disease/test_disease.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from unittest.mock import patch
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
6+
from layered_config_tree import ConfigurationError, LayeredConfigTree
47
from vivarium import Component, InteractiveContext
58
from vivarium.framework.state_machine import Transition
69
from vivarium.framework.utilities import from_yearly
@@ -554,3 +557,101 @@ def test_artifact_transition_keys(mocker, disease):
554557
# check remission rate
555558
remissive_transition = with_condition.add_rate_transition(healthy)
556559
assert remissive_transition._rate_source == f"cause.{cause}.remission_rate"
560+
561+
562+
@pytest.mark.parametrize("rate_conversion_type", ["linear", "exponential"])
563+
def test_transition_rate_to_probability_configuration(
564+
base_config: LayeredConfigTree,
565+
base_plugins: LayeredConfigTree,
566+
disease: str,
567+
rate_conversion_type: str,
568+
):
569+
"""
570+
Test that the transition rate to probability configuration is set correctly.
571+
"""
572+
healthy = BaseDiseaseState("healthy")
573+
sick = DiseaseState("sick")
574+
key = "sequela.acute_myocardial_infarction_first_2_days.incidence_rate"
575+
transition = RateTransition(
576+
input_state=healthy,
577+
output_state=sick,
578+
get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)},
579+
)
580+
healthy.transition_set.append(transition)
581+
model = DiseaseModel(disease, initial_state=healthy, states=[healthy, sick])
582+
583+
base_config.update(
584+
{
585+
f"{transition.name}": {
586+
"rate_conversion_type": rate_conversion_type,
587+
}
588+
}
589+
)
590+
591+
# Sets the configuration
592+
sim = InteractiveContext(
593+
components=[TestPopulation(), model],
594+
configuration=base_config,
595+
plugin_configuration=base_plugins,
596+
setup=True,
597+
)
598+
599+
assert transition.rate_conversion_type == rate_conversion_type
600+
with patch(
601+
"vivarium_public_health.disease.transition.rate_to_probability", return_value=1.0
602+
) as mock_rate_to_probability:
603+
idx = pd.Index(list(range(10)))
604+
transition._probability(idx)
605+
606+
mock_rate_to_probability.assert_called_once()
607+
args, kwargs = mock_rate_to_probability.call_args
608+
assert len(args) == 1
609+
assert len(kwargs) == 1
610+
assert args[0].index.equals(idx)
611+
assert kwargs["rate_conversion_type"] == rate_conversion_type
612+
613+
614+
def test_disease_model_rate_conversion_config_error(
615+
base_config: LayeredConfigTree,
616+
base_plugins: LayeredConfigTree,
617+
disease: str,
618+
):
619+
"""
620+
Test that the transition rate to probability configuration is set correctly.
621+
"""
622+
healthy = BaseDiseaseState("healthy")
623+
sick = DiseaseState("sick")
624+
key = "sequela.acute_myocardial_infarction_first_2_days.incidence_rate"
625+
transition = RateTransition(
626+
input_state=healthy,
627+
output_state=sick,
628+
get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)},
629+
)
630+
another_transition = RateTransition(
631+
input_state=sick,
632+
output_state=healthy,
633+
get_data_functions={"incidence_rate": lambda builder, _: builder.data.load(key)},
634+
)
635+
healthy.transition_set.append(transition)
636+
sick.transition_set.append(another_transition)
637+
model = DiseaseModel(disease, initial_state=healthy, states=[healthy, sick])
638+
639+
base_config.update(
640+
{
641+
f"{transition.name}": {
642+
"rate_conversion_type": "linear",
643+
},
644+
f"{another_transition.name}": {
645+
"rate_conversion_type": "exponential",
646+
},
647+
}
648+
)
649+
650+
# Sets the configuration
651+
with pytest.raises(ConfigurationError):
652+
InteractiveContext(
653+
components=[TestPopulation(), model],
654+
configuration=base_config,
655+
plugin_configuration=base_plugins,
656+
setup=True,
657+
)

0 commit comments

Comments
 (0)