From b730d204267c14af8dfa701c2910f01b80aaa79d Mon Sep 17 00:00:00 2001 From: Soumyadip Sarkar Date: Tue, 3 Mar 2026 00:49:05 +0530 Subject: [PATCH 1/3] Fix prop_above_and_below typing and empty concat handling --- CHANGELOG.md | 7 ++++ balance/stats_and_plots/weights_stats.py | 41 ++++++++++++++----- tests/test_stats_and_plots.py | 52 ++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e59f4ede0..c32f89e3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,11 @@ when the input sample has zero rows, instead of relying on an assertion. - This aligns runtime behavior with documented exceptions and avoids optimization-dependent assert behavior. +- **Stabilized `prop_above_and_below()` return paths** + - `prop_above_and_below()` now builds concatenated outputs only from present + Series objects and returns `None` when both `below` and `above` are `None`, + avoiding ambiguous concat inputs while preserving existing behavior for valid + threshold sets. ## Tests @@ -79,6 +84,8 @@ - Added tests for `link_transform()`, and `calc_dev()` to validate behavior for extreme probabilities, and finite 10-fold deviance summaries. - **Expanded ASCII plot dispatcher edge-case coverage** - Added tests for `ascii_plot_dist` with `comparative=False` to verify direct dispatch to `ascii_plot_hist` and mixed categorical+numeric routing in a single call. +- **Expanded `prop_above_and_below()` edge-case coverage** + - Added focused tests for empty threshold iterables, mixed `None` threshold groups in dict mode, and explicit all-`None` threshold handling across return formats. # 0.16.0 (2026-02-09) diff --git a/balance/stats_and_plots/weights_stats.py b/balance/stats_and_plots/weights_stats.py index bce521d1f..45a2af5d8 100644 --- a/balance/stats_and_plots/weights_stats.py +++ b/balance/stats_and_plots/weights_stats.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, Literal, overload, TypedDict import numpy as np import numpy.typing as npt @@ -17,6 +17,11 @@ logger: logging.Logger = logging.getLogger(__package__) +class PropAboveBelowResult(TypedDict): + below: pd.Series | None + above: pd.Series | None + + ########################################## # Weights diagnostics - functions for analyzing weights # These functions provide statistical measures for evaluating @@ -199,12 +204,32 @@ def nonparametric_skew( return (w.mean() - w.median()) / w.std() +@overload +def prop_above_and_below( + w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, + below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), + above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), + return_as_series: Literal[True] = True, +) -> pd.Series | None: + pass + + +@overload +def prop_above_and_below( + w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, + below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), + above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), + return_as_series: Literal[False] = False, +) -> PropAboveBelowResult | None: + pass + + def prop_above_and_below( w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), return_as_series: bool = True, -) -> pd.Series | dict[Any, Any] | None: +) -> pd.Series | PropAboveBelowResult | None: # TODO (p2): look more in the literature (are there references for using this vs another, or none at all?) # update the doc with insights, once done. """ @@ -317,16 +342,12 @@ def prop_above_and_below( # decide if to return one series or a dict if return_as_series: - out = pd.concat( - [ # pyre-ignore[6]: pd.concat supports Series. - prop_below_series, - prop_above_series, - ] - ) + pieces = [s for s in (prop_below_series, prop_above_series) if s is not None] + out = pd.concat(pieces) if pieces else None else: - out = {"below": prop_below_series, "above": prop_above_series} + out = PropAboveBelowResult(below=prop_below_series, above=prop_above_series) - return out # pyre-ignore[7]: TODO: see if we can fix this pyre + return out def weighted_median_breakdown_point( diff --git a/tests/test_stats_and_plots.py b/tests/test_stats_and_plots.py index f3ee0898d..2b4c898e0 100644 --- a/tests/test_stats_and_plots.py +++ b/tests/test_stats_and_plots.py @@ -217,6 +217,58 @@ def test_prop_above_and_below(self) -> None: } self.assertEqual({k: v.to_list() for k, v in result_dict.items()}, expected) + def test_prop_above_and_below_edge_cases(self) -> None: + """Cover edge combinations for thresholds and return formats.""" + from balance.stats_and_plots.weights_stats import prop_above_and_below + + weights = pd.Series((1.0, 2.0, 3.0, 4.0)) + + # Empty threshold iterables should return an empty Series in series mode. + result_empty = prop_above_and_below(weights, below=(), above=()) + self.assertIsNotNone(result_empty) + result_empty = _assert_type(result_empty, pd.Series) + self.assertEqual(result_empty.to_list(), []) + self.assertEqual(result_empty.index.to_list(), []) + + # Dict mode should preserve None for omitted threshold groups. + result_dict_only_above = prop_above_and_below( + weights, + below=None, + above=(1, 2), + return_as_series=False, + ) + self.assertIsNotNone(result_dict_only_above) + result_dict_only_above = _assert_type(result_dict_only_above) + self.assertIsNone(result_dict_only_above["below"]) + self.assertEqual( + result_dict_only_above["above"].index.to_list(), + ["prop(w >= 1)", "prop(w >= 2)"], + ) + + result_dict_only_below = prop_above_and_below( + weights, + below=(0.5, 1), + above=None, + return_as_series=False, + ) + self.assertIsNotNone(result_dict_only_below) + result_dict_only_below = _assert_type(result_dict_only_below) + self.assertEqual( + result_dict_only_below["below"].index.to_list(), + ["prop(w < 0.5)", "prop(w < 1)"], + ) + self.assertIsNone(result_dict_only_below["above"]) + + # If both groups are omitted, function should return None in all modes. + self.assertIsNone( + prop_above_and_below( + weights, + below=None, + above=None, + return_as_series=False, + ) + ) + def test_weights_diagnostics_accept_list_and_ndarray_input(self) -> None: """Ensure diagnostics are equivalent across list/ndarray/Series inputs.""" from balance.stats_and_plots.weights_stats import ( From f2799fb18f6ccf6950049cc32da05230fd507c1d Mon Sep 17 00:00:00 2001 From: Soumyadip Sarkar Date: Tue, 3 Mar 2026 01:11:37 +0530 Subject: [PATCH 2/3] Update docstrings --- balance/stats_and_plots/weights_stats.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/balance/stats_and_plots/weights_stats.py b/balance/stats_and_plots/weights_stats.py index 45a2af5d8..4d453c67d 100644 --- a/balance/stats_and_plots/weights_stats.py +++ b/balance/stats_and_plots/weights_stats.py @@ -219,7 +219,8 @@ def prop_above_and_below( w: list[Any] | pd.Series | npt.NDArray | pd.DataFrame, below: tuple[float, ...] | list[float] | None = (1 / 10, 1 / 5, 1 / 3, 1 / 2, 1), above: tuple[float, ...] | list[float] | None = (1, 2, 3, 5, 10), - return_as_series: Literal[False] = False, + *, + return_as_series: Literal[False], ) -> PropAboveBelowResult | None: pass @@ -249,14 +250,15 @@ def prop_above_and_below( DataFrame, only the first column is used. below (tuple[float, ...] | list[float] | None, optional): values to check which proportion of normalized weights are *below* them. - Using None returns None. + Using None omits below-threshold calculations. Defaults to (1/10, 1/5, 1/3, 1/2, 1). above (tuple[float, ...] | list[float] | None, optional): values to check which proportion of normalized weights are *above* (or equal) to them. - Using None returns None. + Using None omits above-threshold calculations. Defaults to (1, 2, 3, 5, 10). return_as_series (bool, optional): If true returns one pd.Series of values. - If False will return a dict with two pd.Series (one for below and one for above). + If False returns ``PropAboveBelowResult`` with ``below``/``above`` entries + containing a ``pd.Series`` or ``None`` for omitted groups. Defaults to True. Returns: From 2aa461eabeb47cfd77ee83a0efed49603851f1f8 Mon Sep 17 00:00:00 2001 From: Soumyadip Sarkar Date: Tue, 3 Mar 2026 01:26:27 +0530 Subject: [PATCH 3/3] Implement suggestions --- balance/stats_and_plots/weights_stats.py | 9 ++++++--- tests/test_stats_and_plots.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/balance/stats_and_plots/weights_stats.py b/balance/stats_and_plots/weights_stats.py index 4d453c67d..7d9831710 100644 --- a/balance/stats_and_plots/weights_stats.py +++ b/balance/stats_and_plots/weights_stats.py @@ -262,11 +262,14 @@ def prop_above_and_below( Defaults to True. Returns: - pd.Series | dict: + pd.Series | PropAboveBelowResult | None: If return_as_series is True we get pd.Series with proportions of (normalized weights) that are below/above some numbers, the index indicates which threshold was checked (the values in the index are rounded up to 3 points for printing purposes). - If return_as_series is False we get a dict with 'below' and 'above' with the relevant pd.Series (or None). + If return_as_series is False we get ``PropAboveBelowResult`` with + ``below`` and ``above`` keys whose values are the relevant pd.Series + (or ``None`` when a side is omitted). If both ``below`` and ``above`` + are ``None``, the function returns ``None``. Examples: :: @@ -345,7 +348,7 @@ def prop_above_and_below( # decide if to return one series or a dict if return_as_series: pieces = [s for s in (prop_below_series, prop_above_series) if s is not None] - out = pd.concat(pieces) if pieces else None + out = pd.concat(pieces) else: out = PropAboveBelowResult(below=prop_below_series, above=prop_above_series) diff --git a/tests/test_stats_and_plots.py b/tests/test_stats_and_plots.py index 2b4c898e0..b8e22a05f 100644 --- a/tests/test_stats_and_plots.py +++ b/tests/test_stats_and_plots.py @@ -230,6 +230,20 @@ def test_prop_above_and_below_edge_cases(self) -> None: self.assertEqual(result_empty.to_list(), []) self.assertEqual(result_empty.index.to_list(), []) + # Empty iterables are distinct from omitted groups in dict mode. + result_dict_empty = prop_above_and_below( + weights, + below=(), + above=(), + return_as_series=False, + ) + self.assertIsNotNone(result_dict_empty) + result_dict_empty = _assert_type(result_dict_empty) + self.assertIsNotNone(result_dict_empty["below"]) + self.assertIsNotNone(result_dict_empty["above"]) + self.assertEqual(result_dict_empty["below"].to_list(), []) + self.assertEqual(result_dict_empty["above"].to_list(), []) + # Dict mode should preserve None for omitted threshold groups. result_dict_only_above = prop_above_and_below( weights,