Skip to content

Commit edeea16

Browse files
authored
Merge pull request #897 from alan-turing-institute/calibration_plot
Add calibration plot
2 parents 243a0f9 + 8070772 commit edeea16

File tree

3 files changed

+254
-9
lines changed

3 files changed

+254
-9
lines changed

autoemulate/core/compare.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
calculate_subplot_layout,
1818
create_and_plot_slice,
1919
display_figure,
20+
plot_calibration_from_distributions,
2021
plot_xy,
2122
)
2223
from autoemulate.core.reinitialize import fit_from_reinitialized
@@ -27,6 +28,7 @@
2728
DeviceLike,
2829
InputLike,
2930
ModelParams,
31+
TensorLike,
3032
TransformedEmulatorParams,
3133
)
3234
from autoemulate.data.utils import ConversionMixin, set_random_seed
@@ -36,7 +38,7 @@
3638
PYTORCH_EMULATORS,
3739
get_emulator_class,
3840
)
39-
from autoemulate.emulators.base import Emulator
41+
from autoemulate.emulators.base import Emulator, ProbabilisticEmulator
4042
from autoemulate.emulators.transformed.base import TransformedEmulator
4143
from autoemulate.transforms.base import AutoEmulateTransform
4244
from autoemulate.transforms.standardize import StandardizeTransform
@@ -839,6 +841,78 @@ def plot_surface(
839841
fig.savefig(fname, bbox_inches="tight")
840842
return None
841843

844+
def plot_calibration(
845+
self,
846+
emulator: ProbabilisticEmulator,
847+
x_test: TensorLike | None = None,
848+
y_test: TensorLike | None = None,
849+
levels: np.ndarray | None = None,
850+
n_samples: int = 2000,
851+
joint: bool = False,
852+
title: str | None = None,
853+
legend: bool = True,
854+
fname: str | None = None,
855+
figsize: tuple[int, int] | None = None,
856+
**kwargs,
857+
):
858+
"""Plot calibration curve(s) for a given emulator.
859+
860+
This draws empirical coverage (y-axis) against nominal coverage (x-axis).
861+
862+
Parameters
863+
----------
864+
emulator: ProbabilisticEmulator
865+
Emulator that outputs a predictive distribution.
866+
x_test: Tensorlike | None
867+
Optional test inputs. If None, the held out test data is used.
868+
Defaults to None.
869+
y_test: Tensorlike | None
870+
Optional true test outputs. If None, the held out test data is used.
871+
Defaults to None.
872+
levels: array-like, optional
873+
Nominal coverage levels (between 0 and 1). If None, a default grid is
874+
used.
875+
n_samples: int
876+
Number of Monte-Carlo samples to draw from the predictive
877+
distribution to compute empirical intervals if analytical quantiles
878+
are not available.
879+
joint: bool
880+
If True and the predictive outputs are multivariate, compute joint
881+
coverage (i.e., the true vector must lie inside the interval for all
882+
dimensions). If False (default), compute marginal coverage per output
883+
dimension and return the mean across data points.
884+
title: str | None
885+
An optional title for the plot. Defaults to None (no title).
886+
legend: bool
887+
Whether to display a legend. Defaults to True.
888+
fname: str | None
889+
If provided, the figure will be saved to this file path. If None, the figure
890+
will be displayed. Defaults to None.
891+
figsize: tuple[int, int] | None
892+
The size of the figure to create. If None, a default size is used.
893+
Defaults to None.
894+
"""
895+
if x_test is None or y_test is None:
896+
if not (x_test is None and y_test is None):
897+
msg = (
898+
"Both x_test and y_test must be provided, or neither to use held "
899+
"out test data."
900+
)
901+
raise ValueError(msg)
902+
self.logger.info(
903+
"Using held out test data for calibration plot. "
904+
"To use different data, provide x_test and y_test."
905+
)
906+
x_test, y_test = self._convert_to_tensors(self.test)
907+
y_pred = emulator.predict(x_test)
908+
fig, _ = plot_calibration_from_distributions(
909+
y_pred, y_test, levels, n_samples, joint, title, legend, figsize
910+
)
911+
if fname is None:
912+
return display_figure(fig)
913+
fig.savefig(fname, bbox_inches="tight")
914+
return None
915+
842916
def save(
843917
self,
844918
model_obj: int | Emulator | Result,

autoemulate/core/plotting.py

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from matplotlib.axes import Axes
66
from matplotlib.figure import Figure
77

8-
from autoemulate.core.types import NumpyLike, TensorLike
8+
from autoemulate.core.types import DistributionLike, GaussianLike, NumpyLike, TensorLike
99
from autoemulate.emulators.base import Emulator
1010

1111

@@ -236,7 +236,6 @@ def mean_and_var_surface(
236236
The predicted variance on the grid.
237237
grid: list[TensorLike]
238238
The grid of parameter values used for predictions.
239-
240239
"""
241240
# Determine which parameters to vary and which to fix
242241
grid_params = {}
@@ -412,8 +411,162 @@ def create_and_plot_slice(
412411
param_pair_names,
413412
vmin,
414413
vmax,
415-
fixed_params_info=f"{', '.join(fixed_params)} at {quantile:.1f} quantile"
416-
if len(fixed_params) > 0
417-
else "None",
414+
fixed_params_info=(
415+
f"{', '.join(fixed_params)} at {quantile:.1f} quantile"
416+
if len(fixed_params) > 0
417+
else "None"
418+
),
418419
)
419420
return fig, ax
421+
422+
423+
def coverage_from_distributions(
424+
y_pred: DistributionLike,
425+
y_true: TensorLike,
426+
levels: list[float] | NumpyLike | TensorLike | None = None,
427+
n_samples: int = 2000,
428+
joint: bool = False,
429+
) -> tuple[NumpyLike, NumpyLike]:
430+
"""Compute empirical coverage for a set of nominal confidence levels.
431+
432+
Parameters
433+
----------
434+
y_pred: DistributionLike
435+
The emulator predicted distribution.
436+
y_true: TensorLike
437+
The true values.
438+
levels: array-like, optional
439+
Nominal coverage levels (between 0 and 1). If None, a default grid is
440+
used. Defaults to None.
441+
n_samples: int
442+
Number of Monte-Carlo samples to draw from the predictive
443+
distribution to compute empirical intervals if analytical quantiles
444+
are not available.
445+
joint: bool
446+
If True and the predictive outputs are multivariate, compute joint
447+
coverage (i.e., the true vector must lie inside the interval for all
448+
dimensions). If False (default), compute marginal coverage per output
449+
dimension and return the mean across data points.
450+
451+
Returns
452+
-------
453+
levels: np.ndarray
454+
Nominal coverage levels.
455+
empirical: np.ndarray
456+
Empirical coverages. Shape is (len(levels), output_dim) when
457+
`joint=False` and output_dim>1, or (len(levels),) when joint=True or
458+
output_dim==1.
459+
"""
460+
if levels is None:
461+
levels = np.linspace(0.0, 1.0, 51)
462+
levels = np.asarray(levels)
463+
464+
# if dist.icdf not available, compute empirical intervals using sample quantiles
465+
samples = None
466+
y_dist = None
467+
if isinstance(y_pred, GaussianLike):
468+
y_dist = y_pred
469+
elif isinstance(y_pred, torch.distributions.Independent) and isinstance(
470+
y_pred.base_dist, GaussianLike
471+
):
472+
y_dist = y_pred.base_dist
473+
else:
474+
samples = y_pred.sample((n_samples,))
475+
476+
empirical_list = []
477+
for p in levels:
478+
lower_q = (1.0 - p) / 2.0
479+
upper_q = 1.0 - lower_q
480+
481+
if y_dist is not None:
482+
lower = y_dist.icdf(lower_q)
483+
upper = y_dist.icdf(upper_q)
484+
else:
485+
assert samples is not None
486+
lower = torch.quantile(samples, float(lower_q), dim=0)
487+
upper = torch.quantile(samples, float(upper_q), dim=0)
488+
489+
inside = (y_true >= lower) & (y_true <= upper)
490+
if joint:
491+
inside_all = inside.all(dim=-1)
492+
empirical = inside_all.float().mean().item()
493+
else:
494+
# marginal per-dim coverage
495+
empirical = inside.float().mean(dim=0).cpu().numpy()
496+
empirical_list.append(empirical)
497+
498+
empirical_arr = np.asarray(empirical_list)
499+
500+
return levels, empirical_arr
501+
502+
503+
def plot_calibration_from_distributions(
504+
y_pred: DistributionLike,
505+
y_true: TensorLike,
506+
levels: np.ndarray | None = None,
507+
n_samples: int = 2000,
508+
joint: bool = False,
509+
title: str | None = None,
510+
legend: bool = True,
511+
figsize: tuple[int, int] | None = None,
512+
):
513+
"""Plot calibration curve(s) given predictive distributions and true values.
514+
515+
This draws empirical coverage (y-axis) against nominal coverage (x-axis).
516+
517+
When points lie above or below the diagonal, this indicates that uncertainty
518+
is respectively being overestimated or underestimated.
519+
520+
Parameters
521+
----------
522+
y_pred: DistributionLike
523+
The emulator predicted distribution.
524+
y_true: TensorLike
525+
The true values.
526+
levels: array-like, optional
527+
Nominal coverage levels (between 0 and 1). If None, a default grid is
528+
used.
529+
n_samples: int
530+
Number of Monte-Carlo samples to draw from the predictive
531+
distribution to compute empirical intervals.
532+
joint: bool
533+
If True and the predictive outputs are multivariate, compute joint
534+
coverage (i.e., the true vector must lie inside the interval for all
535+
dimensions). If False (default), compute marginal coverage per output
536+
dimension and return the mean across data points.
537+
title: str | None
538+
An optional title for the plot. Defaults to None (no title).
539+
legend: bool
540+
Whether to display a legend. Defaults to True.
541+
figsize: tuple[int, int] | None
542+
The size of the figure to create. If None, a default size is used.
543+
"""
544+
levels, empirical = coverage_from_distributions(
545+
y_pred, y_true, levels=levels, n_samples=n_samples, joint=joint
546+
)
547+
548+
if figsize is None:
549+
figsize = (6, 6)
550+
fig, ax = plt.subplots(figsize=figsize)
551+
552+
if len(empirical.shape) == 1 or empirical.shape[1] == 1:
553+
ax.plot(levels, empirical, marker="o", label="empirical")
554+
else:
555+
# multiple outputs: plot each dimension
556+
for i in range(empirical.shape[1]):
557+
ax.plot(levels, empirical[:, i], marker="o", label=f"$y_{i}$")
558+
559+
# diagonal reference
560+
ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
561+
ax.set_xlim(0, 1)
562+
ax.set_ylim(0, 1)
563+
ax.set_xlabel("Expected coverage")
564+
ax.set_ylabel("Observed coverage")
565+
566+
if title:
567+
ax.set_title(title)
568+
ax.grid(alpha=0.3)
569+
if legend:
570+
ax.legend()
571+
572+
return fig, ax

docs/tutorials/emulation/01_quickstart.ipynb

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,9 @@
258258
"cell_type": "markdown",
259259
"metadata": {},
260260
"source": [
261-
"As well as plotting the data, we can directly plot the predicted mean and variance of the emulator for a pair of variables while holding the other variables constant at a given quantile. API to support plotting for a subset of the parameter and output range is also supported."
261+
"As well as plotting the data, we can directly plot the predicted mean and variance of the emulator for a pair of variables while holding the other variables constant at a given quantile. API to support plotting for a subset of the parameter and output range is also supported.\n",
262+
"\n",
263+
"The emulator predicted mean captures the simulated data plotted at the top of the tutorial well. The predicted variance is low where we have data, and increases away from the data. "
262264
]
263265
},
264266
{
@@ -270,6 +272,22 @@
270272
"ae.plot_surface(best.model, projectile.parameters_range, quantile=0.5)\n"
271273
]
272274
},
275+
{
276+
"cell_type": "markdown",
277+
"metadata": {},
278+
"source": [
279+
"We can also visualise the calibration of the emulator's predicted uncertainty on the held out test data. The closer the line is to the diagonal, the better calibrated the uncertainty is. Line above the diagonal overestimates the uncertainty while line below the diagonal underestimates it."
280+
]
281+
},
282+
{
283+
"cell_type": "code",
284+
"execution_count": null,
285+
"metadata": {},
286+
"outputs": [],
287+
"source": [
288+
"ae.plot_calibration(best.model)"
289+
]
290+
},
273291
{
274292
"cell_type": "markdown",
275293
"metadata": {},
@@ -358,7 +376,7 @@
358376
],
359377
"metadata": {
360378
"kernelspec": {
361-
"display_name": ".venv",
379+
"display_name": "autoemulate",
362380
"language": "python",
363381
"name": "python3"
364382
},
@@ -372,7 +390,7 @@
372390
"name": "python",
373391
"nbconvert_exporter": "python",
374392
"pygments_lexer": "ipython3",
375-
"version": "3.12.11"
393+
"version": "3.12.7"
376394
}
377395
},
378396
"nbformat": 4,

0 commit comments

Comments
 (0)