Skip to content

Improve pair_plot's reference_values compatibility, flexibility, and documentation #2438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
## Unreleased

### New features
- `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))

### Maintenance and fixes
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))

### Documentation
- Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))

## v0.21.0 (2025 Mar 06)

Expand Down
66 changes: 18 additions & 48 deletions arviz/plots/backends/bokeh/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
diverging_mask,
divergences_kwargs,
flat_var_names,
flat_ref_slices,
flat_var_labels,
backend_kwargs,
marginal_kwargs,
show,
Expand Down Expand Up @@ -72,50 +74,12 @@
kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)

if reference_values:
reference_values_copy = {}
label = []
for variable in list(reference_values.keys()):
if " " in variable:
variable_copy = variable.replace(" ", "\n", 1)
else:
variable_copy = variable

label.append(variable_copy)
reference_values_copy[variable_copy] = reference_values[variable]

difference = set(flat_var_names).difference(set(label))

if difference:
warn = [diff.replace("\n", " ", 1) for diff in difference]
warnings.warn(
"Argument reference_values does not include reference value for: {}".format(
", ".join(warn)
),
UserWarning,
)

if reference_values:
reference_values_copy = {}
label = []
for variable in list(reference_values.keys()):
if " " in variable:
variable_copy = variable.replace(" ", "\n", 1)
else:
variable_copy = variable

label.append(variable_copy)
reference_values_copy[variable_copy] = reference_values[variable]

difference = set(flat_var_names).difference(set(label))

for dif in difference:
reference_values_copy[dif] = None
difference = set(flat_var_names).difference(set(reference_values.keys()))

if difference:
warn = [dif.replace("\n", " ", 1) for dif in difference]
warnings.warn(
"Argument reference_values does not include reference value for: {}".format(
", ".join(warn)
", ".join(difference)
),
UserWarning,
)
Expand Down Expand Up @@ -262,8 +226,8 @@
**marginal_kwargs,
)

ax[j, i].xaxis.axis_label = flat_var_names[i]
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
ax[j, i].xaxis.axis_label = flat_var_labels[i]
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]

Check warning on line 230 in arviz/plots/backends/bokeh/pairplot.py

View check run for this annotation

Codecov / codecov/patch

arviz/plots/backends/bokeh/pairplot.py#L229-L230

Added lines #L229 - L230 were not covered by tests

elif j + marginals_offset > i:
if "scatter" in kind:
Expand Down Expand Up @@ -346,12 +310,18 @@
ax[-1, -1].add_layout(ax_pe_hline)

if reference_values:
x = reference_values_copy[flat_var_names[j + marginals_offset]]
y = reference_values_copy[flat_var_names[i]]
if x and y:
ax[j, i].scatter(y, x, **reference_values_kwargs)
ax[j, i].xaxis.axis_label = flat_var_names[i]
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
x_name = flat_var_names[j + marginals_offset]
y_name = flat_var_names[i]
if (x_name not in difference) and (y_name not in difference):
ax[j, i].scatter(
np.array(reference_values[y_name])[flat_ref_slices[i]],
np.array(reference_values[x_name])[
flat_ref_slices[j + marginals_offset]
],
**reference_values_kwargs,
)
ax[j, i].xaxis.axis_label = flat_var_labels[i]
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]

show_layout(ax, show)

Expand Down
36 changes: 14 additions & 22 deletions arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def plot_pair(
diverging_mask,
divergences_kwargs,
flat_var_names,
flat_ref_slices,
flat_var_labels,
backend_kwargs,
marginal_kwargs,
show,
Expand Down Expand Up @@ -77,24 +79,12 @@ def plot_pair(
kde_kwargs["contour_kwargs"].setdefault("colors", "k")

if reference_values:
reference_values_copy = {}
label = []
for variable in list(reference_values.keys()):
if " " in variable:
variable_copy = variable.replace(" ", "\n", 1)
else:
variable_copy = variable

label.append(variable_copy)
reference_values_copy[variable_copy] = reference_values[variable]

difference = set(flat_var_names).difference(set(label))
difference = set(flat_var_names).difference(set(reference_values.keys()))

if difference:
warn = [diff.replace("\n", " ", 1) for diff in difference]
warnings.warn(
"Argument reference_values does not include reference value for: {}".format(
", ".join(warn)
", ".join(difference)
),
UserWarning,
)
Expand Down Expand Up @@ -211,12 +201,12 @@ def plot_pair(

if reference_values:
ax.plot(
reference_values_copy[flat_var_names[0]],
reference_values_copy[flat_var_names[1]],
np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
**reference_values_kwargs,
)
ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
ax.tick_params(labelsize=xt_labelsize)

else:
Expand Down Expand Up @@ -336,20 +326,22 @@ def plot_pair(
y_name = flat_var_names[j + not_marginals]
if (x_name not in difference) and (y_name not in difference):
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
np.array(reference_values[x_name])[flat_ref_slices[i]],
np.array(reference_values[y_name])[
flat_ref_slices[j + not_marginals]
],
**reference_values_kwargs,
)

if j != vars_to_plot - 1:
plt.setp(ax[j, i].get_xticklabels(), visible=False)
else:
ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
if i != 0:
plt.setp(ax[j, i].get_yticklabels(), visible=False)
else:
ax[j, i].set_ylabel(
f"{flat_var_names[j + not_marginals]}",
f"{flat_var_labels[j + not_marginals]}",
fontsize=ax_labelsize,
wrap=True,
)
Expand Down
13 changes: 10 additions & 3 deletions arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,14 @@ def plot_pair(
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
)
)
flat_var_names = [
labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
]
flat_var_names = []
flat_ref_slices = []
flat_var_labels = []
for var_name, sel, isel, _ in plotters:
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
flat_var_names.append(var_name)
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))

divergent_data = None
diverging_mask = None
Expand Down Expand Up @@ -253,6 +258,8 @@ def plot_pair(
diverging_mask=diverging_mask,
divergences_kwargs=divergences_kwargs,
flat_var_names=flat_var_names,
flat_ref_slices=flat_ref_slices,
flat_var_labels=flat_var_labels,
backend_kwargs=backend_kwargs,
marginal_kwargs=marginal_kwargs,
show=show,
Expand Down
17 changes: 16 additions & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scipy.stats import norm # pylint: disable=wrong-import-position

from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
from ...labels import MapLabeller # pylint: disable=wrong-import-position
from ...plots import ( # pylint: disable=wrong-import-position
plot_autocorr,
plot_bpv,
Expand Down Expand Up @@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
{"divergences": True, "var_names": ["theta", "mu"]},
{"kind": "kde", "var_names": ["theta"]},
{"kind": "hexbin", "var_names": ["theta"]},
{"kind": "hexbin", "var_names": ["theta"]},
{
"kind": "hexbin",
"var_names": ["theta"],
Expand All @@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
"reference_values": {"mu": 0, "tau": 0},
"reference_values_kwargs": {"line_color": "blue"},
},
{
"var_names": ["mu", "tau"],
"reference_values": {"mu": 0, "tau": 0},
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
},
{
"var_names": ["theta"],
"reference_values": {"theta": [0.0] * 8},
"labeller": MapLabeller({"theta": r"$\theta$"}),
},
{
"var_names": ["theta"],
"reference_values": {"theta": np.zeros(8)},
"labeller": MapLabeller({"theta": r"$\theta$"}),
},
],
)
def test_plot_pair(models, kwargs):
Expand Down
16 changes: 16 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.stats import gaussian_kde, norm

from ...data import from_dict, load_arviz_data
from ...labels import MapLabeller
from ...plots import (
plot_autocorr,
plot_bf,
Expand Down Expand Up @@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
"reference_values": {"mu": 0, "tau": 0},
"reference_values_kwargs": {"c": "C0", "marker": "*"},
},
{
"var_names": ["mu", "tau"],
"reference_values": {"mu": 0, "tau": 0},
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
},
{
"var_names": ["theta"],
"reference_values": {"theta": [0.0] * 8},
"labeller": MapLabeller({"theta": r"$\theta$"}),
},
{
"var_names": ["theta"],
"reference_values": {"theta": np.zeros(8)},
"labeller": MapLabeller({"theta": r"$\theta$"}),
},
],
)
def test_plot_pair(models, kwargs):
Expand Down
25 changes: 25 additions & 0 deletions doc/source/user_guide/plots_arguments_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,28 @@ These are kwargs specific to the backend being used, passed to `matplotlib.pyplo
## `show`

Call backend show function.

(common_reference_values)=
## `reference_values`
`plot_pair` accepts `reference_values` to highlight specific values on the probability distributions. The keys of `reference_values` are the associated variable names in `var_names`. The values are the reference values, which must have the same shape as the coordinates selected for plotting since it is indexed as such. For example, here `theta` must have shape `(2,)` since that is the shape of the selected coordinates on `theta`.

```{code-cell} ipython3
coords = {"school": ["Choate", "Deerfield"]}
reference_values = {
"mu": 0.0,
"theta": np.zeros(2),
}
az.plot_pair(centered_eight, var_names=["mu", "theta"], coords=coords, reference_values=reference_values);
```

When used with `combine_dims`, each reference value along the combined dimension is plotted on the same axis.
```{code-cell} ipython3
coords = {"school": ["Choate", "Deerfield"]}
reference_values = {
"theta": [-5.0, 5.0],
"theta_t": [-2.0, 2.0],
}
az.plot_pair(non_centered_eight, var_names=["theta", "theta_t"], coords=coords, reference_values=reference_values, combine_dims={"school"});
```

The values of the `reference_values` dictionary can be scalars (e.g., `0`) or zero-dimensional `numpy` arrays (e.g., `np.array(0)`) for scalar variables, or anything that can be cast to `np.array` (e.g., `[0.0, 0.0]` or `np.array([0.0, 0.0])`) for multi-dimensional variables.
26 changes: 26 additions & 0 deletions examples/bokeh/bokeh_plot_pair_reference_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Pairplot with Reference Values
==============================
"""

import arviz as az
import numpy as np

data = az.load_arviz_data("centered_eight")

coords = {"school": ["Choate", "Deerfield"]}
reference_values = {
"mu": 0.0,
"theta": np.zeros(2),
}

ax = az.plot_pair(
data,
var_names=["mu", "theta"],
kind=["scatter", "kde"],
kde_kwargs={"fill_last": False},
coords=coords,
reference_values=reference_values,
figsize=(11.5, 5),
backend="bokeh",
)
32 changes: 32 additions & 0 deletions examples/matplotlib/mpl_plot_pair_reference_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Pairplot with Reference Values
==============================
_gallery_category: Distributions
"""

import matplotlib.pyplot as plt

import arviz as az
import numpy as np

az.style.use("arviz-doc")

data = az.load_arviz_data("centered_eight")

coords = {"school": ["Choate", "Deerfield"]}
reference_values = {
"mu": 0.0,
"theta": np.zeros(2),
}

ax = az.plot_pair(
data,
var_names=["mu", "theta"],
kind=["scatter", "kde"],
kde_kwargs={"fill_last": False},
coords=coords,
reference_values=reference_values,
figsize=(11.5, 5),
)

plt.show()