Skip to content

Commit 3bc0d86

Browse files
authored
Merge pull request #907 from alan-turing-institute/add_plot
Update plots
2 parents 0312414 + 1f31713 commit 3bc0d86

File tree

4 files changed

+222
-31
lines changed

4 files changed

+222
-31
lines changed

autoemulate/core/compare.py

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from datetime import datetime
44
from pathlib import Path
5+
from typing import Literal
56

67
import joblib
78
import matplotlib.pyplot as plt
@@ -621,10 +622,13 @@ def fit_from_reinitialized(
621622
def plot( # noqa: PLR0912, PLR0915
622623
self,
623624
model_obj: int | Emulator | Result,
625+
input_names: list[str] | None = None,
626+
output_names: list[str] | None = None,
624627
input_index: list[int] | int | None = None,
625628
output_index: list[int] | int | None = None,
626629
input_ranges: dict | None = None,
627630
output_ranges: dict | None = None,
631+
error_style: Literal["bars", "fill"] = "bars",
628632
figsize=None,
629633
ncols: int = 3,
630634
fname: str | None = None,
@@ -637,6 +641,10 @@ def plot( # noqa: PLR0912, PLR0915
637641
model_obj: int | Emulator | Result
638642
The model to plot. Can be an integer ID of a Result, an Emulator instance,
639643
or a Result instance.
644+
input_names: list[str] | None
645+
The names of the input features. If None, generic names are used.
646+
output_names: list[str] | None
647+
The names of the output features. If None, generic names are used.
640648
input_index: int
641649
The index of the input feature to plot against the output.
642650
output_index: int
@@ -649,6 +657,9 @@ def plot( # noqa: PLR0912, PLR0915
649657
The ranges of the output features to consider for the plot. Ranges are
650658
combined such that the final subset is the intersection data within
651659
the specified ranges. Defaults to None.
660+
error_style: Literal["bars", "fill"]
661+
The style of error representation in the plots. Can be "bars" for error
662+
bars or "fill" for shaded error regions. Defaults to "bars".
652663
figsize: tuple[int, int] | None
653664
The size of the figure to create. If None, it is set based on the number
654665
of input and output features.
@@ -728,6 +739,26 @@ def plot( # noqa: PLR0912, PLR0915
728739
fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
729740
axs = axs.flatten()
730741

742+
if input_names is not None:
743+
if len(input_names) != n_features:
744+
msg = (
745+
"Length of input_names does not match number of input features. "
746+
f"Expected {n_features}, got {len(input_names)}."
747+
)
748+
raise ValueError(msg)
749+
else:
750+
input_names = [f"$x_{i}$" for i in range(n_features)]
751+
752+
if output_names is not None:
753+
if len(output_names) != n_outputs:
754+
msg = (
755+
"Length of output_names does not match number of outputs. "
756+
f"Expected {n_outputs}, got {len(output_names)}."
757+
)
758+
raise ValueError(msg)
759+
else:
760+
output_names = [f"$y_{i}$" for i in range(n_outputs)]
761+
731762
plot_index = 0
732763
for out_idx in output_index:
733764
for in_idx in input_index:
@@ -778,10 +809,11 @@ def subset_outputs(x, y, y_p):
778809
y_pred_subset[:, out_idx],
779810
y_variance[:, out_idx] if y_variance is not None else None,
780811
ax=axs[plot_index],
781-
title=f"$x_{in_idx}$ vs. $y_{out_idx}$",
782-
input_index=in_idx,
783-
output_index=out_idx,
812+
title=f"{input_names[in_idx]} vs. {output_names[out_idx]}",
813+
input_label=input_names[in_idx],
814+
output_label=output_names[out_idx],
784815
r2_score=r2_score,
816+
error_style=error_style,
785817
)
786818
plot_index += 1
787819

@@ -795,6 +827,116 @@ def subset_outputs(x, y, y_p):
795827
fig.savefig(fname, bbox_inches="tight")
796828
return None
797829

830+
def plot_preds( # noqa: PLR0912
831+
self,
832+
model_obj: int | Emulator | Result,
833+
output_names: list[str] | None = None,
834+
figsize=None,
835+
ncols: int = 3,
836+
fname: str | None = None,
837+
):
838+
"""
839+
Plot predicted means (and variances) against observations for all outputs.
840+
841+
Parameters
842+
----------
843+
model_obj: int | Emulator | Result
844+
The model to plot. Can be an integer ID of a Result, an Emulator instance,
845+
or a Result instance.
846+
output_names: list[str] | None
847+
The names of the outputs to use in the plot titles. If None, generic names
848+
like "y_0", "y_1", etc. are used.
849+
figsize: tuple[int, int] | None
850+
The size of the figure to create. If None, it is set based on the number
851+
of outputs.
852+
ncols: int
853+
The maximum number of columns in the subplot grid. Defaults to 3.
854+
fname: str | None
855+
If provided, the figure will be saved to this file path.
856+
"""
857+
result = None
858+
if isinstance(model_obj, int):
859+
if model_obj not in self._id_to_result:
860+
raise ValueError(f"No result found with ID: {model_obj}")
861+
result = self.get_result(model_obj)
862+
model = result.model
863+
elif isinstance(model_obj, Emulator):
864+
model = model_obj
865+
elif isinstance(model_obj, Result):
866+
model = model_obj.model
867+
868+
test_x, test_y = self._convert_to_tensors(self.test)
869+
870+
# Re-run prediction with just this model to get the predictions
871+
y_pred, y_variance = model.predict_mean_and_variance(test_x)
872+
y_std = None
873+
if y_variance is not None:
874+
y_variance, _ = self._convert_to_numpy(y_variance, None)
875+
y_variance = self._ensure_numpy_2d(y_variance)
876+
y_std = np.sqrt(y_variance)
877+
878+
# Convert to numpy for plotting
879+
test_x, test_y = self._convert_to_numpy(test_x, test_y)
880+
assert test_x is not None
881+
assert test_y is not None
882+
assert y_pred is not None
883+
y_pred, _ = self._convert_to_numpy(y_pred, None)
884+
test_x = self._ensure_numpy_2d(test_x)
885+
test_y = self._ensure_numpy_2d(test_y)
886+
y_pred = self._ensure_numpy_2d(y_pred)
887+
888+
# Figure out layout
889+
n_outputs = test_y.shape[1] if test_y.ndim > 1 else 1
890+
nrows, ncols = calculate_subplot_layout(n_outputs, ncols)
891+
if figsize is None:
892+
figsize = (5 * ncols, 4 * nrows)
893+
fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
894+
axs = axs.flatten()
895+
896+
if output_names is not None:
897+
if len(output_names) != n_outputs:
898+
msg = (
899+
"Length of output_names does not match number of outputs. "
900+
f"Expected {n_outputs}, got {len(output_names)}."
901+
)
902+
raise ValueError(msg)
903+
else:
904+
output_names = [f"$y_{i}$" for i in range(n_outputs)]
905+
906+
for i in range(n_outputs):
907+
if y_std is not None:
908+
axs[i].errorbar(
909+
test_y[:, i],
910+
y_pred[:, i],
911+
yerr=2 * y_std[:, i],
912+
fmt="none",
913+
alpha=0.4,
914+
capsize=3,
915+
)
916+
axs[i].scatter(
917+
test_y[:, i],
918+
y_pred[:, i],
919+
alpha=0.6,
920+
linewidth=0.5,
921+
)
922+
axs[i].plot(
923+
[test_y[:, i].min(), test_y[:, i].max()],
924+
[test_y[:, i].min(), test_y[:, i].max()],
925+
linestyle="--",
926+
color="gray",
927+
)
928+
axs[i].set_title(output_names[i])
929+
axs[i].set_xlabel("True values")
930+
axs[i].set_ylabel("Predicted values ±2\u03c3")
931+
plt.tight_layout()
932+
933+
if figsize is not None:
934+
fig.set_size_inches(figsize)
935+
if fname is None:
936+
return display_figure(fig)
937+
fig.savefig(fname, bbox_inches="tight")
938+
return None
939+
798940
def plot_surface(
799941
self,
800942
model: Emulator,

autoemulate/core/plotting.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Literal
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
35
import torch
@@ -49,9 +51,10 @@ def plot_xy(
4951
y_variance: NumpyLike | None = None,
5052
ax: Axes | None = None,
5153
title: str = "xy",
52-
input_index: int | None = None,
53-
output_index: int | None = None,
54+
input_label: str | None = None,
55+
output_label: str | None = None,
5456
r2_score: float | None = None,
57+
error_style: Literal["bars", "fill"] = "bars",
5558
):
5659
"""
5760
Plot observed and predicted values vs. features.
@@ -70,12 +73,15 @@ def plot_xy(
7073
An optional matplotlib Axes object to plot on.
7174
title: str
7275
An optional title for the plot.
73-
input_index: int | None
74-
An optional index of the input dimension to plot.
75-
output_index: int | None
76-
An optional index of the output dimension to plot.
76+
input_label: str | None
77+
An optional input label to plot.
78+
output_label: str | None
79+
An optional output label to plot.
7780
r2_score: float | None
7881
An option r2 score to include in the plot legend.
82+
error_style: Literal["bars", "fill"]
83+
The style of error representation in the plots. Can be "bars" for error
84+
bars or "fill" for shaded error regions. Defaults to "bars".
7985
"""
8086
# Sort the data
8187
sort_idx = np.argsort(x).flatten()
@@ -93,18 +99,46 @@ def plot_xy(
9399
assert ax is not None, "ax must be provided"
94100
# Scatter plot with error bars for predictions
95101
if y_std is not None:
96-
ax.errorbar(
97-
x_sorted,
98-
y_pred_sorted,
99-
yerr=2 * y_std,
100-
fmt="o",
101-
color=pred_points_color,
102-
elinewidth=2,
103-
capsize=3,
104-
alpha=0.5,
105-
# use unicode for sigma
106-
label="pred. (±2\u03c3)",
107-
)
102+
if error_style.lower() not in ["bars", "fill"]:
103+
msg = "error_style must be one of ['bars', 'fill']"
104+
raise ValueError(msg)
105+
if error_style.lower() == "bars":
106+
ax.errorbar(
107+
x_sorted,
108+
y_pred_sorted,
109+
yerr=2 * y_std,
110+
fmt="o",
111+
color=pred_points_color,
112+
elinewidth=2,
113+
capsize=3,
114+
alpha=0.5,
115+
# use unicode for sigma
116+
label="pred. (±2\u03c3)",
117+
)
118+
ax.scatter(
119+
x_sorted,
120+
y_pred_sorted,
121+
color=pred_points_color,
122+
edgecolor="black",
123+
linewidth=0.5,
124+
alpha=0.5,
125+
)
126+
else:
127+
ax.fill_between(
128+
x_sorted,
129+
y_pred_sorted - 2 * y_std,
130+
y_pred_sorted + 2 * y_std,
131+
color=pred_points_color,
132+
alpha=0.2,
133+
label="±2\u03c3",
134+
)
135+
ax.plot(
136+
x_sorted,
137+
y_pred_sorted,
138+
color=pred_points_color,
139+
alpha=0.75,
140+
label="pred.",
141+
)
108142
else:
109143
ax.scatter(
110144
x_sorted,
@@ -126,19 +160,18 @@ def plot_xy(
126160
label="data",
127161
)
128162

129-
ax.set_xlabel(f"$x_{input_index}$", fontsize=13)
130-
ax.set_ylabel(f"$y_{output_index}$", fontsize=13)
163+
x_label = input_label if input_label is not None else "x"
164+
y_label = output_label if output_label is not None else "y"
165+
ax.set_xlabel(x_label, fontsize=13)
166+
ax.set_ylabel(y_label, fontsize=13)
131167
ax.set_title(title, fontsize=13)
132168
ax.grid(True, alpha=0.3)
133169

134170
# Get the handles and labels for the scatter plots
135171
handles, _ = ax.get_legend_handles_labels()
136172

137173
# Add legend and get its bounding box
138-
lbl = "pred." if y_variance is None else "pred. (±2\u03c3)"
139174
legend = ax.legend(
140-
handles[-2:],
141-
["data", lbl],
142175
loc="best",
143176
handletextpad=0,
144177
columnspacing=0,

docs/tutorials/emulation/01_quickstart.ipynb

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,30 @@
272272
"metadata": {},
273273
"outputs": [],
274274
"source": [
275-
"ae.plot(best, fname=\"best_model_plot.png\")"
275+
"ae.plot_preds(best, output_names=projectile.output_names)"
276276
]
277277
},
278278
{
279279
"cell_type": "markdown",
280280
"metadata": {},
281281
"source": [
282-
"We can also subset the data included in the plots by providing input and output ranges."
282+
"We can also visualise the predictions against each input feature."
283+
]
284+
},
285+
{
286+
"cell_type": "code",
287+
"execution_count": null,
288+
"metadata": {},
289+
"outputs": [],
290+
"source": [
291+
"ae.plot(best, output_names=projectile.output_names, input_names=projectile.param_names)"
292+
]
293+
},
294+
{
295+
"cell_type": "markdown",
296+
"metadata": {},
297+
"source": [
298+
"We can subset the data included in the feature plots by providing input and output ranges."
283299
]
284300
},
285301
{
@@ -306,7 +322,7 @@
306322
"metadata": {},
307323
"outputs": [],
308324
"source": [
309-
"ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)\n"
325+
"ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)"
310326
]
311327
},
312328
{

tests/core/test_plotting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_plot_xy():
3535
# plot without error bars
3636
fig, ax = plt.subplots()
3737
plotting.plot_xy(
38-
X, y, y_pred, None, ax=ax, input_index=1, output_index=2, r2_score=0.5
38+
X, y, y_pred, None, ax=ax, input_label="1", output_label="2", r2_score=0.5
3939
)
4040
# test for error bars
4141
assert len(ax.containers) == 0
@@ -45,7 +45,7 @@ def test_plot_xy():
4545
# plot with error bars
4646
fig, ax = plt.subplots()
4747
plotting.plot_xy(
48-
X, y, y_pred, y_variance, ax=ax, input_index=1, output_index=2, r2_score=0.5
48+
X, y, y_pred, y_variance, ax=ax, input_label="1", output_label="2", r2_score=0.5
4949
)
5050
assert len(ax.containers) > 0
5151
assert len(ax.collections) > 0

0 commit comments

Comments
 (0)