Skip to content

Commit 2e4cfcf

Browse files
Fix plot_lm handling of multidimensional data with y_model (#2408)
* Fix plot_lm handling of multidimensional data with y_model * Fix pylint, black issues * Fix bokeh backend color handling in plot_lm HDI plots * Fix pylint * Fix bokeh backend color handling in plot_hdi HDI plots * Skip edge case plotting for HDI case in bokeh/lmplot * use fixtures for tests * black * add to changelog --------- Co-authored-by: Oriol (ProDesk) <[email protected]>
1 parent 3205b82 commit 2e4cfcf

File tree

7 files changed

+182
-33
lines changed

7 files changed

+182
-33
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
### Maintenance and fixes
1111
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
12-
12+
- Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
1313
- Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
1414

1515
### Documentation

arviz/plots/backends/bokeh/hdiplot.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
2121
plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
2222
plot_kwargs.setdefault("alpha", 0)
2323

24-
fill_kwargs = {} if fill_kwargs is None else fill_kwargs
25-
fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
26-
fill_kwargs.setdefault("alpha", 0.5)
24+
fill_kwargs = {} if fill_kwargs is None else fill_kwargs.copy()
25+
# Convert matplotlib color to bokeh fill_color if needed
26+
if "color" in fill_kwargs and "fill_color" not in fill_kwargs:
27+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.pop("color"))
28+
else:
29+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.get("fill_color", color))
30+
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0.5))
2731

2832
figsize, *_ = _scale_fig_size(figsize, None)
2933

@@ -38,9 +42,6 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
3842
plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
3943
plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))
4044

41-
fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
42-
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))
43-
4445
ax.patch(
4546
np.concatenate((x_data, x_data[::-1])),
4647
np.concatenate((y_data[:, 0], y_data[:, 1][::-1])),

arviz/plots/backends/bokeh/lmplot.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ def plot_lm(
6868

6969
if y_hat_fill_kwargs is None:
7070
y_hat_fill_kwargs = {}
71-
y_hat_fill_kwargs.setdefault("color", "orange")
71+
else:
72+
y_hat_fill_kwargs = y_hat_fill_kwargs.copy()
73+
# Convert matplotlib color to bokeh fill_color if needed
74+
if "color" in y_hat_fill_kwargs and "fill_color" not in y_hat_fill_kwargs:
75+
y_hat_fill_kwargs["fill_color"] = y_hat_fill_kwargs.pop("color")
76+
y_hat_fill_kwargs.setdefault("fill_color", "orange")
77+
y_hat_fill_kwargs.setdefault("fill_alpha", 0.5)
7278

7379
if y_model_plot_kwargs is None:
7480
y_model_plot_kwargs = {}
@@ -78,8 +84,13 @@ def plot_lm(
7884

7985
if y_model_fill_kwargs is None:
8086
y_model_fill_kwargs = {}
81-
y_model_fill_kwargs.setdefault("color", "black")
82-
y_model_fill_kwargs.setdefault("alpha", 0.5)
87+
else:
88+
y_model_fill_kwargs = y_model_fill_kwargs.copy()
89+
# Convert matplotlib color to bokeh fill_color if needed
90+
if "color" in y_model_fill_kwargs and "fill_color" not in y_model_fill_kwargs:
91+
y_model_fill_kwargs["fill_color"] = y_model_fill_kwargs.pop("color")
92+
y_model_fill_kwargs.setdefault("fill_color", "black")
93+
y_model_fill_kwargs.setdefault("fill_alpha", 0.5)
8394

8495
if y_model_mean_kwargs is None:
8596
y_model_mean_kwargs = {}
@@ -149,6 +160,11 @@ def plot_lm(
149160
)
150161

151162
y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
163+
# Plot mean line across all x values instead of just edges
164+
mean_legend = ax_i.line(x_plotters, y_model_mean, **y_model_mean_kwargs)
165+
legend_it.append(("Mean", [mean_legend]))
166+
continue # Skip the edge plotting since we plotted full line
167+
152168
x_plotters_edge = [min(x_plotters), max(x_plotters)]
153169
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
154170
mean_legend = ax_i.line(x_plotters_edge, y_model_mean_edge, **y_model_mean_kwargs)

arviz/plots/backends/matplotlib/lmplot.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,29 @@ def plot_lm(
115115

116116
if y_model is not None:
117117
_, _, _, y_model_plotters = y_model[i]
118+
118119
if kind_model == "lines":
119-
for j in range(num_samples):
120-
ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs)
121-
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
120+
# y_model_plotters should be (points, samples)
121+
y_points = y_model_plotters.shape[0]
122+
if x_plotters.shape[0] == y_points:
123+
for j in range(num_samples):
124+
ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
125+
126+
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
127+
y_model_mean = np.mean(y_model_plotters, axis=1)
128+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
122129

123-
y_model_mean = np.mean(y_model_plotters, axis=1)
124130
else:
125131
plot_hdi(
126132
x_plotters,
127133
y_model_plotters,
128134
fill_kwargs=y_model_fill_kwargs,
129135
ax=ax_i,
130136
)
131-
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
132137

133-
y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
134-
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
138+
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
139+
y_model_mean = np.mean(y_model_plotters, axis=0)
140+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
135141

136142
if legend:
137143
ax_i.legend(fontsize=xt_labelsize, loc="upper left")

arviz/plots/lmplot.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,47 @@ def plot_lm(
300300
# Filter out the required values to generate plotters
301301
if y_model is not None:
302302
if kind_model == "lines":
303-
y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
304-
305-
y_model = [
306-
tup
307-
for _, tup in zip(
308-
range(len_y),
309-
xarray_var_iter(
310-
y_model,
311-
skip_dims=set(y_model.dims),
312-
combined=True,
313-
),
314-
)
315-
]
316-
y_model = _repeat_flatten_list(y_model, len_x)
303+
var_name = y_model.name if y_model.name else "y_model"
304+
data = y_model.values
305+
306+
total_samples = data.shape[0] * data.shape[1]
307+
data = data.reshape(total_samples, *data.shape[2:])
308+
309+
if pp_sample_ix is not None:
310+
data = data[pp_sample_ix]
311+
312+
if plot_dim is not None:
313+
# For plot_dim case, transpose to get dimension first
314+
data = data.transpose(1, 0, 2)[..., 0]
315+
316+
# Create plotter tuple(s)
317+
if plot_dim is not None:
318+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
319+
else:
320+
y_model = [(var_name, {}, {}, data)]
321+
y_model = _repeat_flatten_list(y_model, len_x)
322+
323+
elif kind_model == "hdi":
324+
var_name = y_model.name if y_model.name else "y_model"
325+
data = y_model.values
326+
327+
if plot_dim is not None:
328+
# First transpose to get plot_dim first
329+
data = data.transpose(2, 0, 1, 3)
330+
# For plot_dim case, we just want HDI for first dimension
331+
data = data[..., 0]
332+
333+
# Reshape to (samples, points)
334+
data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
335+
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
336+
337+
else:
338+
data = data.reshape(-1, data.shape[-1])
339+
y_model = [(var_name, {}, {}, data)]
340+
y_model = _repeat_flatten_list(y_model, len_x)
341+
342+
if len(y_model) == 1:
343+
y_model = _repeat_flatten_list(y_model, len_x)
317344

318345
rows, cols = default_grid(length_plotters)
319346

arviz/tests/base_tests/test_plots_bokeh.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ def test_plot_dot_rotated(continuous_model, kwargs):
12161216
},
12171217
],
12181218
)
1219-
def test_plot_lm(models, kwargs):
1219+
def test_plot_lm_1d(models, kwargs):
12201220
"""Test functionality for 1D data."""
12211221
idata = models.model_1
12221222
if "constant_data" not in idata.groups():
@@ -1243,3 +1243,46 @@ def test_plot_lm_list():
12431243
"""Test the plots when input data is list or ndarray."""
12441244
y = [1, 2, 3, 4, 5]
12451245
assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
1246+
1247+
1248+
def generate_lm_1d_data():
1249+
rng = np.random.default_rng()
1250+
return from_dict(
1251+
observed_data={"y": rng.normal(size=7)},
1252+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
1253+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
1254+
dims={"y": ["dim1"]},
1255+
coords={"dim1": range(7)},
1256+
)
1257+
1258+
1259+
def generate_lm_2d_data():
1260+
rng = np.random.default_rng()
1261+
return from_dict(
1262+
observed_data={"y": rng.normal(size=(5, 7))},
1263+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
1264+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
1265+
dims={"y": ["dim1", "dim2"]},
1266+
coords={"dim1": range(5), "dim2": range(7)},
1267+
)
1268+
1269+
1270+
@pytest.mark.parametrize("data", ("1d", "2d"))
1271+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
1272+
@pytest.mark.parametrize("use_y_model", (True, False))
1273+
def test_plot_lm(data, kind, use_y_model):
1274+
if data == "1d":
1275+
idata = generate_lm_1d_data()
1276+
else:
1277+
idata = generate_lm_2d_data()
1278+
1279+
kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
1280+
if data == "2d":
1281+
kwargs["plot_dim"] = "dim1"
1282+
if use_y_model:
1283+
kwargs["y_model"] = "y_model"
1284+
if kind == "lines":
1285+
kwargs["num_samples"] = 50
1286+
1287+
ax = plot_lm(**kwargs)
1288+
assert ax is not None

arviz/tests/base_tests/test_plots_matplotlib.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1930,7 +1930,7 @@ def test_wilkinson_algorithm(continuous_model):
19301930
},
19311931
],
19321932
)
1933-
def test_plot_lm(models, kwargs):
1933+
def test_plot_lm_1d(models, kwargs):
19341934
"""Test functionality for 1D data."""
19351935
idata = models.model_1
19361936
if "constant_data" not in idata.groups():
@@ -2118,3 +2118,59 @@ def test_plot_bf():
21182118
)
21192119
_, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
21202120
assert bf_plot is not None
2121+
2122+
2123+
def generate_lm_1d_data():
2124+
rng = np.random.default_rng()
2125+
return from_dict(
2126+
observed_data={"y": rng.normal(size=7)},
2127+
posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
2128+
posterior={"y_model": rng.normal(size=(4, 1000, 7))},
2129+
dims={"y": ["dim1"]},
2130+
coords={"dim1": range(7)},
2131+
)
2132+
2133+
2134+
def generate_lm_2d_data():
2135+
rng = np.random.default_rng()
2136+
return from_dict(
2137+
observed_data={"y": rng.normal(size=(5, 7))},
2138+
posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
2139+
posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
2140+
dims={"y": ["dim1", "dim2"]},
2141+
coords={"dim1": range(5), "dim2": range(7)},
2142+
)
2143+
2144+
2145+
@pytest.mark.parametrize("data", ("1d", "2d"))
2146+
@pytest.mark.parametrize("kind", ("lines", "hdi"))
2147+
@pytest.mark.parametrize("use_y_model", (True, False))
2148+
def test_plot_lm(data, kind, use_y_model):
2149+
if data == "1d":
2150+
idata = generate_lm_1d_data()
2151+
else:
2152+
idata = generate_lm_2d_data()
2153+
2154+
# test_cases = [
2155+
# # Single dimensional cases
2156+
# (data_1d, None, "lines", True, 50), # y_model with lines, default samples
2157+
# (data_1d, None, "hdi", True, None), # y_model with hdi, no samples needed
2158+
# (data_1d, None, "lines", False, 50), # without y_model, lines
2159+
# (data_1d, None, "hdi", False, None), # without y_model, hdi
2160+
# # Multi-dimensional cases with plot_dim
2161+
# (data_2d, "dim1", "lines", True, 20), # y_model with lines, fewer samples
2162+
# (data_2d, "dim1", "hdi", True, None), # y_model with hdi
2163+
# (data_2d, "dim1", "lines", False, 50), # without y_model, lines
2164+
# (data_2d, "dim1", "hdi", False, None), # without y_model, hdi
2165+
# ]
2166+
2167+
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
2168+
if data == "2d":
2169+
kwargs["plot_dim"] = "dim1"
2170+
if use_y_model:
2171+
kwargs["y_model"] = "y_model"
2172+
if kind == "lines":
2173+
kwargs["num_samples"] = 50
2174+
2175+
ax = plot_lm(**kwargs)
2176+
assert ax is not None

0 commit comments

Comments
 (0)