Skip to content
15 changes: 13 additions & 2 deletions ice_station_zebra/callbacks/metric_summary_callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
import statistics
from collections.abc import Mapping
from typing import Any

from lightning import LightningModule, Trainer
from lightning.pytorch import Callback
from torch import Tensor

from ice_station_zebra.types import ModelTestOutput

logger = logging.getLogger(__name__)


class MetricSummaryCallback(Callback):
"""A callback to summarise metrics during evaluation."""
Expand All @@ -23,14 +29,19 @@ def on_test_batch_end(
self,
trainer: Trainer,
module: LightningModule,
outputs: dict[str, Tensor], # type: ignore[override]
outputs: Tensor | Mapping[str, Any] | None,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch ends."""
if not isinstance(outputs, ModelTestOutput):
msg = f"Output is of type {type(outputs)}, skipping metric accumulation."
logger.warning(msg)
return

if "average_loss" in self.metrics:
self.metrics["average_loss"].append(outputs["loss"].item())
self.metrics["average_loss"].append(outputs.loss.item())

def on_test_epoch_end(self, trainer: Trainer, module: LightningModule) -> None:
"""Called at the end of the test epoch."""
Expand Down
76 changes: 42 additions & 34 deletions ice_station_zebra/callbacks/plotting_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any

from lightning import LightningModule, Trainer
Expand All @@ -7,6 +8,7 @@
from torch.utils.data import DataLoader

from ice_station_zebra.data_loaders import CombinedDataset
from ice_station_zebra.types import ModelTestOutput
from ice_station_zebra.visualisations import plot_sic_comparison

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -34,45 +36,51 @@ def on_test_batch_end(
self,
trainer: Trainer,
module: LightningModule,
outputs: dict[str, Tensor], # type: ignore[override]
outputs: Tensor | Mapping[str, Any] | None,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
"""Called when the test batch ends."""
# Run plotting every `frequency` batches
if batch_idx % self.frequency == 0:
# Get date for this batch
batch_size = outputs["target"].shape[0]
test_dataloaders: DataLoader | list[DataLoader] | None = (
trainer.test_dataloaders
)
if test_dataloaders is None:
logger.debug("No test dataloaders found, skipping plotting.")
return
dataset: CombinedDataset = (
test_dataloaders[dataloader_idx]
if isinstance(test_dataloaders, list)
else test_dataloaders
).dataset # type: ignore[assignment]
date_ = dataset.date_from_index(batch_size * batch_idx)
# Only run plotting every `frequency` batches
if batch_idx % self.frequency:
return

# Load the ground truth and prediction
np_ground_truth = outputs["target"].cpu().numpy()[0, 0, 0, :, :]
np_prediction = outputs["output"].cpu().numpy()[0, 0, 0, :, :]
# Check that outputs is a ModelTestOutput
if not isinstance(outputs, ModelTestOutput):
msg = f"Output is of type {type(outputs)}, skipping plotting."
logger.warning(msg)
return

# Create each requested plot
images = {
name: plot_fn(np_ground_truth, np_prediction, date_)
for name, plot_fn in self.plot_fns.items()
}
# Get date for this batch
dl: DataLoader | list[DataLoader] | None = trainer.test_dataloaders
if dl is None:
logger.warning("No test dataloaders found, skipping plotting.")
return
dataset = (dl[dataloader_idx] if isinstance(dl, Sequence) else dl).dataset
if not isinstance(dataset, CombinedDataset):
logger.warning("Dataset is not a CombinedDataset, skipping plotting.")
return
batch_size = outputs.target.shape[0]
date_ = dataset.date_from_index(batch_size * batch_idx)

# Log images to each logger
for lightning_logger in trainer.loggers:
for key, image_list in images.items():
if hasattr(lightning_logger, "log_image"):
lightning_logger.log_image(key=key, images=image_list)
else:
logger.debug(
f"Logger {lightning_logger.name} does not support logging images."
)
# Load the ground truth and prediction
# Both prediction and target are TensorNTCHW
np_ground_truth = outputs.target.cpu().numpy()[0, 0, 0, :, :]
np_prediction = outputs.prediction.cpu().numpy()[0, 0, 0, :, :]

# Create each requested plot
images = {
name: plot_fn(np_ground_truth, np_prediction, date_)
for name, plot_fn in self.plot_fns.items()
}

# Log images to each logger
for lightning_logger in trainer.loggers:
for key, image_list in images.items():
if hasattr(lightning_logger, "log_image"):
lightning_logger.log_image(key=key, images=image_list)
else:
logger.debug(
f"Logger {lightning_logger.name} does not support logging images."
)
10 changes: 5 additions & 5 deletions ice_station_zebra/models/zebra_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from omegaconf import DictConfig

from ice_station_zebra.types import DataSpace, TensorNTCHW
from ice_station_zebra.types import DataSpace, ModelTestOutput, TensorNTCHW


class ZebraModel(LightningModule, ABC):
Expand Down Expand Up @@ -66,7 +66,7 @@ def loss(self, output: TensorNTCHW, target: TensorNTCHW) -> torch.Tensor:

def test_step(
self, batch: dict[str, TensorNTCHW], batch_idx: int
) -> dict[str, torch.Tensor]:
) -> ModelTestOutput:
"""Run the test step, in PyTorch eval model (i.e. no gradients)

A batch contains one tensor for each input dataset and one for the target
Expand All @@ -77,9 +77,9 @@ def test_step(
- Return the output, target and loss
Comment thread
jemrobinson marked this conversation as resolved.
Outdated
"""
target = batch.pop("target")
output = self(batch)
loss = self.loss(output, target)
return {"output": output, "target": target, "loss": loss}
prediction = self(batch)
loss = self.loss(prediction, target)
Comment thread
jemrobinson marked this conversation as resolved.
return ModelTestOutput(prediction, target, loss)

def training_step(
self, batch: dict[str, TensorNTCHW], batch_idx: int
Expand Down
28 changes: 27 additions & 1 deletion ice_station_zebra/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Self, TypedDict

Expand Down Expand Up @@ -67,3 +67,29 @@ def to_dict(self) -> DictConfig:
return DictConfig(
{"channels": self.channels, "name": self.name, "shape": self.shape}
)


@dataclass
class ModelTestOutput(Mapping[str, Tensor]):
"""Output of a model test step."""

prediction: TensorNTCHW
target: TensorNTCHW
loss: Tensor

def __getitem__(self, key: str) -> Tensor:
if key == "prediction":
return self.prediction
if key == "target":
return self.target
if key == "loss":
return self.loss
raise KeyError(f"Key {key} not found in ModelTestOutput")

def __iter__(self) -> Iterator[str]:
yield "prediction"
yield "target"
yield "loss"

def __len__(self) -> int:
return 3
Loading