From d3b33616cef96f79704443a6667ed704f76c91b6 Mon Sep 17 00:00:00 2001 From: Uzair Gheewala Date: Sat, 11 Jan 2025 18:18:09 -0800 Subject: [PATCH 1/4] Add test for compare method standard error sorting consistency (#2350) --- arviz/tests/base_tests/test_stats.py | 42 +++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 12b19c2e96..8946d4a7d7 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -14,7 +14,7 @@ from xarray import DataArray, Dataset from xarray_einstats.stats import XrContinuousRV -from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data +from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData from ...rcparams import rcParams from ...stats import ( apply_test_function, @@ -882,3 +882,43 @@ def test_bayes_factor(): bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0) assert bf_dict0["BF10"] > bf_dict0["BF01"] assert bf_dict1["BF10"] < bf_dict1["BF01"] + +def test_compare_sorting_consistency(): + chains, draws = 4, 1000 + + # Model 1 - good fit + log_lik1 = np.random.normal(-2, 1, size=(chains, draws)) + posterior1 = Dataset( + {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))}, + coords={"chain": range(chains), "draw": range(draws)}, + ) + log_like1 = Dataset( + {"y": (("chain", "draw"), log_lik1)}, + coords={"chain": range(chains), "draw": range(draws)}, + ) + data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1) + + # Model 2 - poor fit (higher variance) + log_lik2 = np.random.normal(-5, 2, size=(chains, draws)) + posterior2 = Dataset( + {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))}, + coords={"chain": range(chains), "draw": range(draws)}, + ) + log_like2 = Dataset( + {"y": (("chain", "draw"), log_lik2)}, + coords={"chain": range(chains), "draw": range(draws)}, + ) + data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2) + + # Compare models in different orders + comp_dict1 = {"M1": data1, "M2": data2} + comp_dict2 = {"M2": data2, "M1": data1} + + comparison1 = compare(comp_dict1, method="bb-pseudo-bma") + comparison2 = compare(comp_dict2, method="bb-pseudo-bma") + + assert comparison1.index.tolist() == comparison2.index.tolist() + + se1 = comparison1["se"].values + se2 = comparison2["se"].values + np.testing.assert_array_almost_equal(se1, se2) \ No newline at end of file From 83d8f6669208228a1773fac82620d2b957c564d5 Mon Sep 17 00:00:00 2001 From: Uzair Gheewala Date: Sat, 11 Jan 2025 18:48:41 -0800 Subject: [PATCH 2/4] Fix pylint issues: remove trailing whitespace and add final newline --- arviz/tests/base_tests/test_stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 8946d4a7d7..5b2cd3fc6d 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -883,7 +883,7 @@ def test_bayes_factor(): assert bf_dict0["BF10"] > bf_dict0["BF01"] assert bf_dict1["BF10"] < bf_dict1["BF01"] -def test_compare_sorting_consistency(): +def test_compare_sorting_consistency(): chains, draws = 4, 1000 # Model 1 - good fit @@ -921,4 +921,4 @@ def test_compare_sorting_consistency(): se1 = comparison1["se"].values se2 = comparison2["se"].values - np.testing.assert_array_almost_equal(se1, se2) \ No newline at end of file + np.testing.assert_array_almost_equal(se1, se2) From 3c7be25c8c0cf1da24c2248d7a2cb59384b8b6f3 Mon Sep 17 00:00:00 2001 From: Uzair Gheewala Date: Sat, 11 Jan 2025 19:06:05 -0800 Subject: [PATCH 3/4] Fix Black formatting: add blank line before function definition --- arviz/tests/base_tests/test_stats.py | 1 + 1 file changed, 1 insertion(+) diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index 5b2cd3fc6d..bc6f4d0739 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -883,6 +883,7 @@ def test_bayes_factor(): assert bf_dict0["BF10"] > bf_dict0["BF01"] assert bf_dict1["BF10"] < bf_dict1["BF01"] + def test_compare_sorting_consistency(): chains, draws = 4, 1000 From 1125f9e83b844e2adf44d184700c02fe487d7840 Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Tue, 27 May 2025 22:10:17 +0200 Subject: [PATCH 4/4] add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33e49bfa7f..b997ae924a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437)) - Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445)) +- Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407)) ### Documentation - Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))