Skip to content

Commit 1fadd98

Browse files
authored
Use custom rounding rules for summary (#294)
* Use custom rounding rules for summary * use scientific * fix tests * add/fix tests * fix test for pandas 3.0 * clarify round_to * clarify round_to * improve format docstring and wording * fix indetantion
1 parent b7afa80 commit 1fadd98

File tree

3 files changed

+157
-6
lines changed

3 files changed

+157
-6
lines changed

src/arviz_stats/base/stats_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,10 @@ def round_num(value, precision):
378378
return round(value, sig_digits - int(np.floor(np.log10(abs(value)))) - 1)
379379

380380
return value
381+
382+
383+
def get_decimal_places_from_se(se_val):
384+
"""Get number of decimal places from standard error value."""
385+
two_se = 2 * se_val
386+
se_magnitude = np.floor(np.log10(np.abs(two_se))) if two_se != 0 else 0
387+
return -int(se_magnitude)

src/arviz_stats/summary.py

Lines changed: 114 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Summaries for various statistics and diagnostics."""
22

33
import numpy as np
4+
import pandas as pd
45
import xarray as xr
56
from arviz_base import dataset_to_dataframe, extract, rcParams, references_to_dataset
67
from xarray_einstats import stats
78

9+
from arviz_stats.base.stats_utils import get_decimal_places_from_se, round_num
810
from arviz_stats.utils import _apply_multi_input_function
911
from arviz_stats.validate import validate_dims
1012

@@ -22,7 +24,7 @@ def summary(
2224
fmt="wide",
2325
ci_prob=None,
2426
ci_kind=None,
25-
round_to=2,
27+
round_to="auto",
2628
skipna=False,
2729
):
2830
"""
@@ -60,8 +62,22 @@ def summary(
6062
ci_kind : {"hdi", "eti"}, optional
6163
Type of credible interval. Defaults to ``rcParams["stats.ci_kind"]``.
6264
If `kind` is stats_median or all_median, `ci_kind` is forced to "eti".
63-
round_to : int
64-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
65+
round_to : int or {"auto", "none"}, optional
66+
Rounding specification. Defaults to "auto". If integer, number of decimal places to
67+
round to. If "none", no rounding is applied. If "auto", and `fmt` is "xarray" defaults to
68+
``rcParams["stats.round_to"]``. If "auto" and `fmt` is in {"wide", "long"}, applies the
69+
following rounding rules:
70+
71+
* ESS values (ess_bulk, ess_tail, ess_mean, ess_median, min_ss) are rounded down to int
72+
* R-hat always shows 2 digits after the decimal
73+
* If a column *stat* and *mcse_stat* are both present then the mcse is shown to 2
74+
significant figures, and *stat* is shown with precision based on 2*mcse.
75+
* All other floating point numbers are shown following ``rcParams["stats.round_to"]``.
76+
* For all floating point numbers except R-hat, trailing zeros are removed and values are
77+
converted to string for consistent display.
78+
79+
Note: "auto" is intended for display purposes, using it is not recommended when the output
80+
will be used for further numerical computations.
6581
skipna: bool
6682
If true ignores nan values when computing the summary statistics. Defaults to false.
6783
@@ -127,6 +143,10 @@ def summary(
127143
ci_kind = rcParams["stats.ci_kind"]
128144
if sample_dims is None:
129145
sample_dims = rcParams["data.sample_dims"]
146+
if round_to == "auto":
147+
round_val = rcParams["stats.round_to"]
148+
else:
149+
round_val = round_to
130150

131151
ci_perc = int(ci_prob * 100)
132152

@@ -224,12 +244,101 @@ def summary(
224244
summary_result = summary_result.to_dataframe().reset_index().set_index("summary")
225245
summary_result.index = list(summary_result.index)
226246

227-
if (round_to is not None) and (round_to not in ("None", "none")):
228-
summary_result = summary_result.round(round_to)
247+
if fmt == "xarray":
248+
if (round_to is not None) and (round_to not in ("None", "none")):
249+
summary_result = xr.apply_ufunc(round_num, summary_result, round_val, vectorize=True)
250+
else:
251+
if round_to == "auto":
252+
summary_result = _round_summary(summary_result, round_val)
253+
else:
254+
if (round_to is not None) and (round_to not in ("None", "none")):
255+
summary_result = summary_result.map(lambda x: round_num(x, round_val))
229256

230257
return summary_result
231258

232259

260+
def _round_summary(summary_result, round_val):
261+
"""Apply custom rounding rules to summary statistics.
262+
263+
Parameters
264+
----------
265+
summary_result : pandas.DataFrame
266+
The summary result to round
267+
round_val : int or str
268+
Number of decimals or significant figures to round to.
269+
270+
Returns
271+
-------
272+
pandas.DataFrame
273+
"""
274+
result = summary_result.copy()
275+
columns = result.columns
276+
rounded_columns = set()
277+
use_scientific = {}
278+
279+
# Rule 1: ESS columns and min_ss are rounded down to int
280+
ess_cols = [col for col in columns if col.startswith("ess_") or col == "min_ss"]
281+
for col in ess_cols:
282+
result[col] = result[col].apply(lambda x: pd.NA if not np.isfinite(x) else np.floor(x))
283+
result[col] = result[col].astype("Int64")
284+
rounded_columns.add(col)
285+
286+
# Rule 2: R-hat always shows 2 digits after decimal
287+
if "r_hat" in columns:
288+
result["r_hat"] = result["r_hat"].round(2)
289+
rounded_columns.add("r_hat")
290+
291+
# Rule 3: Handle stat/mcse pairs
292+
stat_se_pairs = []
293+
for col in columns:
294+
if col.startswith("mcse_"):
295+
stat_col = col[5:]
296+
if stat_col in columns:
297+
stat_se_pairs.append((stat_col, col))
298+
299+
for stat_col, se_col in stat_se_pairs:
300+
result[se_col] = result[se_col].apply(lambda x: round_num(x, round_val))
301+
302+
for idx in result.index:
303+
stat_val = result.loc[idx, stat_col]
304+
se_val = result.loc[idx, se_col]
305+
if not np.isfinite(se_val):
306+
continue
307+
decimal_places = get_decimal_places_from_se(se_val)
308+
if decimal_places < 0:
309+
use_scientific[(idx, stat_col)] = True
310+
else:
311+
result.loc[idx, stat_col] = round_num(stat_val, decimal_places)
312+
313+
rounded_columns.add(stat_col)
314+
rounded_columns.add(se_col)
315+
316+
# Rule 4: Other floating point numbers to round_val significant figures
317+
for col in columns:
318+
if col not in rounded_columns:
319+
if result[col].dtype.kind == "f":
320+
result[col] = result[col].apply(lambda x: round_num(x, round_val))
321+
322+
# Rule 5: Format
323+
for col in columns:
324+
if result[col].dtype.kind == "f":
325+
if col == "r_hat":
326+
result[col] = result[col].apply(lambda x: f"{x:.2f}" if np.isfinite(x) else x)
327+
else:
328+
formatted_values = []
329+
for idx, val in zip(result.index, result[col]):
330+
if not np.isfinite(val):
331+
formatted_values.append(val)
332+
elif use_scientific.get((idx, col), False):
333+
formatted_values.append(f"{val:.0e}")
334+
else:
335+
formatted_values.append(f"{val:g}")
336+
337+
result[col] = formatted_values
338+
339+
return result
340+
341+
233342
def ci_in_rope(
234343
data,
235344
rope,

tests/test_summary.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .helpers import importorskip
1111

1212
azb = importorskip("arviz_base")
13+
pd = importorskip("pandas")
1314
xr = importorskip("xarray")
1415

1516
from arviz_stats import ci_in_rope, eti, hdi, mode, qds, summary
17+
from arviz_stats.summary import _round_summary
1618

1719

1820
def test_summary_ndarray():
@@ -396,7 +398,7 @@ def test_mode_single_value_array():
396398
def test_summary_zero_variance():
397399
array = np.ones((4, 100, 2))
398400
summary_df = summary(array)
399-
assert summary_df["sd"].iloc[0] == 0.0
401+
assert summary_df["sd"].iloc[0] == "0"
400402

401403

402404
@pytest.mark.parametrize("prob", [0.5, 0.89, 0.95])
@@ -423,3 +425,36 @@ def test_summary_fmt(datatree, fmt):
423425
else:
424426
assert isinstance(result, xr.Dataset)
425427
assert "summary" in result.dims
428+
429+
430+
def test_round_summary():
431+
labels = ["a", "bb", "ccc", "d", "e"]
432+
data = {
433+
"mean": [111.11, 1.2345e-6, 5.4321e8, np.inf, np.nan],
434+
"mcse_mean": [0.0012345, 5.432e-5, 2.1234e5, np.inf, np.nan],
435+
"sd": [0.0012345, 5.432e-5, 2.1234e5, np.inf, np.nan],
436+
"r_hat": [1.009, 1.011, 0.99, np.inf, np.nan],
437+
"ess_bulk": [312.45, 23.32, 1011.98, np.inf, np.nan],
438+
"ess_tail": [9.2345, 876.321, 999.99, np.inf, np.nan],
439+
}
440+
df = pd.DataFrame(data, index=labels)
441+
result = _round_summary(df, round_val=2)
442+
443+
assert result["ess_bulk"].dtype == "Int64"
444+
assert result["ess_tail"].dtype == "Int64"
445+
expected_ess_bulk = pd.Series(
446+
[312, 23, 1011, pd.NA, pd.NA], index=labels, dtype="Int64", name="ess_bulk"
447+
)
448+
pd.testing.assert_series_equal(result["ess_bulk"], expected_ess_bulk)
449+
expected_ess_tail = pd.Series(
450+
[9, 876, 999, pd.NA, pd.NA], index=labels, dtype="Int64", name="ess_tail"
451+
)
452+
pd.testing.assert_series_equal(result["ess_tail"], expected_ess_tail)
453+
expected_r_hat = pd.Series(["1.01", "1.01", "0.99", np.inf, np.nan], index=labels, name="r_hat")
454+
pd.testing.assert_series_equal(result["r_hat"], expected_r_hat, check_dtype=False)
455+
expected_mcse = pd.Series(["0", "0", "212340", np.inf, np.nan], index=labels, name="mcse_mean")
456+
pd.testing.assert_series_equal(result["mcse_mean"], expected_mcse, check_dtype=False)
457+
expected_mean = pd.Series(["111", "0", "5e+08"], index=labels[:3], name="mean")
458+
pd.testing.assert_series_equal(result.loc[labels[:3], "mean"], expected_mean, check_dtype=False)
459+
expected_sd = pd.Series(["0", "0", "212340"], index=labels[:3], name="sd")
460+
pd.testing.assert_series_equal(result.loc[labels[:3], "sd"], expected_sd, check_dtype=False)

0 commit comments

Comments
 (0)