Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
- **Drop NA rows when requested in model_matrix**
- `model_matrix(add_na=False)` now actually drops rows containing NA values while preserving categorical levels, matching the documented behavior.
- Previously, `add_na=False` only logged a warning without dropping rows; code relying on the old behavior may now see fewer rows and should either handle missingness explicitly or use `add_na=True`.
- **Poststratify missing data handling**
- `poststratify()` now accepts `na_action` to either drop rows with missing
values or treat missing values as their own category during weighting.
- **Breaking change:** the default behavior now treats missing values as a
distinct category. Previously, rows with missing values were implicitly
excluded by pandas groupby operations. To preserve the old behavior, pass
`na_action="drop"` (or `False`) explicitly.

## Bug Fixes

Expand Down
26 changes: 25 additions & 1 deletion balance/weighting_methods/poststratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


# TODO: Add tests for all arguments of function
# TODO: Add argument for na_action
def poststratify(
sample_df: pd.DataFrame,
sample_weights: pd.Series,
Expand All @@ -27,6 +26,7 @@ def poststratify(
transformations: str = "default",
transformations_drop: bool = True,
strict_matching: bool = True,
na_action: Union[str, bool] = "add_indicator",
weight_trimming_mean_ratio: Union[float, int, None] = None,
weight_trimming_percentile: Union[float, None] = None,
keep_sum_of_weights: bool = True,
Expand All @@ -50,6 +50,10 @@ def poststratify(
transformations (str, optional): Transformations to apply to data before fitting the model. Default is "default". See `balance.adjustment.apply_transformations`.
transformations_drop (bool, optional): If True, drops variables not affected by transformations. Default is True.
strict_matching (bool, optional): If True, requires all sample cells to be present in the target. If False, cells missing in the target are assigned weight 0 (and a warning is raised). Default is True.
na_action (Union[str, bool], optional): How to handle missing values. Use
``True``/``"add_indicator"`` to treat missing values as their own category, or
``False``/``"drop"`` to remove rows with missing values from both sample and
target. Defaults to ``"add_indicator"``.
weight_trimming_mean_ratio (Union[float, int, None], optional): Forwarded to
:func:`balance.adjustment.trim_weights` to clip weights at a multiple of the mean.
weight_trimming_percentile (Union[float, None], optional): Percentile limit(s) for
Expand Down Expand Up @@ -145,6 +149,26 @@ def poststratify(
variables = list(sample_df.columns)
logger.debug(f"Final variables in the model after transformations: {variables}")

if na_action is True:
na_action = "add_indicator"
elif na_action is False:
na_action = "drop"

if na_action == "drop":
(sample_df, sample_weights) = balance_util.drop_na_rows(
sample_df, sample_weights, "sample"
)
(target_df, target_weights) = balance_util.drop_na_rows(
target_df, target_weights, "target"
)
elif na_action == "add_indicator":
from balance.util import _safe_fillna_and_infer

sample_df = _safe_fillna_and_infer(sample_df, "__NaN__")
target_df = _safe_fillna_and_infer(target_df, "__NaN__")
else:
raise ValueError("`na_action` must be 'add_indicator' or 'drop'")

target_df = target_df.assign(weight=target_weights)
target_cell_props = target_df.groupby(list(variables))["weight"].sum()

Expand Down
106 changes: 106 additions & 0 deletions tests/test_poststratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,101 @@ def test_poststratify_transformations(self) -> None:
self.assertAlmostEqual(result[s.x == "b"].sum() / size, 0.035, delta=eps)
self.assertAlmostEqual(result[s.x == "c"].sum() / size, 0.015, delta=eps)

def test_poststratify_na_action(self) -> None:
s = pd.DataFrame(
{
"a": (1, np.nan, 2),
"b": ("x", "x", "y"),
}
)
t = s.copy()
s_weights = pd.Series([1, 1, 1])
t_weights = pd.Series([2, 3, 4])

result_add = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
na_action="add_indicator",
transformations=None,
)["weight"]
self.assertEqual(result_add, t_weights.astype("float64"))

result_drop = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
na_action="drop",
transformations=None,
)["weight"]
expected = t_weights.loc[s.dropna().index].astype("float64")
self.assertEqual(result_drop, expected)

result_bool_add = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
na_action=True,
transformations=None,
)["weight"]
self.assertEqual(result_bool_add, t_weights.astype("float64"))

result_bool_drop = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
na_action=False,
transformations=None,
)["weight"]
self.assertEqual(result_bool_drop, expected)

def test_poststratify_na_drop_warns(self) -> None:
sample = Sample.from_frame(
pd.DataFrame(
{
"a": (1, np.nan, 2),
"id": (1, 2, 3),
}
),
id_column="id",
)
target = Sample.from_frame(
pd.DataFrame(
{
"a": (1, 2, np.nan),
"id": (1, 2, 3),
}
),
id_column="id",
)
self.assertWarnsRegexp(
"Dropped 1/3 rows of sample",
sample.adjust,
target,
method="poststratify",
na_action="drop",
transformations=None,
)

def test_poststratify_dropna_empty(self) -> None:
s = pd.DataFrame({"a": (np.nan, None), "b": (np.nan, None)})
s_w = pd.Series((1, 2))
self.assertRaisesRegex(
ValueError,
"Dropping rows led to empty",
poststratify,
s,
s_w,
s,
s_w,
na_action="drop",
transformations=None,
)

def test_poststratify_exceptions(self) -> None:
# column with name weight
s = pd.DataFrame(
Expand Down Expand Up @@ -308,3 +403,14 @@ def test_poststratify_exceptions(self) -> None:
strict_matching=False,
)["weight"]
self.assertEqual(result, pd.Series([2.0, 0.0]))

with self.assertRaisesRegex(
ValueError, "`na_action` must be 'add_indicator' or 'drop'"
):
poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
na_action="invalid",
)
Loading