Skip to content

Commit f251195

Browse files
Merge pull request #7 from ihmeuw/improvements
Ensure `default_category` is valid for Categoricals during sampling
2 parents 9350552 + 9e67d8f commit f251195

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

src/vivarium_helpers/prob_distributions/sampling.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def sample_array_from_propensity(
1515
If method='select', `category_cdf` must be a mapping of categories
1616
to cumulative probabilities.
1717
If method='array', `category_cdf` must be an nd-array of cumulative
18-
probabilities, broadcastable to shape (propensity.shape, len(categories)).
19-
# TODO: check that this actually works when category_cdf has ndim != 2,
20-
# and verify that I am specifying the correct shapes here.
18+
probabilities, broadcastable to shape
19+
(propensity.shape, len(categories)).
20+
# TODO: check that this actually works when category_cdf has
21+
# ndim != 2, and verify that I am specifying the correct shapes
22+
here.
2123
"""
2224
if method == 'select':
2325
# logger.debug(f"{categories=}")
@@ -30,29 +32,34 @@ def sample_array_from_propensity(
3032
default_category=categories[-1]
3133
# If category_cdf is 1-dimensional, broadcast instead of failing
3234
# when there is more than one propensity
33-
# Note: If there is only 1 propensity, this makes it so its index
34-
# does NOT need to be aligned with the index of a single row CDF,
35-
# which may or may not be what's desired.
35+
# Note: If there is only 1 propensity, this makes it so its
36+
# index does NOT need to be aligned with the index of a single
37+
# row CDF, which may or may not be what's desired.
3638
if isinstance(category_cdf, pd.DataFrame):
3739
category_cdf = category_cdf.squeeze()
3840
condlist = [propensity <= category_cdf[cat] for cat in categories]
39-
category = np.select(condlist, choicelist=categories, default=default_category)
41+
category = np.select(
42+
condlist, choicelist=categories, default=default_category)
4043
elif method == 'array':
4144
if default_category is not None:
42-
raise ValueError("`default_category` is unsupported with method='array'")
45+
raise ValueError(
46+
"`default_category` is unsupported with method='array'")
4347
category_index = (
4448
# TODO: Explain why this works...
4549
np.asarray(propensity).reshape((-1,1))
4650
> np.asarray(category_cdf)
4751
).sum(axis=1)
4852
category = np.asarray(categories)[category_index]
4953
else:
50-
raise ValueError(f"Unknown method: {method}. Acceptable values are 'select' and 'array'.")
54+
raise ValueError(
55+
f"Unknown method: {method}. "
56+
"Acceptable values are 'select' and 'array'.")
5157
return category
5258

5359
def sample_categorical_from_propensity(
5460
propensity,
55-
categories, # pandas CategoricalDtype or iterable of unique categories
61+
# pandas CategoricalDtype or iterable of unique categories
62+
categories,
5663
category_cdf,
5764
method='select',
5865
default_category=None,
@@ -77,17 +84,35 @@ def sample_categorical_from_propensity(
7784
cat_to_code = {cat: code for cat, code in zip(cats, codes)}
7885
category_cdf = category_cdf.rename(columns=cat_to_code)
7986
else:
80-
category_cdf = {code: category_cdf[cat] for code, cat in zip(codes, cats)}
87+
category_cdf = {code: category_cdf[cat]
88+
for code, cat in zip(codes, cats)}
89+
90+
# Ensure that default category is valid for the CategoricalDtype
91+
if default_category is not None:
92+
if pd.isna(default_category):
93+
# code -1 corresponds to NaN in pandas Categoricals
94+
default_category = -1
95+
elif default_category in cats:
96+
# Find the code for this category
97+
default_category = list(cats).index(default_category)
98+
else:
99+
raise ValueError(
100+
"`default_category` must either be an element of `categories` "
101+
"or an object for which `pandas.isna` returns True.")
81102

82103
sampled_codes = sample_array_from_propensity(
83104
propensity,
84105
codes,
85106
category_cdf,
86107
method=method,
108+
# Pass a default of -1 to be converted to NaN duing .from_codes,
109+
# which indicates that the propensity does not correspond to any
110+
# of the specified categories
87111
default_category=default_category,
88112
)
89113
sampled_categories = pd.Categorical.from_codes(
90-
sampled_codes, categories=categories, ordered=ordered, dtype=dtype)
114+
sampled_codes, categories=categories, ordered=ordered,
115+
dtype=dtype)
91116
return sampled_categories
92117

93118
def sample_series_from_propensity(
@@ -108,7 +133,7 @@ def sample_series_from_propensity(
108133
if is_categorical:
109134
if dtype is not None:
110135
raise ValueError(
111-
"`dtype` not allowed for categorical data."
136+
"`dtype` not allowed for categorical data. "
112137
"Pass an instance of `CategoricalDtype` to `categories` instead."
113138
)
114139
sample_array = sample_categorical_from_propensity(
@@ -135,7 +160,7 @@ def sample_series_from_propensity(
135160
if isinstance(category_cdf, pd.DataFrame):
136161
name = category_cdf.columns.name
137162
if name is None and isinstance(category_cdf, pd.Series):
138-
name = category_cdf.name
163+
name = category_cdf.index.name
139164
if name is None and isinstance(categories, (pd.Index, pd.Series)):
140165
name = categories.name
141166
sampled_categories = pd.Series(

0 commit comments

Comments
 (0)