Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions autoemulate/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.logging_config import get_configured_logger
from autoemulate.core.metrics import (
TorchMetrics,
get_metric_config,
get_metric_configs,
)
from autoemulate.core.metrics import TorchMetrics, get_metric_config, get_metric_configs
from autoemulate.core.model_selection import bootstrap, evaluate
from autoemulate.core.plotting import (
calculate_subplot_layout,
Expand Down
159 changes: 150 additions & 9 deletions autoemulate/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,39 @@
from functools import partial

import torchmetrics

from autoemulate.core.types import OutputLike, TensorLike, TorchMetricsLike
from einops import rearrange
from torchmetrics.regression.crps import ContinuousRankedProbabilityScore

from autoemulate.core.types import (
DistributionLike,
OutputLike,
TensorLike,
TorchMetricsLike,
)


class Metric:
"""Configuration for a single metric.

Parameters
----------
name : str
name: str
Display name for the metric.
maximize : bool
maximize: bool
Whether higher values are better. Defaults to True.
"""

name: str
maximize: bool

def __repr__(self) -> str:
"""Return the string representation of the MetricConfig."""
return f"MetricConfig(name={self.name}, maximize={self.maximize})"
"""Return the string representation of the Metric."""
return f"Metric(name={self.name}, maximize={self.maximize})"

@abstractmethod
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
def __call__(
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
) -> TensorLike:
"""Calculate metric."""


Expand All @@ -57,20 +66,149 @@ def __init__(
self.name = name
self.maximize = maximize

def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
def __call__(
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
) -> TensorLike:
"""Calculate metric."""
if not isinstance(y_pred, TensorLike):
if not isinstance(y_pred, OutputLike):
raise ValueError(f"Metric not implemented for y_pred ({type(y_pred)})")
if not isinstance(y_true, TensorLike):
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")

# Handle probabilistic predictions
if isinstance(y_pred, DistributionLike):
try:
y_pred = y_pred.mean
except Exception:
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
metric = self.metric()
metric.to(y_pred.device)
# Assume first dim is a batch dim, flatten others for metric calculation
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))
return metric.compute()


class ProbabilisticMetric(Metric):
"""Base class for probabilistic metrics."""

@abstractmethod
def __call__(
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
) -> TensorLike:
"""Calculate metric."""


class CRPSMetric(ProbabilisticMetric):
"""Continuous Ranked Probability Score (CRPS) metric.

CRPS is a scoring rule for evaluating probabilistic predictions. It reduces to mean
absolute error (MAE) for deterministic predictions and generalizes to distributions
by measuring the integral difference between predicted and actual CDFs.

The metric aggregates over batch and target dimensions by computing the mean
CRPS across all scalar outputs, making it comparable across different batch
sizes and output dimensions.

Attributes
----------
name: str
Display name for the metric.
maximize: bool
Whether higher values are better. False for CRPS (lower is better).
"""

name: str = "crps"
maximize: bool = False

def __call__(
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
) -> TensorLike:
"""Calculate CRPS metric.

The metric handles both deterministic predictions (tensors) and probabilistic
predictions (tensors of samples or distributions).

Aggregation across batch and target dimensions is performed by computing the
mean CRPS across all scalar outputs. This makes the metric comparable across
different batch sizes and target dimensions.

Parameters
----------
y_pred: OutputLike
Predicted outputs. Can be a tensor or a distribution.
- If tensor with shape `(batch_size, *target_shape)`: treated as
deterministic prediction (reduces to MAE).
- If tensor with shape `(batch_size, *target_shape, n_samples)`: treated as
samples from a probabilistic prediction.
- If distribution: `n_samples` are drawn to estimate CRPS.
y_true: TensorLike
True target values of shape `(batch_size, *target_shape)`.
n_samples: int
Number of samples to draw from the predicted distribution if `y_pred` is a
distribution. Defaults to 1000.

Returns
-------
TensorLike
Mean CRPS score across all batch elements and target dimensions.

Raises
------
ValueError
If input types or shapes are incompatible.
"""
if not isinstance(y_true, TensorLike):
raise ValueError(f"y_true must be a tensor, got {type(y_true)}")

# Ensure 2D y_true for consistent handling
y_true = y_true.unsqueeze(-1) if y_true.ndim == 1 else y_true

# Initialize CRPS metric (computes mean by default)
crps_metric = ContinuousRankedProbabilityScore()
crps_metric.to(y_true.device)

# Handle different prediction types
if isinstance(y_pred, DistributionLike):
# Distribution case: sample from it
samples = rearrange(y_pred.sample((n_samples,)), "s b ... -> b ... s")
if samples.shape[:-1] != y_true.shape:
raise ValueError(
f"Sampled predictions shape {samples.shape[:-1]} (excluding sample "
f"dimension) does not match y_true shape {y_true.shape}"
)
elif isinstance(y_pred, TensorLike):
# Tensor case: check dimensions
if y_pred.dim() == y_true.dim():
# Deterministic: same shape as y_true
# CRPS requires at least 2 ensemble members, so duplicate the prediction
samples = y_pred.unsqueeze(-1).repeat_interleave(2, dim=-1)
elif y_pred.dim() == y_true.dim() + 1:
# Probabilistic: already has sample dimension at end
samples = y_pred
if samples.shape[:-1] != y_true.shape:
raise ValueError(
f"y_pred shape {samples.shape[:-1]} (excluding last dimension) "
f"does not match y_true shape {y_true.shape}"
)
else:
raise ValueError(
f"y_pred dimensions ({y_pred.dim()}) incompatible with y_true "
f"dimensions ({y_true.dim()}). Expected same dimensions or "
f"y_true.dim() + 1"
)
else:
raise ValueError(
f"y_pred must be a tensor or distribution, got {type(y_pred)}"
)

# Flatten batch and target dimensions
samples_flat = samples.flatten(end_dim=-2) # (batch * targets, n_samples)
y_true_flat = y_true.flatten() # (batch * targets,)

# ContinuousRankedProbabilityScore computes mean by default
return crps_metric(samples_flat, y_true_flat)


R2 = TorchMetrics(
metric=torchmetrics.R2Score,
name="r2",
Expand All @@ -95,11 +233,14 @@ def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
maximize=False,
)

CRPS = CRPSMetric()

AVAILABLE_METRICS = {
"r2": R2,
"rmse": RMSE,
"mse": MSE,
"mae": MAE,
"crps": CRPS,
}


Expand Down
17 changes: 11 additions & 6 deletions autoemulate/core/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from autoemulate.core.types import (
DeviceLike,
ModelParams,
OutputLike,
TensorLike,
TransformedEmulatorParams,
)
Expand All @@ -25,27 +26,31 @@


def evaluate(
y_pred: TensorLike,
y_pred: OutputLike,
y_true: TensorLike,
metric: Metric = R2,
n_samples: int = 1000,
) -> float:
"""
Evaluate Emulator prediction performance using a `torchmetrics.Metric`.

Parameters
----------
y_pred: OutputLike
Predicted target values, as returned by an Emulator.
y_true: TensorLike
Ground truth target values.
y_pred: TensorLike
Predicted target values, as returned by an Emulator.
metric: Metric
Metric to use for evaluation. Defaults to R2.
n_samples: int
Number of samples to generate to predict mean when y_pred does not have a mean
directly available. Defaults to 1000.

Returns
-------
float
"""
return metric(y_pred, y_true).item()
return metric(y_pred, y_true, n_samples=n_samples).item()


def cross_validate(
Expand Down Expand Up @@ -139,7 +144,7 @@ def cross_validate(
transformed_emulator.fit(x, y)

# compute and save results
y_pred = transformed_emulator.predict_mean(x_val)
y_pred = transformed_emulator.predict(x_val)
for metric in metrics:
score = evaluate(y_pred, y_val, metric)
cv_results[metric.name].append(score)
Expand Down Expand Up @@ -192,7 +197,7 @@ def bootstrap(

# If no bootstraps are specified, fall back to a single evaluation on given data
if n_bootstraps is None:
y_pred = model.predict_mean(x, n_samples=n_samples)
y_pred = model.predict(x)
results = {}
for metric in metrics:
score = evaluate(y_pred, y, metric)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"torchrbf>=0.0.1",
"arviz>=0.21.0",
"getdist>=1.7.2",
"einops>=0.8.1",
]

[project.urls]
Expand Down
Loading