Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 145 additions & 3 deletions autoemulate/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from datetime import datetime
from pathlib import Path
from typing import Literal

import joblib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -591,10 +592,13 @@ def fit_from_reinitialized(
def plot( # noqa: PLR0912, PLR0915
self,
model_obj: int | Emulator | Result,
input_names: list[str] | None = None,
output_names: list[str] | None = None,
input_index: list[int] | int | None = None,
output_index: list[int] | int | None = None,
input_ranges: dict | None = None,
output_ranges: dict | None = None,
error_style: Literal["bars", "fill"] = "bars",
figsize=None,
ncols: int = 3,
fname: str | None = None,
Expand All @@ -607,6 +611,10 @@ def plot( # noqa: PLR0912, PLR0915
model_obj: int | Emulator | Result
The model to plot. Can be an integer ID of a Result, an Emulator instance,
or a Result instance.
input_names: list[str] | None
The names of the input features. If None, generic names are used.
output_names: list[str] | None
The names of the output features. If None, generic names are used.
input_index: int
The index of the input feature to plot against the output.
output_index: int
Expand All @@ -619,6 +627,9 @@ def plot( # noqa: PLR0912, PLR0915
The ranges of the output features to consider for the plot. Ranges are
combined such that the final subset is the intersection data within
the specified ranges. Defaults to None.
error_style: Literal["bars", "fill"]
The style of error representation in the plots. Can be "bars" for error
bars or "fill" for shaded error regions. Defaults to "bars".
figsize: tuple[int, int] | None
The size of the figure to create. If None, it is set based on the number
of input and output features.
Expand Down Expand Up @@ -698,6 +709,26 @@ def plot( # noqa: PLR0912, PLR0915
fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
axs = axs.flatten()

if input_names is not None:
if len(input_names) != n_features:
msg = (
"Length of input_names does not match number of input features. "
f"Expected {n_features}, got {len(input_names)}."
)
raise ValueError(msg)
else:
input_names = [f"$x_{i}$" for i in range(n_features)]

if output_names is not None:
if len(output_names) != n_outputs:
msg = (
"Length of output_names does not match number of outputs. "
f"Expected {n_outputs}, got {len(output_names)}."
)
raise ValueError(msg)
else:
output_names = [f"$y_{i}$" for i in range(n_outputs)]

plot_index = 0
for out_idx in output_index:
for in_idx in input_index:
Expand Down Expand Up @@ -748,10 +779,11 @@ def subset_outputs(x, y, y_p):
y_pred_subset[:, out_idx],
y_variance[:, out_idx] if y_variance is not None else None,
ax=axs[plot_index],
title=f"$x_{in_idx}$ vs. $y_{out_idx}$",
input_index=in_idx,
output_index=out_idx,
title=f"{input_names[in_idx]} vs. {output_names[out_idx]}",
input_label=input_names[in_idx],
output_label=output_names[out_idx],
r2_score=r2_score,
error_style=error_style,
)
plot_index += 1

Expand All @@ -765,6 +797,116 @@ def subset_outputs(x, y, y_p):
fig.savefig(fname, bbox_inches="tight")
return None

def plot_preds( # noqa: PLR0912
self,
model_obj: int | Emulator | Result,
output_names: list[str] | None = None,
figsize=None,
ncols: int = 3,
fname: str | None = None,
):
"""
Plot predicted means (and variances) against observations for all outputs.

Parameters
----------
model_obj: int | Emulator | Result
The model to plot. Can be an integer ID of a Result, an Emulator instance,
or a Result instance.
output_names: list[str] | None
The names of the outputs to use in the plot titles. If None, generic names
like "y_0", "y_1", etc. are used.
figsize: tuple[int, int] | None
The size of the figure to create. If None, it is set based on the number
of outputs.
ncols: int
The maximum number of columns in the subplot grid. Defaults to 3.
fname: str | None
If provided, the figure will be saved to this file path.
"""
result = None
if isinstance(model_obj, int):
if model_obj not in self._id_to_result:
raise ValueError(f"No result found with ID: {model_obj}")
result = self.get_result(model_obj)
model = result.model
elif isinstance(model_obj, Emulator):
model = model_obj
elif isinstance(model_obj, Result):
model = model_obj.model

test_x, test_y = self._convert_to_tensors(self.test)

# Re-run prediction with just this model to get the predictions
y_pred, y_variance = model.predict_mean_and_variance(test_x)
y_std = None
if y_variance is not None:
y_variance, _ = self._convert_to_numpy(y_variance, None)
y_variance = self._ensure_numpy_2d(y_variance)
y_std = np.sqrt(y_variance)

# Convert to numpy for plotting
test_x, test_y = self._convert_to_numpy(test_x, test_y)
assert test_x is not None
assert test_y is not None
assert y_pred is not None
y_pred, _ = self._convert_to_numpy(y_pred, None)
test_x = self._ensure_numpy_2d(test_x)
test_y = self._ensure_numpy_2d(test_y)
y_pred = self._ensure_numpy_2d(y_pred)

# Figure out layout
n_outputs = test_y.shape[1] if test_y.ndim > 1 else 1
nrows, ncols = calculate_subplot_layout(n_outputs, ncols)
if figsize is None:
figsize = (5 * ncols, 4 * nrows)
fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
axs = axs.flatten()

if output_names is not None:
if len(output_names) != n_outputs:
msg = (
"Length of output_names does not match number of outputs. "
f"Expected {n_outputs}, got {len(output_names)}."
)
raise ValueError(msg)
else:
output_names = [f"$y_{i}$" for i in range(n_outputs)]

for i in range(n_outputs):
if y_std is not None:
axs[i].errorbar(
test_y[:, i],
y_pred[:, i],
yerr=2 * y_std[:, i],
fmt="none",
alpha=0.4,
capsize=3,
)
axs[i].scatter(
test_y[:, i],
y_pred[:, i],
alpha=0.6,
linewidth=0.5,
)
axs[i].plot(
[test_y[:, i].min(), test_y[:, i].max()],
[test_y[:, i].min(), test_y[:, i].max()],
linestyle="--",
color="gray",
)
axs[i].set_title(output_names[i])
axs[i].set_xlabel("True values")
axs[i].set_ylabel("Predicted values ±2\u03c3")
plt.tight_layout()

if figsize is not None:
fig.set_size_inches(figsize)
if fname is None:
return display_figure(fig)
fig.savefig(fname, bbox_inches="tight")
return None

def plot_surface(
self,
model: Emulator,
Expand Down
79 changes: 56 additions & 23 deletions autoemulate/core/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import torch
Expand Down Expand Up @@ -49,9 +51,10 @@ def plot_xy(
y_variance: NumpyLike | None = None,
ax: Axes | None = None,
title: str = "xy",
input_index: int | None = None,
output_index: int | None = None,
input_label: str | None = None,
output_label: str | None = None,
r2_score: float | None = None,
error_style: Literal["bars", "fill"] = "bars",
):
"""
Plot observed and predicted values vs. features.
Expand All @@ -70,12 +73,15 @@ def plot_xy(
An optional matplotlib Axes object to plot on.
title: str
An optional title for the plot.
input_index: int | None
An optional index of the input dimension to plot.
output_index: int | None
An optional index of the output dimension to plot.
input_label: str | None
An optional input label to plot.
output_label: str | None
An optional output label to plot.
r2_score: float | None
An option r2 score to include in the plot legend.
error_style: Literal["bars", "fill"]
The style of error representation in the plots. Can be "bars" for error
bars or "fill" for shaded error regions. Defaults to "bars".
"""
# Sort the data
sort_idx = np.argsort(x).flatten()
Expand All @@ -93,18 +99,46 @@ def plot_xy(
assert ax is not None, "ax must be provided"
# Scatter plot with error bars for predictions
if y_std is not None:
ax.errorbar(
x_sorted,
y_pred_sorted,
yerr=2 * y_std,
fmt="o",
color=pred_points_color,
elinewidth=2,
capsize=3,
alpha=0.5,
# use unicode for sigma
label="pred. (±2\u03c3)",
)
if error_style.lower() not in ["bars", "fill"]:
msg = "error_style must be one of ['bars', 'fill']"
raise ValueError(msg)
if error_style.lower() == "bars":
ax.errorbar(
x_sorted,
y_pred_sorted,
yerr=2 * y_std,
fmt="o",
color=pred_points_color,
elinewidth=2,
capsize=3,
alpha=0.5,
# use unicode for sigma
label="pred. (±2\u03c3)",
)
ax.scatter(
x_sorted,
y_pred_sorted,
color=pred_points_color,
edgecolor="black",
linewidth=0.5,
alpha=0.5,
)
else:
ax.fill_between(
x_sorted,
y_pred_sorted - 2 * y_std,
y_pred_sorted + 2 * y_std,
color=pred_points_color,
alpha=0.2,
label="±2\u03c3",
)
ax.plot(
x_sorted,
y_pred_sorted,
color=pred_points_color,
alpha=0.75,
label="pred.",
)
else:
ax.scatter(
x_sorted,
Expand All @@ -126,19 +160,18 @@ def plot_xy(
label="data",
)

ax.set_xlabel(f"$x_{input_index}$", fontsize=13)
ax.set_ylabel(f"$y_{output_index}$", fontsize=13)
x_label = input_label if input_label is not None else "x"
y_label = output_label if output_label is not None else "y"
ax.set_xlabel(x_label, fontsize=13)
ax.set_ylabel(y_label, fontsize=13)
ax.set_title(title, fontsize=13)
ax.grid(True, alpha=0.3)

# Get the handles and labels for the scatter plots
handles, _ = ax.get_legend_handles_labels()

# Add legend and get its bounding box
lbl = "pred." if y_variance is None else "pred. (±2\u03c3)"
legend = ax.legend(
handles[-2:],
["data", lbl],
loc="best",
handletextpad=0,
columnspacing=0,
Expand Down
22 changes: 19 additions & 3 deletions docs/tutorials/emulation/01_quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,30 @@
"metadata": {},
"outputs": [],
"source": [
"ae.plot(best, fname=\"best_model_plot.png\")"
"ae.plot_preds(best, output_names=projectile.output_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also subset the data included in the plots by providing input and output ranges."
"We can also visualise the predictions against each input feature."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ae.plot(best, output_names=projectile.output_names, input_names=projectile.param_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can subset the data included in the feature plots by providing input and output ranges."
]
},
{
Expand All @@ -269,7 +285,7 @@
"metadata": {},
"outputs": [],
"source": [
"ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)\n"
"ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_plot_xy():
# plot without error bars
fig, ax = plt.subplots()
plotting.plot_xy(
X, y, y_pred, None, ax=ax, input_index=1, output_index=2, r2_score=0.5
X, y, y_pred, None, ax=ax, input_label="1", output_label="2", r2_score=0.5
)
# test for error bars
assert len(ax.containers) == 0
Expand All @@ -45,7 +45,7 @@ def test_plot_xy():
# plot with error bars
fig, ax = plt.subplots()
plotting.plot_xy(
X, y, y_pred, y_variance, ax=ax, input_index=1, output_index=2, r2_score=0.5
X, y, y_pred, y_variance, ax=ax, input_label="1", output_label="2", r2_score=0.5
)
assert len(ax.containers) > 0
assert len(ax.collections) > 0
Expand Down