diff --git a/CHANGELOG.md b/CHANGELOG.md index e59f4ede0..c32f89e3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,11 @@ when the input sample has zero rows, instead of relying on an assertion. - This aligns runtime behavior with documented exceptions and avoids optimization-dependent assert behavior. +- **Stabilized `prop_above_and_below()` return paths** + - `prop_above_and_below()` now builds concatenated outputs only from present + Series objects and returns `None` when both `below` and `above` are `None`, + avoiding ambiguous concat inputs while preserving existing behavior for valid + threshold sets. ## Tests @@ -79,6 +84,8 @@ - Added tests for `link_transform()`, and `calc_dev()` to validate behavior for extreme probabilities, and finite 10-fold deviance summaries. - **Expanded ASCII plot dispatcher edge-case coverage** - Added tests for `ascii_plot_dist` with `comparative=False` to verify direct dispatch to `ascii_plot_hist` and mixed categorical+numeric routing in a single call. +- **Expanded `prop_above_and_below()` edge-case coverage** + - Added focused tests for empty threshold iterables, mixed `None` threshold groups in dict mode, and explicit all-`None` threshold handling across return formats. # 0.16.0 (2026-02-09) diff --git a/balance/stats_and_plots/weights_stats.py b/balance/stats_and_plots/weights_stats.py index bce521d1f..7d9831710 100644 --- a/balance/stats_and_plots/weights_stats.py +++ b/balance/stats_and_plots/weights_stats.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, Literal, overload, TypedDict import numpy as np import numpy.typing as npt @@ -17,6 +17,11 @@ logger: logging.Logger = logging.getLogger(__package__) +class PropAboveBelowResult(TypedDict): + below: pd.Series | None + above: pd.Series | None + + ########################################## # Weights diagnostics - functions for analyzing weights # These functions provide statistical measures for evaluating @@ -199,12 +204,33 @@ def nonparametric_skew( return (w.mean() - w.median()) / w.std() +@overload +def prop_above_and_below( + w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, + below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), + above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), + return_as_series: Literal[True] = True, +) -> pd.Series | None: + pass + + +@overload +def prop_above_and_below( + w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, + below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), + above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), + *, + return_as_series: Literal[False], +) -> PropAboveBelowResult | None: + pass + + def prop_above_and_below( w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), return_as_series: bool = True, -) -> pd.Series | dict[Any, Any] | None: +) -> pd.Series | PropAboveBelowResult | None: # TODO (p2): look more in the literature (are there references for using this vs another, or none at all?) # update the doc with insights, once done. """ @@ -224,22 +250,26 @@ def prop_above_and_below( DataFrame, only the first column is used. below (tuple[float, ...] | list[float] | None, optional): values to check which proportion of normalized weights are *below* them. - Using None returns None. + Using None omits below-threshold calculations. Defaults to (1/10, 1/5, 1/3, 1/2, 1). above (tuple[float, ...] | list[float] | None, optional): values to check which proportion of normalized weights are *above* (or equal) to them. - Using None returns None. + Using None omits above-threshold calculations. Defaults to (1, 2, 3, 5, 10). return_as_series (bool, optional): If true returns one pd.Series of values. - If False will return a dict with two pd.Series (one for below and one for above). + If False returns ``PropAboveBelowResult`` with ``below``/``above`` entries + containing a ``pd.Series`` or ``None`` for omitted groups. Defaults to True. Returns: - pd.Series | dict: + pd.Series | PropAboveBelowResult | None: If return_as_series is True we get pd.Series with proportions of (normalized weights) that are below/above some numbers, the index indicates which threshold was checked (the values in the index are rounded up to 3 points for printing purposes). - If return_as_series is False we get a dict with 'below' and 'above' with the relevant pd.Series (or None). + If return_as_series is False we get ``PropAboveBelowResult`` with + ``below`` and ``above`` keys whose values are the relevant pd.Series + (or ``None`` when a side is omitted). If both ``below`` and ``above`` + are ``None``, the function returns ``None``. Examples: :: @@ -317,16 +347,12 @@ def prop_above_and_below( # decide if to return one series or a dict if return_as_series: - out = pd.concat( - [ # pyre-ignore[6]: pd.concat supports Series. - prop_below_series, - prop_above_series, - ] - ) + pieces = [s for s in (prop_below_series, prop_above_series) if s is not None] + out = pd.concat(pieces) else: - out = {"below": prop_below_series, "above": prop_above_series} + out = PropAboveBelowResult(below=prop_below_series, above=prop_above_series) - return out # pyre-ignore[7]: TODO: see if we can fix this pyre + return out def weighted_median_breakdown_point( diff --git a/tests/test_stats_and_plots.py b/tests/test_stats_and_plots.py index f3ee0898d..b8e22a05f 100644 --- a/tests/test_stats_and_plots.py +++ b/tests/test_stats_and_plots.py @@ -217,6 +217,72 @@ def test_prop_above_and_below(self) -> None: } self.assertEqual({k: v.to_list() for k, v in result_dict.items()}, expected) + def test_prop_above_and_below_edge_cases(self) -> None: + """Cover edge combinations for thresholds and return formats.""" + from balance.stats_and_plots.weights_stats import prop_above_and_below + + weights = pd.Series((1.0, 2.0, 3.0, 4.0)) + + # Empty threshold iterables should return an empty Series in series mode. + result_empty = prop_above_and_below(weights, below=(), above=()) + self.assertIsNotNone(result_empty) + result_empty = _assert_type(result_empty, pd.Series) + self.assertEqual(result_empty.to_list(), []) + self.assertEqual(result_empty.index.to_list(), []) + + # Empty iterables are distinct from omitted groups in dict mode. + result_dict_empty = prop_above_and_below( + weights, + below=(), + above=(), + return_as_series=False, + ) + self.assertIsNotNone(result_dict_empty) + result_dict_empty = _assert_type(result_dict_empty) + self.assertIsNotNone(result_dict_empty["below"]) + self.assertIsNotNone(result_dict_empty["above"]) + self.assertEqual(result_dict_empty["below"].to_list(), []) + self.assertEqual(result_dict_empty["above"].to_list(), []) + + # Dict mode should preserve None for omitted threshold groups. + result_dict_only_above = prop_above_and_below( + weights, + below=None, + above=(1, 2), + return_as_series=False, + ) + self.assertIsNotNone(result_dict_only_above) + result_dict_only_above = _assert_type(result_dict_only_above) + self.assertIsNone(result_dict_only_above["below"]) + self.assertEqual( + result_dict_only_above["above"].index.to_list(), + ["prop(w >= 1)", "prop(w >= 2)"], + ) + + result_dict_only_below = prop_above_and_below( + weights, + below=(0.5, 1), + above=None, + return_as_series=False, + ) + self.assertIsNotNone(result_dict_only_below) + result_dict_only_below = _assert_type(result_dict_only_below) + self.assertEqual( + result_dict_only_below["below"].index.to_list(), + ["prop(w < 0.5)", "prop(w < 1)"], + ) + self.assertIsNone(result_dict_only_below["above"]) + + # If both groups are omitted, function should return None in all modes. + self.assertIsNone( + prop_above_and_below( + weights, + below=None, + above=None, + return_as_series=False, + ) + ) + def test_weights_diagnostics_accept_list_and_ndarray_input(self) -> None: """Ensure diagnostics are equivalent across list/ndarray/Series inputs.""" from balance.stats_and_plots.weights_stats import (