Skip to content

Commit e7a6ad7

Browse files
authored
Simplify LBWSG category parsing (#546)
* use re findall * add error * change check * isort * CL update
1 parent 6249097 commit e7a6ad7

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**4.3.11 - 09/24/25**
2+
3+
- Refactor: Simplify LBWSG parsing logic
4+
15
**4.3.10 - 09/23/25**
26

37
- Bugfix: Fix bug in PublicHealthObserver to retain stratification columns in results

src/vivarium_public_health/risks/implementations/low_birth_weight_and_short_gestation.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import pickle
12+
import re
1213
from collections.abc import Callable
1314
from typing import Any
1415

@@ -119,13 +120,12 @@ def get_category_intervals(self, builder: Builder) -> dict[str, dict[str, pd.Int
119120
The intervals for each category.
120121
"""
121122
categories: dict[str, str] = builder.data.load(f"{self.risk}.categories")
122-
category_intervals = {
123-
axis: {
124-
category: self._parse_description(axis, description)
125-
for category, description in categories.items()
126-
}
127-
for axis in [BIRTH_WEIGHT, GESTATIONAL_AGE]
128-
}
123+
category_intervals = {GESTATIONAL_AGE: {}, BIRTH_WEIGHT: {}}
124+
125+
for category, description in categories.items():
126+
gestation_interval, birth_weight_interval = self._parse_description(description)
127+
category_intervals[GESTATIONAL_AGE][category] = gestation_interval
128+
category_intervals[BIRTH_WEIGHT][category] = birth_weight_interval
129129
return category_intervals
130130

131131
##################
@@ -224,7 +224,7 @@ def single_axis_ppf(
224224
##################
225225

226226
@staticmethod
227-
def _parse_description(axis: str, description: str) -> pd.Interval:
227+
def _parse_description(description: str) -> tuple[pd.Interval, pd.Interval]:
228228
"""Parses a string corresponding to a low birth weight and short gestation
229229
category to an Interval.
230230
@@ -235,17 +235,15 @@ def _parse_description(axis: str, description: str) -> pd.Interval:
235235
An example of an edge case of birth weight:
236236
'Neonatal preterm and LBWSG (estimation years) - [36, 37) wks, [4000, 9999] g'
237237
"""
238-
endpoints = {
239-
BIRTH_WEIGHT: [
240-
float(val)
241-
for val in description.split(", [")[1].split(")")[0].split("]")[0].split(", ")
242-
],
243-
GESTATIONAL_AGE: [
244-
float(val)
245-
for val in description.split("- [")[1].split(")")[0].split("+")[0].split(", ")
246-
],
247-
}[axis]
248-
return pd.Interval(*endpoints, closed="left") # noqa
238+
lbwsg_values = [float(val) for val in re.findall(r"(\d+)", description)]
239+
if len(list(lbwsg_values)) != 4:
240+
raise ValueError(
241+
f"Could not parse LBWSG description '{description}'. Expected 4 numeric values."
242+
)
243+
return (
244+
pd.Interval(*lbwsg_values[:2], closed="left"), # Gestational Age
245+
pd.Interval(*lbwsg_values[2:], closed="left"), # Birth Weight
246+
)
249247

250248

251249
class LBWSGRisk(Risk):

tests/risks/test_low_birth_weight_and_short_gestation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
],
3737
)
3838
def test_parsing_lbwsg_descriptions(description, expected_weight_values, expected_age_values):
39-
weight_interval = LBWSGDistribution._parse_description("birth_weight", description)
40-
age_interval = LBWSGDistribution._parse_description("gestational_age", description)
39+
age_interval, weight_interval = LBWSGDistribution._parse_description(description)
4140
assert weight_interval.left == expected_weight_values[0]
4241
assert weight_interval.right == expected_weight_values[1]
4342
assert age_interval.left == expected_age_values[0]

0 commit comments

Comments
 (0)