Skip to content

Commit d644cfa

Browse files
committed
Update Metric and evaluate() API to support OutputLike
1 parent f25fe33 commit d644cfa

File tree

4 files changed

+94
-15
lines changed

4 files changed

+94
-15
lines changed

autoemulate/core/compare.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212

1313
from autoemulate.core.device import TorchDeviceMixin
1414
from autoemulate.core.logging_config import get_configured_logger
15-
from autoemulate.core.metrics import (
16-
TorchMetrics,
17-
get_metric_config,
18-
get_metric_configs,
19-
)
15+
from autoemulate.core.metrics import TorchMetrics, get_metric_config, get_metric_configs
2016
from autoemulate.core.model_selection import bootstrap, evaluate
2117
from autoemulate.core.plotting import (
2218
calculate_subplot_layout,

autoemulate/core/metrics.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def __repr__(self) -> str:
3737
return f"Metric(name={self.name}, maximize={self.maximize})"
3838

3939
@abstractmethod
40-
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
40+
def __call__(
41+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
42+
) -> TensorLike:
4143
"""Calculate metric."""
4244

4345

@@ -64,13 +66,21 @@ def __init__(
6466
self.name = name
6567
self.maximize = maximize
6668

67-
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
69+
def __call__(
70+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
71+
) -> TensorLike:
6872
"""Calculate metric."""
69-
if not isinstance(y_pred, TensorLike):
73+
if not isinstance(y_pred, OutputLike):
7074
raise ValueError(f"Metric not implemented for y_pred ({type(y_pred)})")
7175
if not isinstance(y_true, TensorLike):
7276
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")
7377

78+
# Handle probabilistic predictions
79+
if isinstance(y_pred, DistributionLike):
80+
try:
81+
y_pred = y_pred.mean
82+
except Exception:
83+
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
7484
metric = self.metric()
7585
metric.to(y_pred.device)
7686
# Assume first dim is a batch dim, flatten others for metric calculation
@@ -82,7 +92,9 @@ class ProbabilisticMetric(Metric):
8292
"""Base class for probabilistic metrics."""
8393

8494
@abstractmethod
85-
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
95+
def __call__(
96+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
97+
) -> TensorLike:
8698
"""Calculate metric."""
8799

88100

autoemulate/core/model_selection.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from autoemulate.core.types import (
1515
DeviceLike,
1616
ModelParams,
17+
OutputLike,
1718
TensorLike,
1819
TransformedEmulatorParams,
1920
)
@@ -25,27 +26,31 @@
2526

2627

2728
def evaluate(
28-
y_pred: TensorLike,
29+
y_pred: OutputLike,
2930
y_true: TensorLike,
3031
metric: Metric = R2,
32+
n_samples: int = 1000,
3133
) -> float:
3234
"""
3335
Evaluate Emulator prediction performance using a `torchmetrics.Metric`.
3436
3537
Parameters
3638
----------
39+
y_pred: OutputLike
40+
Predicted target values, as returned by an Emulator.
3741
y_true: TensorLike
3842
Ground truth target values.
39-
y_pred: TensorLike
40-
Predicted target values, as returned by an Emulator.
4143
metric: Metric
4244
Metric to use for evaluation. Defaults to R2.
45+
n_samples: int
46+
Number of samples to generate to predict mean when y_pred does not have a mean
47+
directly available. Defaults to 1000.
4348
4449
Returns
4550
-------
4651
float
4752
"""
48-
return metric(y_pred, y_true).item()
53+
return metric(y_pred, y_true, n_samples=n_samples).item()
4954

5055

5156
def cross_validate(
@@ -139,7 +144,7 @@ def cross_validate(
139144
transformed_emulator.fit(x, y)
140145

141146
# compute and save results
142-
y_pred = transformed_emulator.predict_mean(x_val)
147+
y_pred = transformed_emulator.predict(x_val)
143148
for metric in metrics:
144149
score = evaluate(y_pred, y_val, metric)
145150
cv_results[metric.name].append(score)
@@ -192,7 +197,7 @@ def bootstrap(
192197

193198
# If no bootstraps are specified, fall back to a single evaluation on given data
194199
if n_bootstraps is None:
195-
y_pred = model.predict_mean(x, n_samples=n_samples)
200+
y_pred = model.predict(x)
196201
results = {}
197202
for metric in metrics:
198203
score = evaluate(y_pred, y, metric)

tests/core/test_metrics.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,69 @@ def test_crps_with_1d_targets():
528528
assert result.ndim == 0, "Result should be a scalar tensor"
529529
assert isinstance(result, torch.Tensor)
530530
assert result >= 0, "CRPS should be non-negative"
531+
532+
533+
# Tests for OutputLike support in TorchMetrics
534+
535+
536+
def test_torchmetrics_with_distribution_vs_mean():
537+
"""Test TorchMetrics with distribution gives same result as using mean."""
538+
batch_size, n_targets = 10, 3
539+
y_true = torch.randn(batch_size, n_targets)
540+
541+
# Create a Normal distribution
542+
mean = torch.randn(batch_size, n_targets)
543+
std = torch.ones(batch_size, n_targets) * 0.5
544+
y_pred_dist = Normal(mean, std)
545+
546+
# Get result with distribution
547+
result_dist = MSE(y_pred_dist, y_true)
548+
549+
# Get result with mean tensor
550+
result_mean = MSE(mean, y_true)
551+
552+
assert torch.isclose(result_dist, result_mean, rtol=1e-4), "Should be close"
553+
554+
555+
@pytest.mark.parametrize(
556+
"metric_instance",
557+
[
558+
metric
559+
for metric in AVAILABLE_METRICS.values()
560+
if isinstance(metric, TorchMetrics)
561+
],
562+
)
563+
def test_all_torchmetrics_support_distributions(metric_instance):
564+
"""Test all TorchMetrics instances support distribution inputs."""
565+
batch_size = 20
566+
y_true = torch.randn(batch_size, 2)
567+
568+
# Create a distribution
569+
mean = torch.randn(batch_size, 2)
570+
std = torch.ones(batch_size, 2) * 0.3
571+
y_pred_dist = Normal(mean, std)
572+
573+
# Should work without error
574+
result = metric_instance(y_pred_dist, y_true)
575+
576+
assert isinstance(result, torch.Tensor)
577+
assert result.ndim == 0
578+
assert torch.isfinite(result), "Result should be finite"
579+
580+
581+
def test_torchmetrics_distribution_multioutput():
582+
"""Test TorchMetrics with distribution for multioutput case."""
583+
batch_size, n_outputs = 50, 5
584+
y_true = torch.randn(batch_size, n_outputs)
585+
586+
# Create distribution with different means for different outputs
587+
mean = torch.randn(batch_size, n_outputs)
588+
std = torch.rand(batch_size, n_outputs) * 0.5 + 0.1 # Avoid zero std
589+
y_pred_dist = Normal(mean, std)
590+
591+
# Test with MAE
592+
result = MAE(y_pred_dist, y_true)
593+
594+
assert isinstance(result, torch.Tensor)
595+
assert result.ndim == 0
596+
assert result >= 0, "MAE should be non-negative"

0 commit comments

Comments
 (0)