Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/SCRIPTS_AND_CONFIGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ Defines the data source.

#### `eval` (Evaluation Script Only)
* **`checkpoint`**: Path to the trained model checkpoint to load.
* **`metrics`**: List of metrics to compute (e.g., `["mse", "rmse"]`).
* **`metrics`**: List of metrics to compute (e.g., `["mse", "rmse", "spread", "skill"]` for ensemble evaluation). For ensemble predictions, `skill` is the RMSE of the ensemble mean, so it matches `rmse` numerically and is mainly kept as explicit spread/skill terminology.
* **`video_dir`**: Where to save rollout visualizations.

## Workflow Examples
Expand Down
6 changes: 4 additions & 2 deletions src/autocast/configs/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ All eval configs support these parameters:
run. See [Evaluation modes](#evaluation-modes) below.
- `metrics`: List of metrics to compute (default includes mse/mae/rmse/vrmse,
power spectrum scores `psrmse*`, cross-correlation spectrum scores `pscc*`,
and ensemble scores `crps`, `fcrps`, `afcrps`, `energy`, `ssr`; `variogram`
remains available via explicit opt-in)
and ensemble scores `crps`, `fcrps`, `afcrps`, `energy`, `spread`, `skill`,
`ssr`; `variogram` remains available via explicit opt-in. Note that for
ensemble predictions, `skill` is the RMSE of the ensemble mean, so it matches
`rmse` numerically and is included for explicit spread/skill reporting.)
- `csv_path`: Custom path for metrics CSV (default: work_dir/evaluation_metrics.csv)
- `video_dir`: Custom directory for rollout videos (default: work_dir/videos)
- `batch_indices`: List of rollout sample indices to visualize (resolved across
Expand Down
6 changes: 6 additions & 0 deletions src/autocast/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
CRPS,
AlphaFairCRPS,
EnergyScore,
EnsembleSkill,
EnsembleSpread,
FairCRPS,
SpreadSkillRatio,
VariogramScore,
Expand All @@ -43,6 +45,8 @@
"AlphaFairCRPS",
"Coverage",
"EnergyScore",
"EnsembleSkill",
"EnsembleSpread",
"FairCRPS",
"LInfinity",
"MultiCoverage",
Expand Down Expand Up @@ -88,6 +92,8 @@
FairCRPS,
EnergyScore,
VariogramScore,
EnsembleSpread,
EnsembleSkill,
SpreadSkillRatio,
WinklerScore,
Coverage,
Expand Down
128 changes: 128 additions & 0 deletions src/autocast/metrics/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,134 @@ def score(
return ssr


class EnsembleSpread(BTSCMMetric):
r"""
Ensemble spread for probabilistic forecasts.

Notes
-----
By default, returns a **finite-ensemble corrected spread**:

.. math::
\text{Spread}_{\text{corr}} =
\sqrt{\left\langle \mathrm{Var}_{m,\text{unbiased}}(x_m)\right\rangle}
\sqrt{\frac{M + 1}{M}}.

This correction is commonly used so that spread and skill are comparable for
finite ensemble sizes when using unbiased sample variance. It matches the
form used in LoLA/paper evaluations (Appendix "Spread / Skill") where:
``spread = sqrt((M+1)/(M-1) * mean((x_m - mean_m)^2))``, since
``Var_unbiased = (M/(M-1)) * mean((x_m - mean_m)^2)``.

If ``corrected=False``, returns the uncorrected macroscopic ensemble standard
deviation computed from the unbiased variance estimator:

.. math::
\sqrt{\left\langle \mathrm{Var}_{m,\text{unbiased}}(x_m)\right\rangle}.
"""

name: str = "spread"

def __init__(
self,
*,
corrected: bool = True,
score_dims: Literal["spatial", "temporal"] | None = "spatial",
reduce_all: bool = True,
dist_sync_on_step: bool = False,
):
super().__init__(
score_dims=score_dims,
reduce_all=reduce_all,
dist_sync_on_step=dist_sync_on_step,
)
self.corrected = corrected

def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTSC:
"""Not used directly; we override score() to change reduction order."""
msg = "EnsembleSpread overrides score() directly."
raise NotImplementedError(msg)

def score(
self, y_pred: ArrayLike, y_true: ArrayLike
) -> TensorBTC | TensorBSC | TensorBTSC:
y_pred_tensor, y_true_tensor = self._check_input(y_pred, y_true)

n_ensemble = y_pred_tensor.shape[-1]
if n_ensemble < 2:
raise ValueError(
"EnsembleSpread requires at least 2 ensemble members "
f"(got {n_ensemble})."
)

spread_var = y_pred_tensor.var(dim=-1, unbiased=True) # (B, T, S..., C)

# Reduce variance before sqrt (macroscopic approach)
if self.score_dims == "spatial":
n_spatial_dims = self._infer_n_spatial_dims(y_true_tensor)
spatial_dims = tuple(range(2, 2 + n_spatial_dims))
spread_var = spread_var.mean(dim=spatial_dims)
elif self.score_dims == "temporal":
spread_var = spread_var.mean(dim=1)

spread = torch.sqrt(spread_var)

if self.corrected:
correction = float(np.sqrt((n_ensemble + 1) / n_ensemble))
spread = spread * correction

return spread


class EnsembleSkill(BTSCMMetric):
r"""
Ensemble skill defined as RMSE of the ensemble mean.

Notes
-----
Skill is defined as the RMSE of the ensemble mean:

.. math::
\text{Skill} = \sqrt{\left\langle (\bar{x} - y)^2 \right\rangle},

where :math:`\langle \cdot \rangle` denotes the spatial mean.

This metric reduces the squared error over spatial/temporal dimensions *before*
taking the square root (macroscopic RMSE), as is commonly done in ensemble
forecast evaluation (and in LoLA/paper appendices).

In the default spatial-reduction evaluation path, this is numerically
equivalent to the deterministic ``RMSE`` metric applied to an ensemble
prediction tensor, because ``RMSE`` first averages over the ensemble
dimension and then computes RMSE.
"""

name: str = "skill"

def _score(self, y_pred: TensorBTSCM, y_true: TensorBTSC) -> TensorBTSC:
"""Not used directly; we override score() to change reduction order."""
msg = "EnsembleSkill overrides score() directly."
raise NotImplementedError(msg)

def score(
self, y_pred: ArrayLike, y_true: ArrayLike
) -> TensorBTC | TensorBSC | TensorBTSC:
y_pred_tensor, y_true_tensor = self._check_input(y_pred, y_true)

ensemble_mean = y_pred_tensor.mean(dim=-1)
skill_sq = (ensemble_mean - y_true_tensor) ** 2

# Reduce MSE before sqrt (macroscopic approach)
if self.score_dims == "spatial":
n_spatial_dims = self._infer_n_spatial_dims(y_true_tensor)
spatial_dims = tuple(range(2, 2 + n_spatial_dims))
skill_sq = skill_sq.mean(dim=spatial_dims)
elif self.score_dims == "temporal":
skill_sq = skill_sq.mean(dim=1)

return torch.sqrt(skill_sq)


class WinklerScore(BTSCMMetric):
r"""
Winkler interval score for central prediction intervals.
Expand Down
6 changes: 6 additions & 0 deletions src/autocast/scripts/eval/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
CRPS,
AlphaFairCRPS,
EnergyScore,
EnsembleSkill,
EnsembleSpread,
FairCRPS,
SpreadSkillRatio,
VariogramScore,
Expand Down Expand Up @@ -109,6 +111,8 @@
"afcrps": AlphaFairCRPS,
"energy": EnergyScore,
"variogram": VariogramScore,
"spread": EnsembleSpread,
"skill": EnsembleSkill,
"ssr": SpreadSkillRatio,
"winkler": WinklerScore,
}
Expand All @@ -132,6 +136,8 @@
"fcrps",
"afcrps",
"energy",
"spread",
"skill",
"ssr",
"winkler",
]
Expand Down
67 changes: 66 additions & 1 deletion tests/metrics/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from autocast.metrics import ALL_ENSEMBLE_METRICS
from autocast.metrics.base import BaseMetric
from autocast.metrics.coverage import Coverage
from autocast.metrics.deterministic import RMSE
from autocast.metrics.ensemble import (
EnergyScore,
EnsembleSkill,
EnsembleSpread,
SpreadSkillRatio,
VariogramScore,
WinklerScore,
Expand All @@ -25,7 +28,7 @@
ENSEMBLE_ERROR_METRICS = tuple(
m
for m in ENSEMBLE_BASE_METRICS
if m not in [Coverage, VariogramScore, SpreadSkillRatio]
if m not in [Coverage, VariogramScore, SpreadSkillRatio, EnsembleSpread]
)


Expand Down Expand Up @@ -351,6 +354,68 @@ def test_spread_skill_ratio_stateful_returns_per_lead_time_and_is_mean_of_ratios
), value


def test_ensemble_spread_matches_lola_correction():
# Shape: (B=1, T=1, S=1, C=1, M=2)
# Members: [0, 2]
# unbiased var = 2, sqrt(var)=sqrt(2)
# corrected spread = sqrt(2) * sqrt((M+1)/M) = sqrt(3)
y_pred = torch.tensor([[[[[0.0, 2.0]]]]])
y_true = torch.tensor([[[[0.0]]]])

value = EnsembleSpread()(y_pred, y_true)
assert torch.allclose(value, torch.tensor(3.0**0.5), atol=1e-6)


def test_ensemble_spread_can_be_uncorrected():
y_pred = torch.tensor([[[[[0.0, 2.0]]]]])
y_true = torch.tensor([[[[0.0]]]])

value = EnsembleSpread(corrected=False)(y_pred, y_true)
assert torch.allclose(value, torch.tensor(2.0**0.5), atol=1e-6)


def test_ensemble_spread_requires_multiple_ensemble_members():
y_pred = torch.ones((1, 1, 1, 1, 1))
y_true = torch.ones((1, 1, 1, 1))

with pytest.raises(ValueError, match="at least 2 ensemble members"):
EnsembleSpread()(y_pred, y_true)


def test_ensemble_spread_accepts_base_metric_kwargs():
y_pred = torch.tensor([[[[[0.0, 2.0]]], [[[0.0, 4.0]]]]])
y_true = torch.zeros((1, 2, 1, 1))

value = EnsembleSpread(score_dims="temporal", reduce_all=False).score(
y_pred, y_true
)

expected = torch.tensor([[[7.5**0.5]]])
assert torch.allclose(value, expected, atol=1e-6)


def test_ensemble_skill_is_rmse_of_ensemble_mean():
# Shape: (B=1, T=1, S=1, C=1, M=2)
# Members: [0, 2], truth: [0]
# mean=1 -> rmse = 1
y_pred = torch.tensor([[[[[0.0, 2.0]]]]])
y_true = torch.tensor([[[[0.0]]]])

value = EnsembleSkill()(y_pred, y_true)
assert torch.allclose(value, torch.tensor(1.0), atol=1e-6)


def test_ensemble_skill_matches_deterministic_rmse_on_ensemble_predictions():
torch.manual_seed(0)
y_pred = torch.randn((2, 3, 4, 5, 2, 7))
y_true = torch.randn((2, 3, 4, 5, 2))

skill = EnsembleSkill()(y_pred, y_true)
rmse = RMSE()(y_pred, y_true)

assert torch.allclose(skill, rmse, atol=1e-6)


def test_winkler_score_manual_value():
# Shape: (B=1, T=1, S=2, C=1, M=5)
# Ensemble members: [0, 1, 2, 3, 4], alpha=0.2
Expand Down
7 changes: 7 additions & 0 deletions tests/scripts/test_eval_encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from autocast.metrics.ensemble import CRPS, AlphaFairCRPS, SpreadSkillRatio
from autocast.scripts.eval.encoder_processor_decoder import (
DEFAULT_EVAL_METRICS,
DEFAULT_EVAL_MODE,
EVAL_PATH_AMBIENT_EPD,
EVAL_PATH_ENCODE_ONCE,
Expand Down Expand Up @@ -287,6 +288,12 @@ def test_build_per_timestep_metric_factory_sets_reduce_all_false_for_ssr():
assert getattr(metric, "reduce_all", None) is False


def test_default_eval_metrics_include_spread_and_skill_for_lola_comparison():
assert "spread" in DEFAULT_EVAL_METRICS
assert "skill" in DEFAULT_EVAL_METRICS
assert "ssr" in DEFAULT_EVAL_METRICS


def test_should_skip_metric_variogram_only():
assert _should_skip_metric("variogram") is True
assert _should_skip_metric("crps") is False
Expand Down
Loading