diff --git a/README.md b/README.md index 141b849c..7738710b 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,9 @@ You can then run this with, e.g.: ```bash uv run imp --config-name .yaml ``` + +This will run using the default model setup (rescaling encoder, small UNet, rescaling decoder) that is sufficient for quick tests, but not appropriate for larger training runs. + You can also use this config to override other options in the `base.yaml` file, as shown below: ```yaml @@ -74,7 +77,7 @@ uv run imp ++base_path=/local/path/to/my/data See `config/demo_north.yaml` for an example of this. -Note that `base_persistence.yaml` overrides the specific options in `base.yaml` needed to run the `Persistence` model. +:warning: Note that `base_persistence.yaml` overrides the specific options in `base.yaml` needed to run the `Persistence` model. ### HPC-specific configurations diff --git a/icenet_mp/callbacks/plotting_callback.py b/icenet_mp/callbacks/plotting_callback.py index d80ee8be..32943687 100644 --- a/icenet_mp/callbacks/plotting_callback.py +++ b/icenet_mp/callbacks/plotting_callback.py @@ -8,6 +8,7 @@ from torch import Tensor from icenet_mp.data_loaders import CombinedDataset +from icenet_mp.models import BaseModel from icenet_mp.types import ModelTestOutput, PlotSpec from icenet_mp.utils import datetime_from_npdatetime from icenet_mp.visualisations import DEFAULT_SIC_SPEC, Plotter @@ -56,7 +57,7 @@ def set_metadata(self, config: DictConfig, model_name: str) -> None: def on_test_batch_end( self, trainer: Trainer, - pl_module: LightningModule, # noqa: ARG002 + pl_module: LightningModule, outputs: Tensor | Mapping[str, Any] | None, batch: Any, # noqa: ANN401, ARG002 batch_idx: int, @@ -93,7 +94,11 @@ def on_test_batch_end( map(datetime_from_npdatetime, dataset.get_forecast_steps(start_date)) ) # Set hemisphere for plotting based on dataset - self.plotter.set_hemisphere(dataset.hemisphere) + if not isinstance(pl_module, BaseModel): + msg = f"Lightning module is of type {type(pl_module)}, skipping plotting." + logger.warning(msg) + return + self.plotter.set_hemisphere(pl_module.hemisphere) # Get loggers that support image and video logging image_loggers = [ll for ll in trainer.loggers if hasattr(ll, "log_image")] diff --git a/icenet_mp/cli/hydra.py b/icenet_mp/cli/hydra.py index 364ebf47..dbf7f2d5 100644 --- a/icenet_mp/cli/hydra.py +++ b/icenet_mp/cli/hydra.py @@ -32,7 +32,7 @@ def wrapper( config_name: Annotated[ str | None, Option(help="Specify the name of a file to load from the config directory"), - ] = "base", + ] = "sample", *args: Param.args, **kwargs: Param.kwargs, ) -> RetType: diff --git a/icenet_mp/config/base.yaml b/icenet_mp/config/base.yaml index b21b2a82..991e1654 100644 --- a/icenet_mp/config/base.yaml +++ b/icenet_mp/config/base.yaml @@ -5,7 +5,7 @@ defaults: - loggers: - wandb - model: naive_unet_naive - - predict: sic-icenet + - predict: sic-icenet-2d - train: default - _self_ diff --git a/icenet_mp/config/demo_north.yaml b/icenet_mp/config/demo_north.yaml index 1c8d2edf..b0dc6ff8 100644 --- a/icenet_mp/config/demo_north.yaml +++ b/icenet_mp/config/demo_north.yaml @@ -1,5 +1,5 @@ defaults: - base - override /data: demo - - override /predict: osisaf-north + - override /predict: sic-icenet-2d - _self_ diff --git a/icenet_mp/config/model/cnn_null_cnn.yaml b/icenet_mp/config/model/cnn_null_cnn.yaml index 1dcc30f9..a2923f03 100644 --- a/icenet_mp/config/model/cnn_null_cnn.yaml +++ b/icenet_mp/config/model/cnn_null_cnn.yaml @@ -4,15 +4,11 @@ name: cnn-null-cnn encoder: _target_: icenet_mp.models.encoders.CNNEncoder - kernel_size: 3 # Size of the kernel for convolutional layers - latent_space: [128, 128] # Shape of the latent space - n_layers: 3 # Number of convolutional layers + latent_space: [144, 144] # Shape of the latent space processor: _target_: icenet_mp.models.processors.NullProcessor decoder: _target_: icenet_mp.models.decoders.CNNDecoder - kernel_size: 3 # Size of the kernel for convolutional layers - n_layers: 3 # Number of convolutional layers bounded: false # Whether to bound the output between 0 and 1 diff --git a/icenet_mp/config/model/cnn_unet_cnn.yaml b/icenet_mp/config/model/cnn_unet_cnn.yaml index 5ec990a8..88db54cc 100644 --- a/icenet_mp/config/model/cnn_unet_cnn.yaml +++ b/icenet_mp/config/model/cnn_unet_cnn.yaml @@ -4,17 +4,11 @@ name: cnn-unet-cnn encoder: _target_: icenet_mp.models.encoders.CNNEncoder - kernel_size: 3 # Size of the kernel for convolutional layers - latent_space: [128, 128] # Shape of the latent space - n_layers: 3 # Number of convolutional layers + latent_space: [144, 144] # Shape of the latent space processor: _target_: icenet_mp.models.processors.UNetProcessor - kernel_size: 3 # Size of the kernel for convolutional layers - start_out_channels: 64 # Initial number of channels for the first convolutional layer decoder: _target_: icenet_mp.models.decoders.CNNDecoder - kernel_size: 3 # Size of the kernel for convolutional layers - n_layers: 3 # Number of convolutional layers bounded: false # Whether to bound the output between 0 and 1 diff --git a/icenet_mp/config/model/cnn_vit_cnn.yaml b/icenet_mp/config/model/cnn_vit_cnn.yaml index 63d4246f..a6902d0f 100644 --- a/icenet_mp/config/model/cnn_vit_cnn.yaml +++ b/icenet_mp/config/model/cnn_vit_cnn.yaml @@ -4,21 +4,11 @@ name: cnn-vit-cnn encoder: _target_: icenet_mp.models.encoders.CNNEncoder - kernel_size: 3 # Size of the kernel for convolutional layers - latent_space: [192, 192] # Shape of the latent space - n_layers: 3 # Number of convolutional layers + latent_space: [144, 144] # Shape of the latent space processor: _target_: icenet_mp.models.processors.VitProcessor - patch_size: 16 - emb_dim: 128 - depth: 3 - heads: 4 - mlp_dim: 256 - dropout: 0.3 decoder: _target_: icenet_mp.models.decoders.CNNDecoder - kernel_size: 3 # Size of the kernel for convolutional layers - n_layers: 3 # Number of convolutional layers bounded: false # Whether to bound the output between 0 and 1 diff --git a/icenet_mp/config/model/ddpm.yaml b/icenet_mp/config/model/ddpm.yaml index 3a3bb180..d551f64e 100644 --- a/icenet_mp/config/model/ddpm.yaml +++ b/icenet_mp/config/model/ddpm.yaml @@ -1,13 +1,4 @@ _target_: icenet_mp.models.ddpm.DDPM +# Run DDPM model with default settings name: ddpm - -# DDPM parameters -timesteps: 1000 -learning_rate: 5e-4 -start_out_channels: 32 -kernel_size: 3 -activation: "SiLU" -normalization: "groupnorm" -time_embed_dim : 256 -dropout_rate: 0.1 diff --git a/icenet_mp/config/model/naive_null_naive.yaml b/icenet_mp/config/model/naive_null_naive.yaml index 55effefa..20b7dd19 100644 --- a/icenet_mp/config/model/naive_null_naive.yaml +++ b/icenet_mp/config/model/naive_null_naive.yaml @@ -4,7 +4,7 @@ name: naive-null-naive encoder: _target_: icenet_mp.models.encoders.NaiveLinearEncoder - latent_space: [128, 128] # Shape of the latent space + latent_space: [432, 432] # Shape of the latent space processor: _target_: icenet_mp.models.processors.NullProcessor diff --git a/icenet_mp/config/model/naive_unet_naive.yaml b/icenet_mp/config/model/naive_unet_naive.yaml index 6bf7bd71..e87a289a 100644 --- a/icenet_mp/config/model/naive_unet_naive.yaml +++ b/icenet_mp/config/model/naive_unet_naive.yaml @@ -4,12 +4,11 @@ name: naive-unet-naive encoder: _target_: icenet_mp.models.encoders.NaiveLinearEncoder - latent_space: [128, 128] # Shape of the latent space + latent_space: [432, 432] # Shape of the latent space processor: _target_: icenet_mp.models.processors.UNetProcessor - kernel_size: 3 # Size of the kernel for convolutional layers - start_out_channels: 64 # Initial number of channels for the first convolutional layer + start_out_channels: 100 # Reduce number of channels to support 21 day forecasts decoder: _target_: icenet_mp.models.decoders.NaiveLinearDecoder diff --git a/icenet_mp/config/model/naive_vit_naive.yaml b/icenet_mp/config/model/naive_vit_naive.yaml index 7c5caf89..4f3b0ac8 100644 --- a/icenet_mp/config/model/naive_vit_naive.yaml +++ b/icenet_mp/config/model/naive_vit_naive.yaml @@ -4,16 +4,10 @@ name: naive-vit-naive encoder: _target_: icenet_mp.models.encoders.NaiveLinearEncoder - latent_space: [192, 192] # Shape of the latent space + latent_space: [432, 432] # Shape of the latent space processor: _target_: icenet_mp.models.processors.VitProcessor - patch_size: 16 - emb_dim: 128 - depth: 3 - heads: 4 - mlp_dim: 256 - dropout: 0.3 decoder: _target_: icenet_mp.models.decoders.NaiveLinearDecoder diff --git a/icenet_mp/config/model/piecewise_null_piecewise.yaml b/icenet_mp/config/model/piecewise_null_piecewise.yaml new file mode 100644 index 00000000..424f96a1 --- /dev/null +++ b/icenet_mp/config/model/piecewise_null_piecewise.yaml @@ -0,0 +1,13 @@ +_target_: icenet_mp.models.EncodeProcessDecode + +name: piecewise-null-piecewise + +encoder: + _target_: icenet_mp.models.encoders.PiecewiseEncoder + latent_space: [192, 192] # Shape of the latent space + +processor: + _target_: icenet_mp.models.processors.NullProcessor + +decoder: + _target_: icenet_mp.models.decoders.PiecewiseDecoder diff --git a/icenet_mp/config/model/piecewise_unet_piecewise.yaml b/icenet_mp/config/model/piecewise_unet_piecewise.yaml index 6aaa77aa..4ef2e5f7 100644 --- a/icenet_mp/config/model/piecewise_unet_piecewise.yaml +++ b/icenet_mp/config/model/piecewise_unet_piecewise.yaml @@ -4,15 +4,10 @@ name: piecewise-unet-piecewise encoder: _target_: icenet_mp.models.encoders.PiecewiseEncoder - latent_space: [128, 128] # Shape of the latent space - n_conv_blocks: 3 # Number of convolutional blocks to add after encoding + latent_space: [192, 192] # Shape of the latent space processor: _target_: icenet_mp.models.processors.UNetProcessor - kernel_size: 3 # Size of the kernel for convolutional layers - start_out_channels: 64 # Initial number of channels for the first convolutional layer decoder: _target_: icenet_mp.models.decoders.PiecewiseDecoder - restrict_range: clamp # Method for restricting output range (e.g., clamp, sigmoid, tanh) - n_conv_blocks: 3 # Number of convolutional blocks to add before decoding diff --git a/icenet_mp/config/model/piecewise_vit_piecewise.yaml b/icenet_mp/config/model/piecewise_vit_piecewise.yaml new file mode 100644 index 00000000..3382f030 --- /dev/null +++ b/icenet_mp/config/model/piecewise_vit_piecewise.yaml @@ -0,0 +1,13 @@ +_target_: icenet_mp.models.EncodeProcessDecode + +name: piecewise-vit-piecewise + +encoder: + _target_: icenet_mp.models.encoders.PiecewiseEncoder + latent_space: [192, 192] # Shape of the latent space + +processor: + _target_: icenet_mp.models.processors.VitProcessor + +decoder: + _target_: icenet_mp.models.decoders.PiecewiseDecoder diff --git a/icenet_mp/config/model/quick_test.yaml b/icenet_mp/config/model/quick_test.yaml new file mode 100644 index 00000000..8e94d902 --- /dev/null +++ b/icenet_mp/config/model/quick_test.yaml @@ -0,0 +1,14 @@ +_target_: icenet_mp.models.EncodeProcessDecode + +name: sample + +encoder: + _target_: icenet_mp.models.encoders.NaiveLinearEncoder + latent_space: [128, 128] + +processor: + _target_: icenet_mp.models.processors.UNetProcessor + start_out_channels: 64 + +decoder: + _target_: icenet_mp.models.decoders.NaiveLinearDecoder diff --git a/icenet_mp/config/predict/sic-icenet-14d.yaml b/icenet_mp/config/predict/sic-icenet-14d.yaml new file mode 100644 index 00000000..0b69f573 --- /dev/null +++ b/icenet_mp/config/predict/sic-icenet-14d.yaml @@ -0,0 +1,9 @@ +# Name of the dataset group containing our prediction target +target: + group_name: sic-icenet + +# Number of future steps to predict +n_forecast_steps: 14 + +# Number of history steps to use when predicting +n_history_steps: 3 diff --git a/icenet_mp/config/predict/sic-icenet-21d.yaml b/icenet_mp/config/predict/sic-icenet-21d.yaml new file mode 100644 index 00000000..50843146 --- /dev/null +++ b/icenet_mp/config/predict/sic-icenet-21d.yaml @@ -0,0 +1,9 @@ +# Name of the dataset group containing our prediction target +target: + group_name: sic-icenet + +# Number of future steps to predict +n_forecast_steps: 21 + +# Number of history steps to use when predicting +n_history_steps: 3 diff --git a/icenet_mp/config/predict/sic-icenet.yaml b/icenet_mp/config/predict/sic-icenet-2d.yaml similarity index 100% rename from icenet_mp/config/predict/sic-icenet.yaml rename to icenet_mp/config/predict/sic-icenet-2d.yaml diff --git a/icenet_mp/config/predict/sic-ssmis-14d.yaml b/icenet_mp/config/predict/sic-ssmis-14d.yaml new file mode 100644 index 00000000..3bc9bf0d --- /dev/null +++ b/icenet_mp/config/predict/sic-ssmis-14d.yaml @@ -0,0 +1,11 @@ +# Name of the dataset group containing our prediction target +target: + group_name: sic-ssmis + variables: + - ice_conc + +# Number of future steps to predict +n_forecast_steps: 14 + +# Number of history steps to use when predicting +n_history_steps: 3 diff --git a/icenet_mp/config/predict/sic-ssmis-21d.yaml b/icenet_mp/config/predict/sic-ssmis-21d.yaml new file mode 100644 index 00000000..2338611b --- /dev/null +++ b/icenet_mp/config/predict/sic-ssmis-21d.yaml @@ -0,0 +1,11 @@ +# Name of the dataset group containing our prediction target +target: + group_name: sic-ssmis + variables: + - ice_conc + +# Number of future steps to predict +n_forecast_steps: 21 + +# Number of history steps to use when predicting +n_history_steps: 3 diff --git a/icenet_mp/config/predict/sic-ssmis.yaml b/icenet_mp/config/predict/sic-ssmis-2d.yaml similarity index 100% rename from icenet_mp/config/predict/sic-ssmis.yaml rename to icenet_mp/config/predict/sic-ssmis-2d.yaml diff --git a/icenet_mp/config/sample.yaml b/icenet_mp/config/sample.yaml new file mode 100644 index 00000000..1749e14e --- /dev/null +++ b/icenet_mp/config/sample.yaml @@ -0,0 +1,4 @@ +defaults: + - base + - override /model: quick_test + - _self_ diff --git a/icenet_mp/data_loaders/combined_dataset.py b/icenet_mp/data_loaders/combined_dataset.py index 0a8ee421..8f8f6842 100644 --- a/icenet_mp/data_loaders/combined_dataset.py +++ b/icenet_mp/data_loaders/combined_dataset.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from typing import Literal import numpy as np from torch.utils.data import Dataset @@ -87,17 +86,6 @@ def start_date(self) -> np.datetime64: """Return the start date of the dataset.""" return self.dates[0] - @property - def hemisphere(self) -> Literal["north", "south"]: - """Return the hemisphere of the dataset.""" - hemisphere: set[Literal["north", "south"]] = { - ds.hemisphere for ds in self.inputs - } - if len(hemisphere) != 1: - msg = f"Found {len(hemisphere)} different hemisphere indicators across {len(self.inputs)} datasets." - raise ValueError(msg) - return hemisphere.pop() - def __len__(self) -> int: """Return the total length of the dataset.""" return len(self.dates) diff --git a/icenet_mp/data_loaders/common_data_module.py b/icenet_mp/data_loaders/common_data_module.py index 3a73eac8..49c0a427 100644 --- a/icenet_mp/data_loaders/common_data_module.py +++ b/icenet_mp/data_loaders/common_data_module.py @@ -7,7 +7,7 @@ from omegaconf import DictConfig from torch.utils.data import DataLoader -from icenet_mp.types import ArrayTCHW, DataloaderArgs, DataSpace +from icenet_mp.types import ArrayTCHW, DataloaderArgs, DataSpace, Hemisphere from .combined_dataset import CombinedDataset from .single_dataset import SingleDataset @@ -35,9 +35,11 @@ def __init__(self, config: DictConfig) -> None: self.base_path / "data" / "anemoi" / f"{dataset['name']}.zarr" ).resolve() ) - logger.info("Found %d dataset_groups.", len(self.dataset_groups)) - for dataset_group in self.dataset_groups: - logger.debug("... %s.", dataset_group) + logger.info("Found %d dataset groups.", len(self.dataset_groups)) + for idx, (name, paths) in enumerate(self.dataset_groups.items(), start=1): + logger.info("%d) %s:", idx, name) + for path in paths: + logger.info("%s - %s", " " * (len(str(idx)) + 1), path) # Check prediction target self.target_group_name = config["predict"]["target"]["group_name"] @@ -82,6 +84,18 @@ def __init__(self, config: DictConfig) -> None: worker_init_fn=None, ) + @property + def hemisphere(self) -> Hemisphere: + """Return the hemisphere of the dataset.""" + hemisphere: set[Hemisphere] = { + SingleDataset(name, paths).hemisphere + for name, paths in self.dataset_groups.items() + } + if len(hemisphere) != 1: + msg = f"Found {len(hemisphere)} different hemisphere indicators across {len(self.dataset_groups)} dataset groups." + raise ValueError(msg) + return hemisphere.pop() + @cached_property def input_spaces(self) -> list[DataSpace]: """Return the data space for each input.""" diff --git a/icenet_mp/data_loaders/single_dataset.py b/icenet_mp/data_loaders/single_dataset.py index 61ef4f57..48861f2e 100644 --- a/icenet_mp/data_loaders/single_dataset.py +++ b/icenet_mp/data_loaders/single_dataset.py @@ -1,14 +1,13 @@ from collections.abc import Sequence from functools import cached_property from pathlib import Path -from typing import Literal import numpy as np from anemoi.datasets.data import open_dataset from anemoi.datasets.data.dataset import Dataset as AnemoiDataset from torch.utils.data import Dataset -from icenet_mp.types import ArrayCHW, ArrayTCHW, DataSpace +from icenet_mp.types import ArrayCHW, ArrayTCHW, DataSpace, Hemisphere from icenet_mp.utils import normalise_date @@ -31,7 +30,7 @@ def __init__( self._date_ranges = sorted( date_ranges, key=lambda dr: "" if dr["start"] is None else dr["start"] ) - self.hemisphere: Literal["north", "south"] = ( + self.hemisphere: Hemisphere = ( "north" if any("north" in str(input_file).lower() for input_file in input_files) else "south" diff --git a/icenet_mp/model_service.py b/icenet_mp/model_service.py index aa6c78cf..1f4ec885 100644 --- a/icenet_mp/model_service.py +++ b/icenet_mp/model_service.py @@ -44,6 +44,7 @@ def from_config(cls, config: DictConfig) -> "ModelService": builder.model_ = hydra.utils.instantiate( dict( { + "hemisphere": builder.data_module.hemisphere, "input_spaces": [ s.to_dict() for s in builder.data_module.input_spaces ], diff --git a/icenet_mp/models/base_model.py b/icenet_mp/models/base_model.py index a9ba98c2..a1e3814b 100644 --- a/icenet_mp/models/base_model.py +++ b/icenet_mp/models/base_model.py @@ -16,7 +16,7 @@ from icenet_mp.metrics.base_metrics import MAEPerForecastDay, RMSEPerForecastDay from icenet_mp.metrics.sie_error_abs import SeaIceExtentErrorPerForecastDay -from icenet_mp.types import DataSpace, ModelTestOutput, TensorNTCHW +from icenet_mp.types import DataSpace, Hemisphere, ModelTestOutput, TensorNTCHW class BaseModel(LightningModule, ABC): @@ -25,12 +25,13 @@ class BaseModel(LightningModule, ABC): def __init__( # noqa: PLR0913 self, *, - name: str, + hemisphere: Hemisphere, input_spaces: list[DictConfig], n_forecast_steps: int, n_history_steps: int, - output_space: DictConfig, + name: str, optimizer: DictConfig, + output_space: DictConfig, scheduler: DictConfig, **_kwargs: Any, ) -> None: @@ -43,8 +44,9 @@ def __init__( # noqa: PLR0913 """ super().__init__() - # Save model name + # Save model name and hemisphere self.name = name + self.hemisphere = hemisphere # Save history and forecast steps if n_forecast_steps <= 0: @@ -182,10 +184,10 @@ def training_step( self.log( "train_loss", loss, - sync_dist=True, on_step=False, on_epoch=True, prog_bar=True, + sync_dist=True, ) return loss @@ -218,9 +220,9 @@ def validation_step( self.log( "validation_loss", loss, - sync_dist=True, on_step=False, on_epoch=True, prog_bar=True, + sync_dist=True, ) return loss diff --git a/icenet_mp/models/ddpm.py b/icenet_mp/models/ddpm.py index 8efaf446..babe510a 100644 --- a/icenet_mp/models/ddpm.py +++ b/icenet_mp/models/ddpm.py @@ -1,27 +1,15 @@ -import os from typing import Any, NoReturn import torch import torch.nn.functional as F # noqa: N812 from torchmetrics import Metric, MetricCollection -from icenet_mp.losses import WeightedMSELoss from icenet_mp.metrics import IceNetAccuracy, SIEError from icenet_mp.models.diffusion import GaussianDiffusion, UNetDiffusion -from icenet_mp.types import ModelTestOutput, TensorNTCHW +from icenet_mp.types import ModelTestOutput, TensorNCHW, TensorNTCHW from .base_model import BaseModel -# Unset SLURM_NTASKS if it's causing issues -if "SLURM_NTASKS" in os.environ: - del os.environ["SLURM_NTASKS"] - -# Optionally, set SLURM_NTASKS_PER_NODE if needed -os.environ["SLURM_NTASKS_PER_NODE"] = "1" - -# Force all new tensors to be float32 by default -torch.set_default_dtype(torch.float32) - class SimpleEncoder2D(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: @@ -39,14 +27,14 @@ def __init__(self, in_channels: int, out_channels: int) -> None: torch.nn.SiLU(), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: TensorNCHW) -> TensorNCHW: """Forward pass through the encoder block. Args: - x (torch.Tensor): Input tensor of shape (B, C, H, W). + x (TensorNCHW): Input tensor of shape (B, C, H, W). Returns: - torch.Tensor: Output tensor after applying the block. + TensorNCHW: Output tensor after applying the block. """ return self.net(x) @@ -176,40 +164,20 @@ def __init__( # noqa: PLR0913 metrics[f"val_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i]) self.metrics = MetricCollection(metrics) - test_metrics: dict[str, Metric | MetricCollection] = { - "test_accuracy": IceNetAccuracy( - leadtimes_to_evaluate=list(range(self.n_forecast_steps)) - ), - "test_sieerror": SIEError( - leadtimes_to_evaluate=list(range(self.n_forecast_steps)) - ), - } - for i in range(self.n_forecast_steps): - test_metrics[f"test_accuracy_{i}"] = IceNetAccuracy( - leadtimes_to_evaluate=[i] - ) - test_metrics[f"test_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i]) - self.test_metrics = MetricCollection(test_metrics) - self.save_hyperparameters() def forward(self, *args: Any, **kwargs: Any) -> NoReturn: msg = "This model uses `training_step`, `validation_step`, and `test_step` instead of `forward()`" raise NotImplementedError(msg) - def sample( - self, - x: torch.Tensor, - sample_weight: torch.Tensor | None, # noqa: ARG002 - ) -> torch.Tensor: + def sample(self, x: TensorNCHW) -> TensorNCHW: """Perform reverse diffusion sampling starting from noise. Args: - x (torch.Tensor): Conditioning input [B, C, H, W]. - sample_weight (torch.Tensor or None): Optional weights. + x (TensorNCHW): Conditioning input [B, C, H, W]. Returns: - torch.Tensor: Denoised output of shape [B, C, H, W]. + TensorNCHW: Denoised output of shape [B, C, H, W]. """ shape = ( @@ -227,25 +195,15 @@ def sample( t_batch = torch.full_like( x[:, 0, 0, 0], t, dtype=torch.long, device=self.device ) - pred_v: torch.Tensor = self.model(y, t_batch, x) + pred_v: TensorNCHW = self.model(y, t_batch, x) pred_v = ( pred_v.squeeze(3) if pred_v.dim() > dim_threshold else pred_v.squeeze() ) y = self.diffusion.p_sample(y, t_batch, pred_v) - return y - - def loss( - self, - prediction: TensorNTCHW, - target: TensorNTCHW, - sample_weight: TensorNTCHW | None = None, - ) -> torch.Tensor: - if sample_weight is None: - sample_weight = torch.ones_like(prediction) - return WeightedMSELoss(reduction="none")(prediction, target, sample_weight) + return torch.clamp(y, 0, 1) - def prepare_inputs(self, batch: dict[str, TensorNTCHW]) -> torch.Tensor: + def prepare_inputs(self, batch: dict[str, TensorNTCHW]) -> TensorNCHW: """Encode OSISAF and ERA5 separately, then concatenate. ERA5 -> Norm -> Project -> Resize -> Flatten Time -> Encode @@ -343,12 +301,12 @@ def training_step( target_v = self.diffusion.calculate_v(y, noise, t) # Compute loss - loss = F.mse_loss(pred_v, target_v) + loss = self.loss(pred_v, target_v) self.log( "train_loss", loss, - on_step=True, + on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, @@ -385,14 +343,12 @@ def validation_step( sample_weight = batch.get("sample_weight", torch.ones_like(y)) # Generate samples - outputs = self.sample(x, sample_weight) - - y_hat = torch.clamp(outputs, 0, 1) + y_hat = self.sample(x) # Calculate loss - loss = self.loss(y_hat, y, sample_weight) + loss = self.loss(y_hat, y) self.log( - "val_loss", + "validation_loss", loss, on_step=False, on_epoch=True, @@ -432,17 +388,10 @@ def test_step( """ x = self.prepare_inputs(batch) # [B, T, C_combined, H, W] - y = batch["target"].squeeze(2) - sample_weight = batch.get("sample_weight", torch.ones_like(y)) - - outputs = self.sample(x, sample_weight) - - y_hat = torch.clamp(outputs, 0, 1).unsqueeze(2) - - y = y.unsqueeze(2) - sample_weight = sample_weight.unsqueeze(2) + y = batch["target"] + y_hat = self.sample(x).unsqueeze(2) # note that this assumes C=1 - loss = self.loss(y_hat, y, sample_weight) + loss = self.loss(y_hat, y) self.log( "test_loss", loss, @@ -452,6 +401,7 @@ def test_step( sync_dist=True, ) - self.test_metrics.update(y_hat, y, sample_weight) + # Use BaseModel test metrics + self.test_metrics.update(y_hat, y) return ModelTestOutput(prediction=y_hat, target=y, loss=loss) diff --git a/icenet_mp/models/decoders/cnn_decoder.py b/icenet_mp/models/decoders/cnn_decoder.py index 30c45bd4..486af645 100644 --- a/icenet_mp/models/decoders/cnn_decoder.py +++ b/icenet_mp/models/decoders/cnn_decoder.py @@ -32,7 +32,7 @@ def __init__( *, activation: str = "ReLU", kernel_size: int = 3, - n_layers: int = 2, + n_layers: int = 3, bounded: bool = False, **kwargs: Any, ) -> None: diff --git a/icenet_mp/models/decoders/piecewise_decoder.py b/icenet_mp/models/decoders/piecewise_decoder.py index 318cc04d..6bd18c87 100644 --- a/icenet_mp/models/decoders/piecewise_decoder.py +++ b/icenet_mp/models/decoders/piecewise_decoder.py @@ -33,8 +33,8 @@ def __init__( *, conv_activation: str = "SiLU", conv_kernel_size: int = 3, - n_conv_blocks: int = 0, - restrict_range: str = "none", + n_conv_blocks: int = 3, + restrict_range: str = "clamp", **kwargs: Any, ) -> None: """Initialise a PiecewiseDecoder.""" diff --git a/icenet_mp/models/encoders/cnn_encoder.py b/icenet_mp/models/encoders/cnn_encoder.py index 6b13c7fc..634658c3 100644 --- a/icenet_mp/models/encoders/cnn_encoder.py +++ b/icenet_mp/models/encoders/cnn_encoder.py @@ -29,7 +29,7 @@ def __init__( *, activation: str = "ReLU", kernel_size: int = 3, - n_layers: int = 2, + n_layers: int = 3, **kwargs: Any, ) -> None: """Initialise a CNNEncoder.""" diff --git a/icenet_mp/models/encoders/piecewise_encoder.py b/icenet_mp/models/encoders/piecewise_encoder.py index 02ce816a..ab9bd9d7 100644 --- a/icenet_mp/models/encoders/piecewise_encoder.py +++ b/icenet_mp/models/encoders/piecewise_encoder.py @@ -25,7 +25,7 @@ def __init__( self, conv_activation: str = "SiLU", conv_kernel_size: int = 3, - n_conv_blocks: int = 0, + n_conv_blocks: int = 2, **kwargs: Any, ) -> None: """Initialise a PiecewiseEncoder.""" diff --git a/icenet_mp/models/processors/unet.py b/icenet_mp/models/processors/unet.py index 7833edec..3037ff15 100644 --- a/icenet_mp/models/processors/unet.py +++ b/icenet_mp/models/processors/unet.py @@ -25,8 +25,8 @@ class UNetProcessor(BaseProcessor): def __init__( self, *, - kernel_size: int, - start_out_channels: int, + kernel_size: int = 3, + start_out_channels: int = 128, **kwargs: Any, ) -> None: """Initialise a UNetProcessor. diff --git a/icenet_mp/models/processors/vit.py b/icenet_mp/models/processors/vit.py index 5eecb629..796b0203 100644 --- a/icenet_mp/models/processors/vit.py +++ b/icenet_mp/models/processors/vit.py @@ -22,12 +22,12 @@ class VitProcessor(BaseProcessor): def __init__( self, *, - depth: int, - dropout: float, - emb_dim: int, - heads: int, - mlp_dim: int, - patch_size: int, + depth: int = 3, + dropout: float = 0.3, + emb_dim: int = 128, + heads: int = 4, + mlp_dim: int = 256, + patch_size: int = 16, **kwargs: Any, ) -> None: """Initialize Vision Transformer model for sea ice forecasting.""" diff --git a/icenet_mp/types/__init__.py b/icenet_mp/types/__init__.py index 2036b94c..a2a87810 100644 --- a/icenet_mp/types/__init__.py +++ b/icenet_mp/types/__init__.py @@ -18,6 +18,7 @@ ArrayTHW, DiffMode, DiffStrategy, + Hemisphere, TensorNCHW, TensorNTCHW, ) @@ -37,6 +38,7 @@ "DiffColourmapSpec", "DiffMode", "DiffStrategy", + "Hemisphere", "Metadata", "ModelTestOutput", "PlotSpec", diff --git a/icenet_mp/types/typedefs.py b/icenet_mp/types/typedefs.py index 2f88acae..1f99aa39 100644 --- a/icenet_mp/types/typedefs.py +++ b/icenet_mp/types/typedefs.py @@ -23,3 +23,5 @@ # - "two-pass": scan once to figure the scale, compute per-frame (balanced) # - "per-frame": compute per-frame (low RAM, more CPU) DiffStrategy = Literal["precompute", "two-pass", "per-frame"] + +Hemisphere = Literal["north", "south"] diff --git a/tests/conftest.py b/tests/conftest.py index 56547256..e80fe55b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,6 +80,7 @@ def cfg_model_service() -> DictConfig: }, }, "evaluate": {"callbacks": {}}, + "hemisphere": "north", "loggers": {}, "model": { "_target_": "MockModel", diff --git a/tests/models/test_base_model.py b/tests/models/test_base_model.py index 07b0899b..c1b27e5c 100644 --- a/tests/models/test_base_model.py +++ b/tests/models/test_base_model.py @@ -11,7 +11,7 @@ class FakeDataModel(BaseModel): def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialise a fake data model for testing purposes.""" - super().__init__(*args, **kwargs) + super().__init__(*args, hemisphere="north", **kwargs) self.t = kwargs["n_forecast_steps"] self.c = kwargs["output_space"]["channels"] self.h = kwargs["output_space"]["shape"][0] diff --git a/tests/models/test_encode_process_decode.py b/tests/models/test_encode_process_decode.py index 8d3ced1e..a4c15b49 100644 --- a/tests/models/test_encode_process_decode.py +++ b/tests/models/test_encode_process_decode.py @@ -23,6 +23,7 @@ def test_init( encoder=cfg_encoder, processor=cfg_processor, decoder=cfg_decoder, + hemisphere="north", input_spaces=[cfg_input_space], n_forecast_steps=test_n_forecast_steps, n_history_steps=test_n_history_steps, @@ -58,6 +59,7 @@ def test_forward( encoder=cfg_encoder, processor=cfg_processor, decoder=cfg_decoder, + hemisphere="north", input_spaces=[cfg_input_space], n_forecast_steps=test_n_forecast_steps, n_history_steps=test_n_history_steps, diff --git a/tests/models/test_persistence.py b/tests/models/test_persistence.py index 21843c39..f830a68a 100644 --- a/tests/models/test_persistence.py +++ b/tests/models/test_persistence.py @@ -30,6 +30,7 @@ def test_forward_shape( } model = Persistence( name="persistence", + hemisphere="north", input_spaces=[input_space], n_forecast_steps=test_n_forecast_steps, n_history_steps=test_n_history_steps, @@ -59,6 +60,7 @@ def test_forward_shape( def test_optimizer(self) -> None: model = Persistence( name="persistence", + hemisphere="north", input_spaces=[ { "channels": 1, diff --git a/tests/models/test_piecewise_encode_decode.py b/tests/models/test_piecewise_encode_decode.py index 1bdda06c..61504250 100644 --- a/tests/models/test_piecewise_encode_decode.py +++ b/tests/models/test_piecewise_encode_decode.py @@ -65,6 +65,7 @@ def test_forward( data_space_out=input_space, n_conv_blocks=n_conv_blocks, n_forecast_steps=n_forecast_steps, + restrict_range="none", ) output_ntchw = decoder.rollout(latent_ntchw) assert torch.equal(input_ntchw, output_ntchw) diff --git a/tests/test_model_service.py b/tests/test_model_service.py index d5ccbf8c..7d7e27c8 100644 --- a/tests/test_model_service.py +++ b/tests/test_model_service.py @@ -12,6 +12,7 @@ class MockCommonDataModule: def __init__(self, config: DictConfig) -> None: """Mock CommonDataModule.""" self.config = config + self.hemisphere = "north" self.input_spaces = [DataSpace(5, "input", (20, 20))] self.n_forecast_steps = 2 self.n_history_steps = 3