Skip to content

Commit 2a505ab

Browse files
fix type-hinting in framework/utilities.py (#702)
1 parent 26efdcb commit 2a505ab

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**3.6.5 - 12/16/25**
2+
3+
- Type-hinting: Fix mypy errors in framework/utilities.py
4+
15
**3.6.4 - 12/16/25**
26

37
- Add documentation for Component LookupTable configuration

src/vivarium/framework/utilities.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# mypy: ignore-errors
21
"""
32
===========================
43
Framework Utility Functions
@@ -11,21 +10,21 @@
1110
from bdb import BdbQuit
1211
from collections.abc import Callable, Sequence
1312
from importlib import import_module
14-
from typing import Any, Literal, TypeVar
13+
from typing import Any, Literal
1514

1615
import numpy as np
1716
from loguru import logger
1817

1918
from vivarium.types import NumberLike, NumericArray, Timedelta
2019

21-
TimeValue = TypeVar("T", bound=NumberLike)
2220

23-
24-
def from_yearly(value: TimeValue, time_step: Timedelta) -> TimeValue:
21+
def from_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike:
22+
"""Rescale a yearly rate to the size of a time step."""
2523
return value * (time_step.total_seconds() / (60 * 60 * 24 * 365.0))
2624

2725

28-
def to_yearly(value: TimeValue, time_step: Timedelta) -> TimeValue:
26+
def to_yearly(value: NumberLike, time_step: Timedelta) -> NumberLike:
27+
"""Convert a time-step-scaled rate back to a yearly rate."""
2928
return value / (time_step.total_seconds() / (60 * 60 * 24 * 365.0))
3029

3130

@@ -56,15 +55,16 @@ def rate_to_probability(
5655
f"Rate conversion type {rate_conversion_type} is not implemented. "
5756
"Allowable types are 'linear' or 'exponential'."
5857
)
58+
probability: NumericArray
5959
if rate_conversion_type == "linear":
6060
# NOTE: The default behavior for randomness streams is to use a rate that is already
6161
# scaled to the time step which is why the default time scaling factor is 1.0.
62-
probability = np.array(rate * time_scaling_factor)
62+
# Use asarray to handle both scalars and arrays
63+
probability = np.asarray(rate) * time_scaling_factor
6364

6465
# Clip to 1.0 if the probability is greater than 1.0.
65-
exceeds_one = probability > 1.0
66-
if exceeds_one.any():
67-
probability[exceeds_one] = 1.0
66+
if np.any(probability > 1.0):
67+
probability = np.clip(probability, None, 1.0)
6868
logger.warning(
6969
"The rate to probability conversion resulted in a probability greater than 1.0. "
7070
"The probability has been clipped to 1.0 and indicates the rate is too high. "
@@ -73,9 +73,9 @@ def rate_to_probability(
7373
# encountered underflow from rate > 30k
7474
# for rates greater than 250, exp(-rate) evaluates to 1e-109
7575
# beware machine-specific floating point issues
76-
rate = np.array(rate)
76+
rate = np.asarray(rate)
7777
rate[rate > 250] = 250.0
78-
probability: NumericArray = 1 - np.exp(-rate * time_scaling_factor)
78+
probability = 1 - np.exp(-rate * time_scaling_factor)
7979

8080
return probability
8181

@@ -109,11 +109,13 @@ def probability_to_rate(
109109
f"Rate conversion type {rate_conversion_type} is not implemented. "
110110
"Allowable types are 'linear' or 'exponential'."
111111
)
112+
rate: NumericArray
112113
if rate_conversion_type == "linear":
113-
rate = np.array(probability / time_scaling_factor)
114+
# Use asarray to handle both scalars and arrays
115+
rate = np.asarray(probability) / time_scaling_factor
114116
else:
115-
probability = np.array(probability)
116-
rate: NumericArray = -np.log(1 - probability)
117+
probability = np.asarray(probability)
118+
rate = -np.log(1 - probability)
117119
return rate
118120

119121

tests/examples/test_disease_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,23 @@ def test_disease_model(fuzzy_checker: FuzzyChecker, disease_model_spec: Path) ->
6060
pop = simulation.get_population()
6161
is_alive = pop["alive"] == "alive"
6262

63+
alive_target = from_yearly(20, timedelta(days=0.5))
64+
assert isinstance(alive_target, float)
6365
fuzzy_checker.fuzzy_assert_proportion(
6466
observed_numerator=(len(pop[~is_alive])),
6567
observed_denominator=len(pop),
66-
target_proportion=from_yearly(20, timedelta(days=0.5)),
68+
target_proportion=alive_target,
6769
# todo: remove this parameter when MIC-5412 is resolved
6870
name="alive_proportion",
6971
)
7072

7173
has_lri = pop["lower_respiratory_infections"] == "infected_with_lower_respiratory_infections"
74+
lri_target = from_yearly(25, timedelta(days=0.5))
75+
assert isinstance(lri_target, float)
7276
fuzzy_checker.fuzzy_assert_proportion(
7377
observed_numerator=(len(pop[is_alive & has_lri])),
7478
observed_denominator=len(pop[is_alive]),
75-
target_proportion=from_yearly(25, timedelta(days=0.5)),
79+
target_proportion=lri_target,
7680
# todo: remove this parameter when MIC-5412 is resolved
7781
name="lri_proportion",
7882
)

0 commit comments

Comments
 (0)