Skip to content

Commit b326e5b

Browse files
IFentonCopilot
andauthored
Adding plotting metrics to WandB (#216)
* 🚧 Adding plots of MAE by day * 🐛 Fixing error with plotting callback dates * 🚧 Adding a new way of calcuating SIEError * 🚨 linting * 🚧 Calculating the SIEError metric by day * 🚧 Adding line plot to wandb * 🚧 automating table / line plot creation * 🚧 Log sieerror at each test step * 🚧 Calculating MAE / RMSE daily * 🎨 Remove duplicate code * 🐛 Removing code that was plotting the wrong value for the last epoch value * 🚧 Calculating the mean value across all days and plotting * 🚧 Tidying up the metrics to calculate * 🚨 linting * 🚨 linting * ✅ Adding tests for metrics * 🚨 linting * ✅ Checking logic of tests * 🚚 Rename SIE_error_new to SIE_error_abs * 🚨 linting * Correct typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Correcting doc string Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * 🥅 Make sure Table / line plot are only done with a WandB logger * ⚡ Only run the metric compute step at the end of each epoch * ✅ Adding a unit test for wandb.Table and wandb.plot.line * 🎨 better way of calculating RMSEDaily * 🎨 Remove unnecessary logging * 🎨 Remove batch size dependence for SIEError * ♻️ Create base metric class for MAE and RMSE * 🚨 linting * ♻️ Refactoring SIEError code to make it more robust * ♻️ Removing average_loss as a metric * 🚚 Move metrics / losses out of models folder * 🎨 Update metric names from Daily to PerForecastDay * 📝 Adding description of how SIEError is calculated * ♻️ Use get_wand_run * 🎨 Refactor use of get_wandb_run * ♻️ Improve check of test_metrics type --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f72f4cc commit b326e5b

15 files changed

Lines changed: 515 additions & 55 deletions
Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,44 @@
11
import logging
2-
import statistics
3-
from collections.abc import Mapping
4-
from typing import Any
52

3+
import wandb
64
from lightning import LightningModule, Trainer
75
from lightning.pytorch import Callback
8-
from torch import Tensor
6+
from torchmetrics import MetricCollection
97

10-
from icenet_mp.types import ModelTestOutput
8+
from icenet_mp.utils import get_wandb_run
119

1210
logger = logging.getLogger(__name__)
1311

1412

1513
class MetricSummaryCallback(Callback):
1614
"""A callback to summarise metrics during evaluation."""
1715

18-
def __init__(self, *, average_loss: bool = True) -> None:
19-
"""Summarise metrics during evaluation.
20-
21-
Args:
22-
average_loss: Whether to log average loss
23-
24-
"""
25-
self.metrics: dict[str, list[float]] = {}
26-
if average_loss:
27-
self.metrics["average_loss"] = []
28-
29-
def on_test_batch_end(
30-
self,
31-
_trainer: Trainer,
32-
_module: LightningModule,
33-
outputs: Tensor | Mapping[str, Any] | None,
34-
_batch: Any, # noqa: ANN401
35-
_batch_idx: int,
36-
_dataloader_idx: int = 0,
37-
) -> None:
38-
"""Called when the test batch ends."""
39-
if not isinstance(outputs, ModelTestOutput):
40-
msg = f"Output is of type {type(outputs)}, skipping metric accumulation."
41-
logger.warning(msg)
16+
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
17+
"""Called at the end of testing."""
18+
test_metrics: MetricCollection = pl_module.test_metrics # type: ignore[assignment]
19+
if not isinstance(pl_module.test_metrics, MetricCollection):
20+
logger.warning("Could not load test metrics!")
4221
return
4322

44-
if "average_loss" in self.metrics:
45-
self.metrics["average_loss"].append(outputs.loss.item())
46-
47-
def on_test_epoch_end(
48-
self,
49-
trainer: Trainer,
50-
_module: LightningModule,
51-
) -> None:
52-
"""Called at the end of the test epoch."""
53-
# Post-process accumulated metrics into a single value
54-
metrics_: dict[str, float] = {}
55-
for name, values in self.metrics.items():
56-
if name.startswith("average_"):
57-
metrics_[name] = statistics.mean(values)
58-
59-
# Log metrics to each logger
60-
for logger in trainer.loggers:
61-
logger.log_metrics(metrics_)
23+
for name, metric in test_metrics.items():
24+
# Compute the metric value (e.g., SIEError) across all batches and log it
25+
values = metric.compute()
26+
27+
for logger_ in trainer.loggers:
28+
# Log the mean value of the metric across all days
29+
logger_.log_metrics({f"{name} (mean)": values.mean().item()})
30+
31+
# check if WandB is being used as a logger and metrics are calculated for multiple days
32+
# if so, log the metric values as a table and plot
33+
if (
34+
isinstance(run := get_wandb_run(trainer), wandb.Run)
35+
and values.numel() > 1
36+
):
37+
table = wandb.Table(
38+
data=list(enumerate(values.tolist(), start=1)),
39+
columns=["day", name],
40+
)
41+
plot_name = name + " per day"
42+
run.log(
43+
{plot_name: wandb.plot.line(table, "day", name, title=plot_name)}
44+
)

icenet_mp/callbacks/plotting_callback.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def on_test_batch_end(
8686

8787
# Get sequence dates for static and video plots
8888
batch_size = int(outputs.target.shape[0])
89-
n_timesteps = int(outputs.target.shape[1])
90-
dates = [
91-
datetime_from_npdatetime(dataset.dates[batch_size * batch_idx + tt])
92-
for tt in range(n_timesteps)
93-
]
9489

90+
start_date = dataset.dates[batch_size * batch_idx]
91+
92+
dates = list(
93+
map(datetime_from_npdatetime, dataset.get_forecast_steps(start_date))
94+
)
9595
# Set hemisphere for plotting based on dataset
9696
self.plotter.set_hemisphere(dataset.hemisphere)
9797

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
metric_summary:
22
_target_: icenet_mp.callbacks.MetricSummaryCallback
3-
average_loss: true
File renamed without changes.
File renamed without changes.
File renamed without changes.

icenet_mp/metrics/base_metrics.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Calculating RMSE, MAE by forecast step."""
2+
3+
import torch
4+
from torchmetrics import Metric
5+
6+
7+
class BaseErrorMetricDaily(Metric):
8+
"""Base class for per-timestep error metrics using sufficient statistics."""
9+
10+
def __init__(self) -> None:
11+
"""Initialize the metric."""
12+
super().__init__()
13+
self.sum_errors: torch.Tensor
14+
self.count: torch.Tensor
15+
self.add_state(
16+
"sum_errors",
17+
default=torch.tensor([], dtype=torch.float32),
18+
dist_reduce_fx="sum",
19+
)
20+
self.add_state(
21+
"count",
22+
default=torch.tensor([], dtype=torch.long),
23+
dist_reduce_fx="sum",
24+
)
25+
26+
def _compute_errors(
27+
self, preds: torch.Tensor, targets: torch.Tensor
28+
) -> torch.Tensor:
29+
"""Compute element-wise errors. Override in subclasses."""
30+
raise NotImplementedError
31+
32+
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
33+
"""Update metrics with a batch of predictions and targets.
34+
35+
Args:
36+
preds: Tensor of shape (batch, time, channels, height, width)
37+
targets: Tensor of shape (batch, time, channels, height, width)
38+
39+
"""
40+
# Compute errors: (batch, time, channels, height, width)
41+
errors = self._compute_errors(preds, targets)
42+
43+
batch_size = errors.shape[0]
44+
num_spatial = errors.shape[2] * errors.shape[3] * errors.shape[4]
45+
46+
# Reshape to (batch, time, -1) then sum over batch and spatial dims
47+
errors_reshaped = errors.view(batch_size, -1, num_spatial)
48+
batch_sum_errors = errors_reshaped.sum(dim=(0, 2))
49+
50+
# Count samples per time step
51+
batch_count = torch.full(
52+
(errors.shape[1],),
53+
batch_size * num_spatial,
54+
dtype=torch.long,
55+
device=errors.device,
56+
)
57+
58+
# Initialize buffers on first update
59+
if self.sum_errors.numel() == 0:
60+
self.sum_errors = batch_sum_errors
61+
self.count = batch_count
62+
else:
63+
if self.sum_errors.shape[0] != batch_sum_errors.shape[0]:
64+
msg = f"Time dimension mismatch: expected {self.sum_errors.shape[0]}, got {batch_sum_errors.shape[0]}"
65+
raise ValueError(msg)
66+
self.sum_errors += batch_sum_errors
67+
self.count += batch_count
68+
69+
def _finalize(self, mean_errors: torch.Tensor) -> torch.Tensor:
70+
"""Apply final transformation to mean errors. Override in subclasses."""
71+
return mean_errors
72+
73+
def compute(self) -> torch.Tensor:
74+
"""Compute metric per lead time from accumulated sufficient statistics.
75+
76+
Returns:
77+
Tensor of shape (T,) with metric value for each time step
78+
79+
"""
80+
if self.count.numel() == 0:
81+
return torch.tensor([], dtype=torch.float32, device=self.sum_errors.device)
82+
83+
count = torch.clamp(self.count, min=1)
84+
mean_errors = self.sum_errors / count.float()
85+
return self._finalize(mean_errors)
86+
87+
88+
class RMSEPerForecastDay(BaseErrorMetricDaily):
89+
"""Root Mean Squared Error per forecast lead time."""
90+
91+
def _compute_errors(
92+
self, preds: torch.Tensor, targets: torch.Tensor
93+
) -> torch.Tensor:
94+
return (preds - targets) ** 2
95+
96+
def _finalize(self, mean_errors: torch.Tensor) -> torch.Tensor:
97+
return torch.sqrt(mean_errors)
98+
99+
100+
class MAEPerForecastDay(BaseErrorMetricDaily):
101+
"""Mean Absolute Error per forecast lead time."""
102+
103+
def _compute_errors(
104+
self, preds: torch.Tensor, targets: torch.Tensor
105+
) -> torch.Tensor:
106+
return torch.abs(preds - targets)
File renamed without changes.

0 commit comments

Comments
 (0)