11"""Summaries for various statistics and diagnostics."""
22
33import numpy as np
4+ import pandas as pd
45import xarray as xr
56from arviz_base import dataset_to_dataframe , extract , rcParams , references_to_dataset
67from xarray_einstats import stats
78
9+ from arviz_stats .base .stats_utils import get_decimal_places_from_se , round_num
810from arviz_stats .utils import _apply_multi_input_function
911from arviz_stats .validate import validate_dims
1012
@@ -22,7 +24,7 @@ def summary(
2224 fmt = "wide" ,
2325 ci_prob = None ,
2426 ci_kind = None ,
25- round_to = 2 ,
27+ round_to = "auto" ,
2628 skipna = False ,
2729):
2830 """
@@ -60,8 +62,22 @@ def summary(
6062 ci_kind : {"hdi", "eti"}, optional
6163 Type of credible interval. Defaults to ``rcParams["stats.ci_kind"]``.
6264 If `kind` is stats_median or all_median, `ci_kind` is forced to "eti".
63- round_to : int
64- Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
65+ round_to : int or {"auto", "none"}, optional
66+ Rounding specification. Defaults to "auto". If integer, number of decimal places to
67+ round to. If "none", no rounding is applied. If "auto", and `fmt` is "xarray" defaults to
68+ ``rcParams["stats.round_to"]``. If "auto" and `fmt` is in {"wide", "long"}, applies the
69+ following rounding rules:
70+
71+ * ESS values (ess_bulk, ess_tail, ess_mean, ess_median, min_ss) are rounded down to int
72+ * R-hat always shows 2 digits after the decimal
73+ * If a column *stat* and *mcse_stat* are both present then the mcse is shown to 2
74+ significant figures, and *stat* is shown with precision based on 2*mcse.
75+ * All other floating point numbers are shown following ``rcParams["stats.round_to"]``.
76+ * For all floating point numbers except R-hat, trailing zeros are removed and values are
77+ converted to string for consistent display.
78+
79+ Note: "auto" is intended for display purposes, using it is not recommended when the output
80+ will be used for further numerical computations.
6581 skipna: bool
6682 If true ignores nan values when computing the summary statistics. Defaults to false.
6783
@@ -127,6 +143,10 @@ def summary(
127143 ci_kind = rcParams ["stats.ci_kind" ]
128144 if sample_dims is None :
129145 sample_dims = rcParams ["data.sample_dims" ]
146+ if round_to == "auto" :
147+ round_val = rcParams ["stats.round_to" ]
148+ else :
149+ round_val = round_to
130150
131151 ci_perc = int (ci_prob * 100 )
132152
@@ -224,12 +244,101 @@ def summary(
224244 summary_result = summary_result .to_dataframe ().reset_index ().set_index ("summary" )
225245 summary_result .index = list (summary_result .index )
226246
227- if (round_to is not None ) and (round_to not in ("None" , "none" )):
228- summary_result = summary_result .round (round_to )
247+ if fmt == "xarray" :
248+ if (round_to is not None ) and (round_to not in ("None" , "none" )):
249+ summary_result = xr .apply_ufunc (round_num , summary_result , round_val , vectorize = True )
250+ else :
251+ if round_to == "auto" :
252+ summary_result = _round_summary (summary_result , round_val )
253+ else :
254+ if (round_to is not None ) and (round_to not in ("None" , "none" )):
255+ summary_result = summary_result .map (lambda x : round_num (x , round_val ))
229256
230257 return summary_result
231258
232259
260+ def _round_summary (summary_result , round_val ):
261+ """Apply custom rounding rules to summary statistics.
262+
263+ Parameters
264+ ----------
265+ summary_result : pandas.DataFrame
266+ The summary result to round
267+ round_val : int or str
268+ Number of decimals or significant figures to round to.
269+
270+ Returns
271+ -------
272+ pandas.DataFrame
273+ """
274+ result = summary_result .copy ()
275+ columns = result .columns
276+ rounded_columns = set ()
277+ use_scientific = {}
278+
279+ # Rule 1: ESS columns and min_ss are rounded down to int
280+ ess_cols = [col for col in columns if col .startswith ("ess_" ) or col == "min_ss" ]
281+ for col in ess_cols :
282+ result [col ] = result [col ].apply (lambda x : pd .NA if not np .isfinite (x ) else np .floor (x ))
283+ result [col ] = result [col ].astype ("Int64" )
284+ rounded_columns .add (col )
285+
286+ # Rule 2: R-hat always shows 2 digits after decimal
287+ if "r_hat" in columns :
288+ result ["r_hat" ] = result ["r_hat" ].round (2 )
289+ rounded_columns .add ("r_hat" )
290+
291+ # Rule 3: Handle stat/mcse pairs
292+ stat_se_pairs = []
293+ for col in columns :
294+ if col .startswith ("mcse_" ):
295+ stat_col = col [5 :]
296+ if stat_col in columns :
297+ stat_se_pairs .append ((stat_col , col ))
298+
299+ for stat_col , se_col in stat_se_pairs :
300+ result [se_col ] = result [se_col ].apply (lambda x : round_num (x , round_val ))
301+
302+ for idx in result .index :
303+ stat_val = result .loc [idx , stat_col ]
304+ se_val = result .loc [idx , se_col ]
305+ if not np .isfinite (se_val ):
306+ continue
307+ decimal_places = get_decimal_places_from_se (se_val )
308+ if decimal_places < 0 :
309+ use_scientific [(idx , stat_col )] = True
310+ else :
311+ result .loc [idx , stat_col ] = round_num (stat_val , decimal_places )
312+
313+ rounded_columns .add (stat_col )
314+ rounded_columns .add (se_col )
315+
316+ # Rule 4: Other floating point numbers to round_val significant figures
317+ for col in columns :
318+ if col not in rounded_columns :
319+ if result [col ].dtype .kind == "f" :
320+ result [col ] = result [col ].apply (lambda x : round_num (x , round_val ))
321+
322+ # Rule 5: Format
323+ for col in columns :
324+ if result [col ].dtype .kind == "f" :
325+ if col == "r_hat" :
326+ result [col ] = result [col ].apply (lambda x : f"{ x :.2f} " if np .isfinite (x ) else x )
327+ else :
328+ formatted_values = []
329+ for idx , val in zip (result .index , result [col ]):
330+ if not np .isfinite (val ):
331+ formatted_values .append (val )
332+ elif use_scientific .get ((idx , col ), False ):
333+ formatted_values .append (f"{ val :.0e} " )
334+ else :
335+ formatted_values .append (f"{ val :g} " )
336+
337+ result [col ] = formatted_values
338+
339+ return result
340+
341+
233342def ci_in_rope (
234343 data ,
235344 rope ,
0 commit comments