Skip to content

Commit e8d8146

Browse files
authored
Fill missing Ensemble Distribution Parameters (#89)
* use singledispatch to simplify formatters * [COPILOT] Add basic formatting tests * add some tests for missing data * make good tests * cleanup * remove added tests * add unsupported tests * formatting * add cl * another change * format
1 parent b7968ac commit e8d8146

File tree

5 files changed

+127
-32
lines changed

5 files changed

+127
-32
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
**2.2.0 - 10/02/25**
2+
3+
- Allow user to pass incomplete parameter sets to EnsembleDistribution (filled with zeros)
4+
- Backfill some unit tests
5+
16
**2.1.6 - 08/01/25**
27

38
- Use vivarium_dependencies for common setup constraints

src/risk_distributions/formatting.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import singledispatch
12
from typing import Any, TypeVar
23

34
import numpy as np
@@ -42,22 +43,13 @@ def cast_to_series(mean: Parameter, sd: Parameter) -> tuple[pd.Series, pd.Series
4243
return mean, sd
4344

4445

46+
@singledispatch
4547
def format_data(data: Parameters, required_columns: list[Any], measure: str) -> pd.DataFrame:
4648
"""Formats parameter data into a dataframe."""
47-
if isinstance(data, np.ndarray):
48-
data = format_array(data, required_columns, measure)
49-
elif isinstance(data, pd.Series):
50-
data = format_series(data, required_columns, measure)
51-
elif isinstance(data, pd.DataFrame):
52-
data = format_data_frame(data, required_columns, measure)
53-
elif isinstance(data, (list, tuple)):
54-
data = format_list_like(data, required_columns, measure)
55-
elif isinstance(data, dict):
56-
data = format_dict(data, required_columns, measure)
57-
58-
return data
49+
raise TypeError(f"Unsupported data type {type(data)} for {measure}")
5950

6051

52+
@format_data.register
6153
def format_array(data: np.ndarray, required_columns: list[Any], measure: str) -> pd.DataFrame:
6254
"""Transforms 1d and 2d arrays into dataframes with columns for the
6355
parameters and (possibly) rows for each parameter variation."""
@@ -111,6 +103,7 @@ def format_array(data: np.ndarray, required_columns: list[Any], measure: str) ->
111103
return data
112104

113105

106+
@format_data.register
114107
def format_series(data: pd.Series, required_columns: list[Any], measure: str) -> pd.DataFrame:
115108
"""Transforms series data into dataframes with columns for the
116109
parameters and (possibly) rows for each parameter variation."""
@@ -133,6 +126,7 @@ def format_series(data: pd.Series, required_columns: list[Any], measure: str) ->
133126
return data
134127

135128

129+
@format_data.register
136130
def format_data_frame(
137131
data: pd.DataFrame, required_columns: list[Any], measure: str
138132
) -> pd.DataFrame:
@@ -154,6 +148,8 @@ def format_data_frame(
154148
return data
155149

156150

151+
@format_data.register(list)
152+
@format_data.register(tuple)
157153
def format_list_like(
158154
data: list | tuple, required_columns: list[Any], measure: str
159155
) -> pd.DataFrame:
@@ -163,9 +159,8 @@ def format_list_like(
163159
return format_array(data, required_columns, measure)
164160

165161

166-
def format_dict(
167-
data: dict[str, Parameter], required_columns: list[Any], measure: str
168-
) -> pd.DataFrame:
162+
@format_data.register
163+
def format_dict(data: dict, required_columns: list[Any], measure: str) -> pd.DataFrame:
169164
"""Transform dictionaries with scalar or list-like values into dataframes
170165
with columns for the parameters and (possibly) rows for each parameter
171166
variation."""

src/risk_distributions/risk_distributions.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import warnings
23
from collections.abc import Callable
34

@@ -566,7 +567,10 @@ def get_parameters(
566567
mean: Parameter = None,
567568
sd: Parameter = None,
568569
) -> tuple[pd.DataFrame, dict[str, pd.DataFrame]]:
569-
weights = format_data(weights, list(cls._distribution_map.keys()), "weights")
570+
expected_columns = list(cls._distribution_map.keys())
571+
572+
weights = cls.fill_missing_weights(weights, expected_columns)
573+
weights = format_data(weights, expected_columns, "weights")
570574

571575
params = {}
572576
for name, dist in cls._distribution_map.items():
@@ -591,6 +595,27 @@ def get_parameters(
591595

592596
return weights, params
593597

598+
@staticmethod
599+
def fill_missing_weights(weights: Parameters, expected_columns) -> Parameters:
600+
weights = copy.deepcopy(weights)
601+
602+
# Get existing keys/columns/index based on weights type
603+
if isinstance(weights, dict):
604+
column_names = set(weights.keys())
605+
elif isinstance(weights, pd.DataFrame):
606+
column_names = set(weights.columns)
607+
elif isinstance(weights, pd.Series):
608+
column_names = set(weights.index)
609+
else:
610+
column_names = None # For list, tuple, np.array, we can't fill missing columns
611+
612+
# Add missing columns with 0.0 value
613+
if column_names and column_names < set(expected_columns):
614+
for col in expected_columns:
615+
if col not in column_names:
616+
weights[col] = 0.0
617+
return weights
618+
594619
def pdf(self, x: pd.Series | np.ndarray | float | int) -> pd.Series | np.ndarray | float:
595620
single_val = isinstance(x, (float, int))
596621
values_only = isinstance(x, np.ndarray)

tests/test_ensemble_distribution.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,107 @@
99
from risk_distributions.risk_distributions import EnsembleDistribution
1010

1111
weights_base = {
12-
"betasr": 1 / 12,
13-
"exp": 1 / 12,
14-
"gamma": 1 / 12,
15-
"gumbel": 1 / 12,
16-
"invgamma": 1 / 12,
17-
"invweibull": 1 / 12,
18-
"llogis": 1 / 12,
19-
"lnorm": 1 / 12,
20-
"mgamma": 1 / 12,
21-
"mgumbel": 1 / 12,
22-
"norm": 1 / 12,
23-
"weibull": 1 / 12,
12+
"betasr": 1,
13+
"exp": 2,
14+
"gamma": 3,
15+
"gumbel": 5,
16+
"invgamma": 7,
17+
"invweibull": 11,
18+
"llogis": 13,
19+
"lnorm": 17,
20+
"mgamma": 19,
21+
"mgumbel": 23,
22+
"norm": 29,
23+
"weibull": 31,
2424
}
2525

26-
weights_df = pd.DataFrame({k: [v] for k, v in weights_base.items()})
26+
weights_base_missing = copy.deepcopy(weights_base)
27+
del weights_base_missing["exp"]
28+
29+
30+
def normalize_weights(weights: dict[str, float]) -> dict[str, float]:
31+
weights = copy.deepcopy(weights)
32+
total = sum(weights.values())
33+
for k in weights:
34+
weights[k] = weights[k] / total
35+
return weights
36+
37+
38+
@pytest.fixture
39+
def expected_weights() -> pd.DataFrame:
40+
return pd.DataFrame({k: [v] for k, v in normalize_weights(weights_base).items()})
41+
42+
43+
@pytest.fixture
44+
def expected_weights_missing() -> pd.DataFrame:
45+
data = pd.DataFrame({k: [v] for k, v in normalize_weights(weights_base_missing).items()})
46+
data["exp"] = 0.0
47+
return data
2748

2849

2950
@pytest.mark.parametrize(
3051
"weights",
3152
[
3253
weights_base,
54+
normalize_weights(weights_base),
3355
{k: [v] for k, v in weights_base.items()},
3456
pd.Series(weights_base),
3557
pd.Series(weights_base).reset_index(drop=True),
36-
weights_df,
58+
pd.DataFrame({k: [v] for k, v in weights_base.items()}),
3759
list(weights_base.values()),
3860
tuple(weights_base.values()),
3961
np.array(list(weights_base.values())), # Column Vector
4062
np.array([list(weights_base.values())]), # Row Vector
4163
np.array([list(weights_base.values())]).T,
4264
],
4365
)
44-
def test_weight_formats(weights: Parameters) -> None:
66+
def test_weight_formats(weights: Parameters, expected_weights: pd.DataFrame) -> None:
4567
weights_original = copy.deepcopy(weights)
4668
dist = EnsembleDistribution(
4769
weights,
4870
mean=1,
4971
sd=1,
5072
)
5173
assert_equal(weights_original, weights)
52-
pd.testing.assert_frame_equal(dist.weights, pd.DataFrame(weights_df))
74+
pd.testing.assert_frame_equal(dist.weights, expected_weights)
75+
76+
77+
@pytest.mark.parametrize(
78+
"weights",
79+
[
80+
weights_base_missing,
81+
normalize_weights(weights_base_missing),
82+
{k: [v] for k, v in weights_base_missing.items()},
83+
pd.Series(weights_base_missing),
84+
pd.DataFrame({k: [v] for k, v in weights_base_missing.items()}),
85+
],
86+
)
87+
def test_missing_weights(weights: Parameters, expected_weights_missing: pd.DataFrame) -> None:
88+
weights_original = copy.deepcopy(weights)
89+
dist = EnsembleDistribution(
90+
weights,
91+
mean=1,
92+
sd=1,
93+
)
94+
assert_equal(weights_original, weights)
95+
pd.testing.assert_frame_equal(dist.weights, expected_weights_missing)
96+
97+
98+
@pytest.mark.parametrize(
99+
"weights",
100+
[
101+
pd.Series(weights_base_missing).reset_index(drop=True),
102+
list(weights_base_missing.values()),
103+
tuple(weights_base_missing.values()),
104+
np.array(list(weights_base_missing.values())), # Column Vector
105+
np.array([list(weights_base_missing.values())]), # Row Vector
106+
np.array([list(weights_base_missing.values())]).T,
107+
],
108+
)
109+
def test_missing_weights_invalid(weights: Parameters) -> None:
110+
with pytest.raises(ValueError):
111+
EnsembleDistribution(
112+
weights,
113+
mean=1,
114+
sd=1,
115+
)

tests/test_formatting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ def test_format_data_frame(data_columns, required_columns, match):
121121
format_data_frame(data, required_columns, measure="test")
122122

123123

124+
@pytest.mark.parametrize("data", ["string", {1, 2, 3}, None])
125+
def test_format_data_unsupported_types(data):
126+
"""Test format_data with unsupported data types."""
127+
with pytest.raises(TypeError, match="Unsupported data type"):
128+
format_data(data, ["param1"], "test")
129+
130+
124131
@pytest.mark.parametrize(
125132
"data, required_columns, expected",
126133
[

0 commit comments

Comments
 (0)