Skip to content

Commit 396faa5

Browse files
committed
Initial CRPS impl
1 parent 817014e commit 396faa5

File tree

1 file changed

+100
-2
lines changed

1 file changed

+100
-2
lines changed

autoemulate/core/metrics.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
from functools import partial
88

99
import torchmetrics
10-
11-
from autoemulate.core.types import OutputLike, TensorLike, TorchMetricsLike
10+
from einops import rearrange
11+
from torchmetrics.regression.crps import ContinuousRankedProbabilityScore
12+
13+
from autoemulate.core.types import (
14+
DistributionLike,
15+
OutputLike,
16+
TensorLike,
17+
TorchMetricsLike,
18+
)
1219

1320

1421
class Metric:
@@ -71,6 +78,97 @@ def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
7178
return metric.compute()
7279

7380

81+
class ProbabilisticMetric(Metric):
82+
"""Base class for probabilistic metrics."""
83+
84+
@abstractmethod
85+
def __call__(self, y_pred: OutputLike, y_true: TensorLike) -> TensorLike:
86+
"""Calculate metric."""
87+
88+
89+
class CRPS(ProbabilisticMetric):
90+
"""Continuous Ranked Probability Score (CRPS) metric.
91+
92+
Parameters
93+
----------
94+
name : str
95+
Display name for the metric.
96+
maximize : bool
97+
Whether higher values are better. Defaults to False.
98+
"""
99+
100+
name: str = "crps"
101+
maximize: bool = False
102+
103+
def __call__(
104+
self, y_pred: OutputLike, y_true: TensorLike, n_samples: int = 1000
105+
) -> TensorLike:
106+
"""Calculate CRPS metric.
107+
108+
The metric can handle both deterministic predictions (tensors) and probabilistic
109+
predictions.
110+
111+
Aggregation across batch and target dimensions is performed by flattening such
112+
that the sum of scores is taken across all samples for each point.
113+
114+
Parameters
115+
----------
116+
y_pred: OutputLike
117+
Predicted outputs. Can be a tensor or a distribution. If `y_pred` is a
118+
tensor of shape (batch_size, *(target_shape)), it is treated as
119+
a deterministic prediction and reduces the metric calculation to mean
120+
absolute error.
121+
If `y_pred` is a tensor of shape
122+
`(batch_size, *(target_shape), n_samples)`, it is treated as a
123+
probabilistic prediction and the metric is computed across the samples.
124+
If `y_pred` is a distribution, then `n_samples` are drawn from the predicted
125+
distribution to estimate the CRPS.
126+
y_true: TensorLike
127+
True target values.
128+
n_samples: int
129+
Number of samples to draw from the predicted distribution if `y_pred` is a
130+
distribution. Defaults to 1000.
131+
132+
"""
133+
if not isinstance(y_true, TensorLike):
134+
raise ValueError(f"Metric not implemented for y_true ({type(y_true)})")
135+
136+
crps_metric = ContinuousRankedProbabilityScore()
137+
crps_metric.to(y_true.device)
138+
139+
# Deterministic predictions case
140+
if (isinstance(y_pred, TensorLike) and y_pred.dim() == y_true.dim()) or (
141+
isinstance(y_pred, TensorLike) and y_pred.dim() == y_true.dim() + 1
142+
):
143+
samples = y_pred
144+
# Distribution case
145+
elif isinstance(y_pred, DistributionLike):
146+
# Move sample dim to end
147+
samples = rearrange(y_pred.sample((n_samples,)), "s b ... -> b ... s")
148+
print(samples.shape, y_true.shape)
149+
assert samples.shape[:-1] == y_true.shape, (
150+
f"predictive distribution samples shape {samples.shape} does not match "
151+
f"y_true shape {y_true.shape} "
152+
)
153+
# Otherwise, raise error
154+
else:
155+
if isinstance(y_pred, TensorLike) and isinstance(y_true, TensorLike):
156+
msg = (
157+
f"Metric not implemented for y_pred shape ({y_pred.shape}) given "
158+
f"y_true shape ({y_true.shape})"
159+
)
160+
raise ValueError(msg)
161+
msg = (
162+
f"Metric not implemented for y_pred ({type(y_pred)}) and y_true "
163+
f"({type(y_true)})"
164+
)
165+
raise ValueError(msg)
166+
167+
# Reshape samples and y_true to (-1, n_samples) and (-1,) respectively, compute
168+
samples = samples.flatten(start_dim=0, end_dim=-2)
169+
return crps_metric(samples, y_true.flatten())
170+
171+
74172
R2 = TorchMetrics(
75173
metric=torchmetrics.R2Score,
76174
name="r2",

0 commit comments

Comments
 (0)