Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/autocast/metrics/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,8 @@ class SpreadSkillRatio(BTSCMMetric):

name: str = "ssr"

def __init__(self, eps: float = 1e-6):
super().__init__()
def __init__(self, eps: float = 1e-6, **kwargs):
super().__init__(**kwargs)
if eps <= 0:
msg = "eps must be > 0"
raise ValueError(msg)
Expand All @@ -598,8 +598,8 @@ def score(
Compute corrected spread-to-skill ratio.

Reductions (spatial/temporal) are applied to the variance and MSE before
taking the square root and computing the ratio, matching macroscopic
approaches like Lola's.
taking the square root and computing the ratio (i.e., reduce variance/MSE
first, then sqrt, then divide).

Args:
y_pred: (B, T, S, C, M)
Expand Down Expand Up @@ -634,7 +634,7 @@ def score(
skill_sq = skill_sq.mean(dim=1)
spread_var = spread_var.mean(dim=1)

# Compute macroscopic spread, skill, and ratio
# Reduce to spread, skill, and take ratio
skill = torch.sqrt(skill_sq)
spread = torch.sqrt(spread_var)

Expand Down
139 changes: 139 additions & 0 deletions tests/metrics/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,145 @@ def test_spread_skill_ratio_requires_multiple_ensemble_members():
SpreadSkillRatio()(y_pred, y_true)


def _controlled_ssr_batch(
sigma_t: torch.Tensor,
bias: float | torch.Tensor,
B: int,
S1: int,
S2: int,
C: int,
M: int,
seed: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build an ensemble with known per-lead-time spread and skill.

For each lead time t, ensemble members equal ``bias + sigma_t * z_m'`` where
``z_m'`` are ensemble-centered (i.e. sum to zero across ``M``) standard
normals. This gives:
- pointwise ensemble mean == bias (so skill == |bias| exactly),
- pointwise unbiased ensemble variance with expectation ``sigma_t**2``.

``bias`` may be a scalar (constant skill across time) or a 1D tensor of
length T (per-lead-time skill). Ensemble-centering removes sampling noise
in the skill so per-t SSR values are predictable within a tight tolerance
even with modest B/S/M.
"""
T = sigma_t.shape[0]
torch.manual_seed(seed)
z = torch.randn((B, T, S1, S2, C, M))
z = z - z.mean(dim=-1, keepdim=True)
sigma = sigma_t.view(1, T, 1, 1, 1, 1)
bias_bcast = bias.view(1, T, 1, 1, 1, 1) if isinstance(bias, torch.Tensor) else bias
y_pred = bias_bcast + sigma * z
y_true = torch.zeros((B, T, S1, S2, C))
return y_pred, y_true


def test_spread_skill_ratio_over_time_tracks_growing_skill_when_spread_is_fixed():
"""Holding spread constant and growing skill with lead time, SSR should
decrease monotonically across T (``spread / skill -> 0``). This emulates a
rollout where ensemble fails to broaden as error grows -- the classic
under-dispersion signature that should show up on the lead-time panel.
"""
B, T, S1, S2, C, M = 8, 5, 16, 16, 1, 32
skill_t = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0])
sigma_t = torch.ones(T) # spread fixed at 1.0

y_pred, y_true = _controlled_ssr_batch(
sigma_t, bias=skill_t, B=B, S1=S1, S2=S2, C=C, M=M
)

ssr = SpreadSkillRatio().score(y_pred, y_true) # (B, T, C)
per_t = ssr.mean(dim=(0, 2))

correction = float(((M + 1) / M) ** 0.5)
expected = correction / skill_t
assert torch.allclose(per_t, expected, rtol=5e-2, atol=5e-3), (
per_t,
expected,
)
# Explicit monotonic-decrease check: belt-and-braces against any future
# change that silently re-orders time or swaps numerator/denominator.
diffs = per_t[1:] - per_t[:-1]
assert torch.all(diffs < 0), per_t


def test_spread_skill_ratio_over_time_calibrated_stays_near_one():
"""Well-dispersed ensemble: members and truth drawn iid from the same
per-lead-time distribution, with independent noise for truth. For a
perfectly calibrated ensemble the corrected SSR has expectation 1.0 at
every lead time, irrespective of how the per-t variance scales.

The classical identities used here:
- E[spread^2] = sigma_t^2 (unbiased variance estimator)
- E[skill^2] = sigma_t^2 * (M+1)/M (ensemble mean vs independent truth)
-> SSR_corrected = sqrt(M/(M+1)) * sqrt((M+1)/M) = 1.
"""
B, S1, S2, C, M = 16, 16, 16, 1, 16
sigma_t = torch.tensor([0.5, 1.0, 2.0, 4.0])
T = sigma_t.shape[0]

torch.manual_seed(0)
sigma_bcast = sigma_t.view(1, T, 1, 1, 1, 1)
y_pred = sigma_bcast * torch.randn((B, T, S1, S2, C, M))
y_true = sigma_bcast.squeeze(-1) * torch.randn((B, T, S1, S2, C))

ssr = SpreadSkillRatio().score(y_pred, y_true) # (B, T, C)
per_t = ssr.mean(dim=(0, 2))

# Sampling noise at (B=16, S=256, M=16); 10% tolerance is comfortably safe.
assert torch.all(torch.abs(per_t - 1.0) < 0.1), per_t


def test_spread_skill_ratio_stateful_returns_per_lead_time_and_is_mean_of_ratios():
"""Lock in the mean-of-ratios aggregation convention across update() calls.

Two batches with deliberately different per-batch SSRs are streamed through
``update()``. The stateful ``compute()`` with ``reduce_all=False`` should:
1. expose a per-lead-time vector (shape (T, C)), and
2. equal the arithmetic mean of the per-batch SSRs (mean of per-sample
``spread/skill``) rather than the macroscopic ratio of pooled second
moments.

If somebody silently switches to a macroscopic ratio in the future, this
test fails and the behaviour change is caught.
"""
B, S1, S2, C, M = 4, 16, 16, 1, 32
sigma_t = torch.tensor([1.0, 1.0, 1.0]) # spread fixed across t
T = sigma_t.shape[0]
correction = float(((M + 1) / M) ** 0.5)

# Batch A: skill=1, spread=1 -> SSR ~= 1 * correction at every t
pred_a, true_a = _controlled_ssr_batch(
sigma_t, bias=1.0, B=B, S1=S1, S2=S2, C=C, M=M, seed=0
)
# Batch B: skill=0.1, spread=1 -> SSR ~= 10 * correction at every t
pred_b, true_b = _controlled_ssr_batch(
sigma_t, bias=0.1, B=B, S1=S1, S2=S2, C=C, M=M, seed=1
)

metric = SpreadSkillRatio(reduce_all=False)
metric.update(pred_a, true_a)
metric.update(pred_b, true_b)
value = metric.compute() # (T, C)

assert value.shape == (T, C), value.shape

expected_mean_of_ratios = 0.5 * (1.0 + 10.0) * correction # ~5.5 * correction
expected_macroscopic = correction * (
((sigma_t[0] ** 2 + sigma_t[0] ** 2) / 2).sqrt()
/ ((1.0**2 + 0.1**2) / 2) ** 0.5
) # ~= correction / sqrt(0.505) ~= 1.41 * correction

# Must match mean-of-ratios, must NOT match macroscopic ratio.
assert torch.allclose(
value, torch.full_like(value, expected_mean_of_ratios), rtol=5e-2
), value
assert not torch.allclose(
value, torch.full_like(value, float(expected_macroscopic)), rtol=2e-1
), value


def test_winkler_score_manual_value():
# Shape: (B=1, T=1, S=2, C=1, M=5)
# Ensemble members: [0, 1, 2, 3, 4], alpha=0.2
Expand Down
Loading