diff --git a/autoemulate/core/compare.py b/autoemulate/core/compare.py index 9eb3c2453..a40297a3f 100644 --- a/autoemulate/core/compare.py +++ b/autoemulate/core/compare.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime from pathlib import Path +from typing import Literal import joblib import matplotlib.pyplot as plt @@ -591,10 +592,13 @@ def fit_from_reinitialized( def plot( # noqa: PLR0912, PLR0915 self, model_obj: int | Emulator | Result, + input_names: list[str] | None = None, + output_names: list[str] | None = None, input_index: list[int] | int | None = None, output_index: list[int] | int | None = None, input_ranges: dict | None = None, output_ranges: dict | None = None, + error_style: Literal["bars", "fill"] = "bars", figsize=None, ncols: int = 3, fname: str | None = None, @@ -607,6 +611,10 @@ def plot( # noqa: PLR0912, PLR0915 model_obj: int | Emulator | Result The model to plot. Can be an integer ID of a Result, an Emulator instance, or a Result instance. + input_names: list[str] | None + The names of the input features. If None, generic names are used. + output_names: list[str] | None + The names of the output features. If None, generic names are used. input_index: int The index of the input feature to plot against the output. output_index: int @@ -619,6 +627,9 @@ def plot( # noqa: PLR0912, PLR0915 The ranges of the output features to consider for the plot. Ranges are combined such that the final subset is the intersection data within the specified ranges. Defaults to None. + error_style: Literal["bars", "fill"] + The style of error representation in the plots. Can be "bars" for error + bars or "fill" for shaded error regions. Defaults to "bars". figsize: tuple[int, int] | None The size of the figure to create. If None, it is set based on the number of input and output features. @@ -698,6 +709,26 @@ def plot( # noqa: PLR0912, PLR0915 fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) axs = axs.flatten() + if input_names is not None: + if len(input_names) != n_features: + msg = ( + "Length of input_names does not match number of input features. " + f"Expected {n_features}, got {len(input_names)}." + ) + raise ValueError(msg) + else: + input_names = [f"$x_{i}$" for i in range(n_features)] + + if output_names is not None: + if len(output_names) != n_outputs: + msg = ( + "Length of output_names does not match number of outputs. " + f"Expected {n_outputs}, got {len(output_names)}." + ) + raise ValueError(msg) + else: + output_names = [f"$y_{i}$" for i in range(n_outputs)] + plot_index = 0 for out_idx in output_index: for in_idx in input_index: @@ -748,10 +779,11 @@ def subset_outputs(x, y, y_p): y_pred_subset[:, out_idx], y_variance[:, out_idx] if y_variance is not None else None, ax=axs[plot_index], - title=f"$x_{in_idx}$ vs. $y_{out_idx}$", - input_index=in_idx, - output_index=out_idx, + title=f"{input_names[in_idx]} vs. {output_names[out_idx]}", + input_label=input_names[in_idx], + output_label=output_names[out_idx], r2_score=r2_score, + error_style=error_style, ) plot_index += 1 @@ -765,6 +797,116 @@ def subset_outputs(x, y, y_p): fig.savefig(fname, bbox_inches="tight") return None + def plot_preds( # noqa: PLR0912 + self, + model_obj: int | Emulator | Result, + output_names: list[str] | None = None, + figsize=None, + ncols: int = 3, + fname: str | None = None, + ): + """ + Plot predicted means (and variances) against observations for all outputs. + + Parameters + ---------- + model_obj: int | Emulator | Result + The model to plot. Can be an integer ID of a Result, an Emulator instance, + or a Result instance. + output_names: list[str] | None + The names of the outputs to use in the plot titles. If None, generic names + like "y_0", "y_1", etc. are used. + figsize: tuple[int, int] | None + The size of the figure to create. If None, it is set based on the number + of outputs. + ncols: int + The maximum number of columns in the subplot grid. Defaults to 3. + fname: str | None + If provided, the figure will be saved to this file path. + """ + result = None + if isinstance(model_obj, int): + if model_obj not in self._id_to_result: + raise ValueError(f"No result found with ID: {model_obj}") + result = self.get_result(model_obj) + model = result.model + elif isinstance(model_obj, Emulator): + model = model_obj + elif isinstance(model_obj, Result): + model = model_obj.model + + test_x, test_y = self._convert_to_tensors(self.test) + + # Re-run prediction with just this model to get the predictions + y_pred, y_variance = model.predict_mean_and_variance(test_x) + y_std = None + if y_variance is not None: + y_variance, _ = self._convert_to_numpy(y_variance, None) + y_variance = self._ensure_numpy_2d(y_variance) + y_std = np.sqrt(y_variance) + + # Convert to numpy for plotting + test_x, test_y = self._convert_to_numpy(test_x, test_y) + assert test_x is not None + assert test_y is not None + assert y_pred is not None + y_pred, _ = self._convert_to_numpy(y_pred, None) + test_x = self._ensure_numpy_2d(test_x) + test_y = self._ensure_numpy_2d(test_y) + y_pred = self._ensure_numpy_2d(y_pred) + + # Figure out layout + n_outputs = test_y.shape[1] if test_y.ndim > 1 else 1 + nrows, ncols = calculate_subplot_layout(n_outputs, ncols) + if figsize is None: + figsize = (5 * ncols, 4 * nrows) + fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) + axs = axs.flatten() + + if output_names is not None: + if len(output_names) != n_outputs: + msg = ( + "Length of output_names does not match number of outputs. " + f"Expected {n_outputs}, got {len(output_names)}." + ) + raise ValueError(msg) + else: + output_names = [f"$y_{i}$" for i in range(n_outputs)] + + for i in range(n_outputs): + if y_std is not None: + axs[i].errorbar( + test_y[:, i], + y_pred[:, i], + yerr=2 * y_std[:, i], + fmt="none", + alpha=0.4, + capsize=3, + ) + axs[i].scatter( + test_y[:, i], + y_pred[:, i], + alpha=0.6, + linewidth=0.5, + ) + axs[i].plot( + [test_y[:, i].min(), test_y[:, i].max()], + [test_y[:, i].min(), test_y[:, i].max()], + linestyle="--", + color="gray", + ) + axs[i].set_title(output_names[i]) + axs[i].set_xlabel("True values") + axs[i].set_ylabel("Predicted values ±2\u03c3") + plt.tight_layout() + + if figsize is not None: + fig.set_size_inches(figsize) + if fname is None: + return display_figure(fig) + fig.savefig(fname, bbox_inches="tight") + return None + def plot_surface( self, model: Emulator, diff --git a/autoemulate/core/plotting.py b/autoemulate/core/plotting.py index c206edb27..1d6082345 100644 --- a/autoemulate/core/plotting.py +++ b/autoemulate/core/plotting.py @@ -1,3 +1,5 @@ +from typing import Literal + import matplotlib.pyplot as plt import numpy as np import torch @@ -49,9 +51,10 @@ def plot_xy( y_variance: NumpyLike | None = None, ax: Axes | None = None, title: str = "xy", - input_index: int | None = None, - output_index: int | None = None, + input_label: str | None = None, + output_label: str | None = None, r2_score: float | None = None, + error_style: Literal["bars", "fill"] = "bars", ): """ Plot observed and predicted values vs. features. @@ -70,12 +73,15 @@ def plot_xy( An optional matplotlib Axes object to plot on. title: str An optional title for the plot. - input_index: int | None - An optional index of the input dimension to plot. - output_index: int | None - An optional index of the output dimension to plot. + input_label: str | None + An optional input label to plot. + output_label: str | None + An optional output label to plot. r2_score: float | None An option r2 score to include in the plot legend. + error_style: Literal["bars", "fill"] + The style of error representation in the plots. Can be "bars" for error + bars or "fill" for shaded error regions. Defaults to "bars". """ # Sort the data sort_idx = np.argsort(x).flatten() @@ -93,18 +99,46 @@ def plot_xy( assert ax is not None, "ax must be provided" # Scatter plot with error bars for predictions if y_std is not None: - ax.errorbar( - x_sorted, - y_pred_sorted, - yerr=2 * y_std, - fmt="o", - color=pred_points_color, - elinewidth=2, - capsize=3, - alpha=0.5, - # use unicode for sigma - label="pred. (±2\u03c3)", - ) + if error_style.lower() not in ["bars", "fill"]: + msg = "error_style must be one of ['bars', 'fill']" + raise ValueError(msg) + if error_style.lower() == "bars": + ax.errorbar( + x_sorted, + y_pred_sorted, + yerr=2 * y_std, + fmt="o", + color=pred_points_color, + elinewidth=2, + capsize=3, + alpha=0.5, + # use unicode for sigma + label="pred. (±2\u03c3)", + ) + ax.scatter( + x_sorted, + y_pred_sorted, + color=pred_points_color, + edgecolor="black", + linewidth=0.5, + alpha=0.5, + ) + else: + ax.fill_between( + x_sorted, + y_pred_sorted - 2 * y_std, + y_pred_sorted + 2 * y_std, + color=pred_points_color, + alpha=0.2, + label="±2\u03c3", + ) + ax.plot( + x_sorted, + y_pred_sorted, + color=pred_points_color, + alpha=0.75, + label="pred.", + ) else: ax.scatter( x_sorted, @@ -126,8 +160,10 @@ def plot_xy( label="data", ) - ax.set_xlabel(f"$x_{input_index}$", fontsize=13) - ax.set_ylabel(f"$y_{output_index}$", fontsize=13) + x_label = input_label if input_label is not None else "x" + y_label = output_label if output_label is not None else "y" + ax.set_xlabel(x_label, fontsize=13) + ax.set_ylabel(y_label, fontsize=13) ax.set_title(title, fontsize=13) ax.grid(True, alpha=0.3) @@ -135,10 +171,7 @@ def plot_xy( handles, _ = ax.get_legend_handles_labels() # Add legend and get its bounding box - lbl = "pred." if y_variance is None else "pred. (±2\u03c3)" legend = ax.legend( - handles[-2:], - ["data", lbl], loc="best", handletextpad=0, columnspacing=0, diff --git a/docs/tutorials/emulation/01_quickstart.ipynb b/docs/tutorials/emulation/01_quickstart.ipynb index 8af50bf2e..c9612bace 100644 --- a/docs/tutorials/emulation/01_quickstart.ipynb +++ b/docs/tutorials/emulation/01_quickstart.ipynb @@ -235,14 +235,30 @@ "metadata": {}, "outputs": [], "source": [ - "ae.plot(best, fname=\"best_model_plot.png\")" + "ae.plot_preds(best, output_names=projectile.output_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can also subset the data included in the plots by providing input and output ranges." + "We can also visualise the predictions against each input feature." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ae.plot(best, output_names=projectile.output_names, input_names=projectile.param_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can subset the data included in the feature plots by providing input and output ranges." ] }, { @@ -269,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)\n" + "ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)" ] }, { diff --git a/tests/core/test_plotting.py b/tests/core/test_plotting.py index 7878e7086..69b604b5d 100644 --- a/tests/core/test_plotting.py +++ b/tests/core/test_plotting.py @@ -35,7 +35,7 @@ def test_plot_xy(): # plot without error bars fig, ax = plt.subplots() plotting.plot_xy( - X, y, y_pred, None, ax=ax, input_index=1, output_index=2, r2_score=0.5 + X, y, y_pred, None, ax=ax, input_label="1", output_label="2", r2_score=0.5 ) # test for error bars assert len(ax.containers) == 0 @@ -45,7 +45,7 @@ def test_plot_xy(): # plot with error bars fig, ax = plt.subplots() plotting.plot_xy( - X, y, y_pred, y_variance, ax=ax, input_index=1, output_index=2, r2_score=0.5 + X, y, y_pred, y_variance, ax=ax, input_label="1", output_label="2", r2_score=0.5 ) assert len(ax.containers) > 0 assert len(ax.collections) > 0