diff --git a/ice_station_zebra/callbacks/metric_summary_callback.py b/ice_station_zebra/callbacks/metric_summary_callback.py index 37338566..869ed8e3 100644 --- a/ice_station_zebra/callbacks/metric_summary_callback.py +++ b/ice_station_zebra/callbacks/metric_summary_callback.py @@ -1,10 +1,16 @@ +import logging import statistics +from collections.abc import Mapping from typing import Any from lightning import LightningModule, Trainer from lightning.pytorch import Callback from torch import Tensor +from ice_station_zebra.types import ModelTestOutput + +logger = logging.getLogger(__name__) + class MetricSummaryCallback(Callback): """A callback to summarise metrics during evaluation.""" @@ -23,14 +29,19 @@ def on_test_batch_end( self, trainer: Trainer, module: LightningModule, - outputs: dict[str, Tensor], # type: ignore[override] + outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the test batch ends.""" + if not isinstance(outputs, ModelTestOutput): + msg = f"Output is of type {type(outputs)}, skipping metric accumulation." + logger.warning(msg) + return + if "average_loss" in self.metrics: - self.metrics["average_loss"].append(outputs["loss"].item()) + self.metrics["average_loss"].append(outputs.loss.item()) def on_test_epoch_end(self, trainer: Trainer, module: LightningModule) -> None: """Called at the end of the test epoch.""" diff --git a/ice_station_zebra/callbacks/plotting_callback.py b/ice_station_zebra/callbacks/plotting_callback.py index c0ec8b8a..1b151b87 100644 --- a/ice_station_zebra/callbacks/plotting_callback.py +++ b/ice_station_zebra/callbacks/plotting_callback.py @@ -1,4 +1,5 @@ import logging +from collections.abc import Mapping, Sequence from typing import Any from lightning import LightningModule, Trainer @@ -7,6 +8,7 @@ from torch.utils.data import DataLoader from ice_station_zebra.data_loaders import CombinedDataset +from ice_station_zebra.types import ModelTestOutput from ice_station_zebra.visualisations import plot_sic_comparison logger = logging.getLogger(__name__) @@ -34,45 +36,51 @@ def on_test_batch_end( self, trainer: Trainer, module: LightningModule, - outputs: dict[str, Tensor], # type: ignore[override] + outputs: Tensor | Mapping[str, Any] | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Called when the test batch ends.""" - # Run plotting every `frequency` batches - if batch_idx % self.frequency == 0: - # Get date for this batch - batch_size = outputs["target"].shape[0] - test_dataloaders: DataLoader | list[DataLoader] | None = ( - trainer.test_dataloaders - ) - if test_dataloaders is None: - logger.debug("No test dataloaders found, skipping plotting.") - return - dataset: CombinedDataset = ( - test_dataloaders[dataloader_idx] - if isinstance(test_dataloaders, list) - else test_dataloaders - ).dataset # type: ignore[assignment] - date_ = dataset.date_from_index(batch_size * batch_idx) + # Only run plotting every `frequency` batches + if batch_idx % self.frequency: + return - # Load the ground truth and prediction - np_ground_truth = outputs["target"].cpu().numpy()[0, 0, 0, :, :] - np_prediction = outputs["output"].cpu().numpy()[0, 0, 0, :, :] + # Check that outputs is a ModelTestOutput + if not isinstance(outputs, ModelTestOutput): + msg = f"Output is of type {type(outputs)}, skipping plotting." + logger.warning(msg) + return - # Create each requested plot - images = { - name: plot_fn(np_ground_truth, np_prediction, date_) - for name, plot_fn in self.plot_fns.items() - } + # Get date for this batch + dl: DataLoader | list[DataLoader] | None = trainer.test_dataloaders + if dl is None: + logger.warning("No test dataloaders found, skipping plotting.") + return + dataset = (dl[dataloader_idx] if isinstance(dl, Sequence) else dl).dataset + if not isinstance(dataset, CombinedDataset): + logger.warning("Dataset is not a CombinedDataset, skipping plotting.") + return + batch_size = outputs.target.shape[0] + date_ = dataset.date_from_index(batch_size * batch_idx) - # Log images to each logger - for lightning_logger in trainer.loggers: - for key, image_list in images.items(): - if hasattr(lightning_logger, "log_image"): - lightning_logger.log_image(key=key, images=image_list) - else: - logger.debug( - f"Logger {lightning_logger.name} does not support logging images." - ) + # Load the ground truth and prediction + # Both prediction and target are TensorNTCHW + np_ground_truth = outputs.target.cpu().numpy()[0, 0, 0, :, :] + np_prediction = outputs.prediction.cpu().numpy()[0, 0, 0, :, :] + + # Create each requested plot + images = { + name: plot_fn(np_ground_truth, np_prediction, date_) + for name, plot_fn in self.plot_fns.items() + } + + # Log images to each logger + for lightning_logger in trainer.loggers: + for key, image_list in images.items(): + if hasattr(lightning_logger, "log_image"): + lightning_logger.log_image(key=key, images=image_list) + else: + logger.debug( + f"Logger {lightning_logger.name} does not support logging images." + ) diff --git a/ice_station_zebra/models/zebra_model.py b/ice_station_zebra/models/zebra_model.py index 554929df..a06d97ca 100644 --- a/ice_station_zebra/models/zebra_model.py +++ b/ice_station_zebra/models/zebra_model.py @@ -7,10 +7,12 @@ from lightning.pytorch.utilities.types import OptimizerLRScheduler from omegaconf import DictConfig -from ice_station_zebra.types import DataSpace, TensorNTCHW +from ice_station_zebra.types import DataSpace, ModelTestOutput, TensorNTCHW class ZebraModel(LightningModule, ABC): + """A base class for all models used in the Ice Station Zebra project.""" + def __init__( self, *, @@ -61,41 +63,52 @@ def configure_optimizers(self) -> OptimizerLRScheduler: } ) - def loss(self, output: TensorNTCHW, target: TensorNTCHW) -> torch.Tensor: - return torch.nn.functional.l1_loss(output, target) + def loss(self, prediction: TensorNTCHW, target: TensorNTCHW) -> torch.Tensor: + """Calculate the loss given a prediction and target""" + return torch.nn.functional.l1_loss(prediction, target) def test_step( self, batch: dict[str, TensorNTCHW], batch_idx: int - ) -> dict[str, torch.Tensor]: + ) -> ModelTestOutput: """Run the test step, in PyTorch eval model (i.e. no gradients) - A batch contains one tensor for each input dataset and one for the target - These are [NTCHW] tensors with (batch_size, n_history_steps, C, H, W) - - Separate the batch into inputs and target - Run inputs through the model - - Return the output, target and loss + - Return the prediction, target and loss + + Args: + batch: Dictionary mapping dataset name to its contents. There is one entry + for each input dataset and one for the target. Each of these is a + TensorNTCHW with (batch_size, n_history_steps, C, H, W). + + Returns: + A ModelTestOutput containing the prediction, target and loss for the batch. """ target = batch.pop("target") - output = self(batch) - loss = self.loss(output, target) - return {"output": output, "target": target, "loss": loss} + prediction = self(batch) + loss = self.loss(prediction, target) + return ModelTestOutput(prediction, target, loss) def training_step( self, batch: dict[str, TensorNTCHW], batch_idx: int ) -> torch.Tensor: """Run the training step - A batch contains one tensor for each input dataset and one for the target - These are [NTCHW] tensors with (batch_size, n_history_steps, C, H, W) - - Separate the batch into inputs and target - Run inputs through the model - Calculate the loss wrt. the target + + Args: + batch: Dictionary mapping dataset name to its contents. There is one entry + for each input dataset and one for the target. Each of these is a + TensorNTCHW with (batch_size, n_history_steps, C, H, W). + + Returns: + A Tensor containing the loss for the batch. """ target = batch.pop("target") - output = self(batch) - return self.loss(output, target) + prediction = self(batch) + return self.loss(prediction, target) def validation_step( self, batch: dict[str, TensorNTCHW], batch_idx: int @@ -107,10 +120,18 @@ def validation_step( - Separate the batch into inputs and target - Run inputs through the model - - Calculate the loss wrt. the target + - Calculate and log the loss wrt. the target + + Args: + batch: Dictionary mapping dataset name to its contents. There is one entry + for each input dataset and one for the target. Each of these is a + TensorNTCHW with (batch_size, n_history_steps, C, H, W). + + Returns: + A Tensor containing the loss for the batch. """ target = batch.pop("target") - output = self(batch) - loss = self.loss(output, target) + prediction = self(batch) + loss = self.loss(prediction, target) self.log("validation_loss", loss) return loss diff --git a/ice_station_zebra/types.py b/ice_station_zebra/types.py index cfaa81e4..1125c7db 100644 --- a/ice_station_zebra/types.py +++ b/ice_station_zebra/types.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Mapping, Sequence from dataclasses import dataclass from typing import Any, Self, TypedDict @@ -67,3 +67,29 @@ def to_dict(self) -> DictConfig: return DictConfig( {"channels": self.channels, "name": self.name, "shape": self.shape} ) + + +@dataclass +class ModelTestOutput(Mapping[str, Tensor]): + """Output of a model test step.""" + + prediction: TensorNTCHW + target: TensorNTCHW + loss: Tensor + + def __getitem__(self, key: str) -> Tensor: + if key == "prediction": + return self.prediction + if key == "target": + return self.target + if key == "loss": + return self.loss + raise KeyError(f"Key {key} not found in ModelTestOutput") + + def __iter__(self) -> Iterator[str]: + yield "prediction" + yield "target" + yield "loss" + + def __len__(self) -> int: + return 3