Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d9d9517
:wrench: Move standard ViT parameters into model defaults
jemrobinson Mar 2, 2026
8d7d701
:wrench: Set common parameters for UNet
jemrobinson Mar 2, 2026
26094b8
:wrench: Move standard UNet parameters into model defaults
jemrobinson Mar 2, 2026
7c14fff
:wrench: Standardise CNN encoder latent space
jemrobinson Mar 2, 2026
432bd0d
:wrench: Standardise naive encoder latent space
jemrobinson Mar 2, 2026
291f1eb
:wrench: Standardise piecewise latent space
jemrobinson Mar 2, 2026
3a8f120
:art: Set hemisphere as W&B run group for easier filtering
jemrobinson Mar 2, 2026
764623a
:coffin: Use CommonDataModule hemisphere instead of CombinedDataset
jemrobinson Mar 2, 2026
9381776
:wrench: Clean up DDPM config file by making clear that it is using d…
jemrobinson Mar 2, 2026
66ff9b5
:loud_sound: Log datasets being used
jemrobinson Mar 2, 2026
69562e8
:wrench: Synchronise naming of losses between DDPM and BaseModel
jemrobinson Mar 2, 2026
5d7bc68
:bug: Attach hemisphere property to BaseModel
jemrobinson Mar 5, 2026
0ed690c
:wrench: Add a config for predicting 14 days into the future
jemrobinson Mar 5, 2026
c804041
:coffin: Drop unused sample_weight argument from DDPM.sample
jemrobinson Mar 5, 2026
62cdc11
:boom: Use common loss function in DDPM for easier comparison with ot…
jemrobinson Mar 5, 2026
8700774
:sparkles: Use common test metrics in DDPM
jemrobinson Mar 5, 2026
c0a0c9e
:coffin: Drop unused sample_weight in DDPM.test_step
jemrobinson Mar 5, 2026
e59216b
:art: Move clamp into DDPM.sample
jemrobinson Mar 5, 2026
76fd724
:label: Add type hints for tensor shapes to DDPM
jemrobinson Mar 6, 2026
4c54500
:coffin: Remove slurm specific options from DDPM model
jemrobinson Mar 6, 2026
a7c76bd
:wrench: Drop convolutional size in piecewise encoders to reduce memo…
jemrobinson Mar 6, 2026
46a2887
:wrench: Add a config for predicting 21 days into the future
jemrobinson Mar 9, 2026
b4b69af
:rewind: Hemisphere is no longer needed as a grouping
jemrobinson Mar 9, 2026
c17bb25
:wrench: Reduce size of CNN latent space to support 21 forecast days
jemrobinson Mar 9, 2026
8f502b3
:wrench: Reduce number of channels in naive-unet-naive to support 21 …
jemrobinson Mar 9, 2026
32398c7
:wrench: Fix config name
jemrobinson Mar 9, 2026
f18ef85
:wrench: Use 2 convolutional blocks in piecewise encoder by default
jemrobinson Mar 9, 2026
090f5bb
:wrench: Use 3 convolutional blocks and clamp in piecewise decoder by…
jemrobinson Mar 9, 2026
2dae804
:rotating_light: Fix linting/testing errors
jemrobinson Mar 9, 2026
18f003c
:wrench: updating predict config for the demo
IFenton Mar 20, 2026
54e36e1
:wrench: Add a new model called 'quick_test' that runs faster at the …
jemrobinson Mar 23, 2026
cde8d30
:memo: Added note about default quick test model
jemrobinson Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ You can then run this with, e.g.:
```bash
uv run imp <command> --config-name <your local config>.yaml
```

This will run using the default model setup (rescaling encoder, small UNet, rescaling decoder) that is sufficient for quick tests, but not appropriate for larger training runs.

You can also use this config to override other options in the `base.yaml` file, as shown below:

```yaml
Expand All @@ -74,7 +77,7 @@ uv run imp <command> ++base_path=/local/path/to/my/data

See `config/demo_north.yaml` for an example of this.

Note that `base_persistence.yaml` overrides the specific options in `base.yaml` needed to run the `Persistence` model.
:warning: Note that `base_persistence.yaml` overrides the specific options in `base.yaml` needed to run the `Persistence` model.

### HPC-specific configurations

Expand Down
9 changes: 7 additions & 2 deletions icenet_mp/callbacks/plotting_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor

from icenet_mp.data_loaders import CombinedDataset
from icenet_mp.models import BaseModel
from icenet_mp.types import ModelTestOutput, PlotSpec
from icenet_mp.utils import datetime_from_npdatetime
from icenet_mp.visualisations import DEFAULT_SIC_SPEC, Plotter
Expand Down Expand Up @@ -56,7 +57,7 @@ def set_metadata(self, config: DictConfig, model_name: str) -> None:
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule, # noqa: ARG002
pl_module: LightningModule,
outputs: Tensor | Mapping[str, Any] | None,
batch: Any, # noqa: ANN401, ARG002
batch_idx: int,
Expand Down Expand Up @@ -93,7 +94,11 @@ def on_test_batch_end(
map(datetime_from_npdatetime, dataset.get_forecast_steps(start_date))
)
# Set hemisphere for plotting based on dataset
self.plotter.set_hemisphere(dataset.hemisphere)
if not isinstance(pl_module, BaseModel):
msg = f"Lightning module is of type {type(pl_module)}, skipping plotting."
logger.warning(msg)
return
self.plotter.set_hemisphere(pl_module.hemisphere)

# Get loggers that support image and video logging
image_loggers = [ll for ll in trainer.loggers if hasattr(ll, "log_image")]
Expand Down
2 changes: 1 addition & 1 deletion icenet_mp/cli/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def wrapper(
config_name: Annotated[
str | None,
Option(help="Specify the name of a file to load from the config directory"),
] = "base",
] = "sample",
*args: Param.args,
**kwargs: Param.kwargs,
) -> RetType:
Expand Down
2 changes: 1 addition & 1 deletion icenet_mp/config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- loggers:
- wandb
- model: naive_unet_naive
- predict: sic-icenet
- predict: sic-icenet-2d
- train: default
- _self_

Expand Down
2 changes: 1 addition & 1 deletion icenet_mp/config/demo_north.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- base
- override /data: demo
- override /predict: osisaf-north
- override /predict: sic-icenet-2d
- _self_
6 changes: 1 addition & 5 deletions icenet_mp/config/model/cnn_null_cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@ name: cnn-null-cnn

encoder:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 3 # Size of the kernel for convolutional layers
latent_space: [128, 128] # Shape of the latent space
n_layers: 3 # Number of convolutional layers
latent_space: [144, 144] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.NullProcessor

decoder:
_target_: icenet_mp.models.decoders.CNNDecoder
kernel_size: 3 # Size of the kernel for convolutional layers
n_layers: 3 # Number of convolutional layers
bounded: false # Whether to bound the output between 0 and 1
8 changes: 1 addition & 7 deletions icenet_mp/config/model/cnn_unet_cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,11 @@ name: cnn-unet-cnn

encoder:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 3 # Size of the kernel for convolutional layers
latent_space: [128, 128] # Shape of the latent space
n_layers: 3 # Number of convolutional layers
latent_space: [144, 144] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.UNetProcessor
kernel_size: 3 # Size of the kernel for convolutional layers
start_out_channels: 64 # Initial number of channels for the first convolutional layer

decoder:
_target_: icenet_mp.models.decoders.CNNDecoder
kernel_size: 3 # Size of the kernel for convolutional layers
n_layers: 3 # Number of convolutional layers
bounded: false # Whether to bound the output between 0 and 1
12 changes: 1 addition & 11 deletions icenet_mp/config/model/cnn_vit_cnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,11 @@ name: cnn-vit-cnn

encoder:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 3 # Size of the kernel for convolutional layers
latent_space: [192, 192] # Shape of the latent space
n_layers: 3 # Number of convolutional layers
latent_space: [144, 144] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.VitProcessor
patch_size: 16
emb_dim: 128
depth: 3
heads: 4
mlp_dim: 256
dropout: 0.3

decoder:
_target_: icenet_mp.models.decoders.CNNDecoder
kernel_size: 3 # Size of the kernel for convolutional layers
n_layers: 3 # Number of convolutional layers
bounded: false # Whether to bound the output between 0 and 1
11 changes: 1 addition & 10 deletions icenet_mp/config/model/ddpm.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
_target_: icenet_mp.models.ddpm.DDPM

# Run DDPM model with default settings
name: ddpm

# DDPM parameters
timesteps: 1000
learning_rate: 5e-4
start_out_channels: 32
kernel_size: 3
activation: "SiLU"
normalization: "groupnorm"
time_embed_dim : 256
dropout_rate: 0.1
2 changes: 1 addition & 1 deletion icenet_mp/config/model/naive_null_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ name: naive-null-naive

encoder:
_target_: icenet_mp.models.encoders.NaiveLinearEncoder
latent_space: [128, 128] # Shape of the latent space
latent_space: [432, 432] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.NullProcessor
Expand Down
5 changes: 2 additions & 3 deletions icenet_mp/config/model/naive_unet_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ name: naive-unet-naive

encoder:
_target_: icenet_mp.models.encoders.NaiveLinearEncoder
latent_space: [128, 128] # Shape of the latent space
latent_space: [432, 432] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.UNetProcessor
kernel_size: 3 # Size of the kernel for convolutional layers
start_out_channels: 64 # Initial number of channels for the first convolutional layer
start_out_channels: 100 # Reduce number of channels to support 21 day forecasts

decoder:
_target_: icenet_mp.models.decoders.NaiveLinearDecoder
Expand Down
8 changes: 1 addition & 7 deletions icenet_mp/config/model/naive_vit_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@ name: naive-vit-naive

encoder:
_target_: icenet_mp.models.encoders.NaiveLinearEncoder
latent_space: [192, 192] # Shape of the latent space
latent_space: [432, 432] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.VitProcessor
patch_size: 16
emb_dim: 128
depth: 3
heads: 4
mlp_dim: 256
dropout: 0.3

decoder:
_target_: icenet_mp.models.decoders.NaiveLinearDecoder
Expand Down
13 changes: 13 additions & 0 deletions icenet_mp/config/model/piecewise_null_piecewise.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: icenet_mp.models.EncodeProcessDecode

name: piecewise-null-piecewise

encoder:
_target_: icenet_mp.models.encoders.PiecewiseEncoder
latent_space: [192, 192] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.NullProcessor

decoder:
_target_: icenet_mp.models.decoders.PiecewiseDecoder
7 changes: 1 addition & 6 deletions icenet_mp/config/model/piecewise_unet_piecewise.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@ name: piecewise-unet-piecewise

encoder:
_target_: icenet_mp.models.encoders.PiecewiseEncoder
latent_space: [128, 128] # Shape of the latent space
n_conv_blocks: 3 # Number of convolutional blocks to add after encoding
latent_space: [192, 192] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.UNetProcessor
kernel_size: 3 # Size of the kernel for convolutional layers
start_out_channels: 64 # Initial number of channels for the first convolutional layer

decoder:
_target_: icenet_mp.models.decoders.PiecewiseDecoder
restrict_range: clamp # Method for restricting output range (e.g., clamp, sigmoid, tanh)
n_conv_blocks: 3 # Number of convolutional blocks to add before decoding
13 changes: 13 additions & 0 deletions icenet_mp/config/model/piecewise_vit_piecewise.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: icenet_mp.models.EncodeProcessDecode

name: piecewise-vit-piecewise

encoder:
_target_: icenet_mp.models.encoders.PiecewiseEncoder
latent_space: [192, 192] # Shape of the latent space

processor:
_target_: icenet_mp.models.processors.VitProcessor

decoder:
_target_: icenet_mp.models.decoders.PiecewiseDecoder
14 changes: 14 additions & 0 deletions icenet_mp/config/model/quick_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_target_: icenet_mp.models.EncodeProcessDecode

name: sample

encoder:
_target_: icenet_mp.models.encoders.NaiveLinearEncoder
latent_space: [128, 128]

processor:
_target_: icenet_mp.models.processors.UNetProcessor
start_out_channels: 64

decoder:
_target_: icenet_mp.models.decoders.NaiveLinearDecoder
9 changes: 9 additions & 0 deletions icenet_mp/config/predict/sic-icenet-14d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Name of the dataset group containing our prediction target
target:
group_name: sic-icenet

# Number of future steps to predict
n_forecast_steps: 14

# Number of history steps to use when predicting
n_history_steps: 3
9 changes: 9 additions & 0 deletions icenet_mp/config/predict/sic-icenet-21d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Name of the dataset group containing our prediction target
target:
group_name: sic-icenet

# Number of future steps to predict
n_forecast_steps: 21

# Number of history steps to use when predicting
n_history_steps: 3
11 changes: 11 additions & 0 deletions icenet_mp/config/predict/sic-ssmis-14d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Name of the dataset group containing our prediction target
target:
group_name: sic-ssmis
variables:
- ice_conc

# Number of future steps to predict
n_forecast_steps: 14

# Number of history steps to use when predicting
n_history_steps: 3
11 changes: 11 additions & 0 deletions icenet_mp/config/predict/sic-ssmis-21d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Name of the dataset group containing our prediction target
target:
group_name: sic-ssmis
variables:
- ice_conc

# Number of future steps to predict
n_forecast_steps: 21

# Number of history steps to use when predicting
n_history_steps: 3
4 changes: 4 additions & 0 deletions icenet_mp/config/sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base
- override /model: quick_test
- _self_
12 changes: 0 additions & 12 deletions icenet_mp/data_loaders/combined_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Literal

import numpy as np
from torch.utils.data import Dataset
Expand Down Expand Up @@ -87,17 +86,6 @@ def start_date(self) -> np.datetime64:
"""Return the start date of the dataset."""
return self.dates[0]

@property
def hemisphere(self) -> Literal["north", "south"]:
"""Return the hemisphere of the dataset."""
hemisphere: set[Literal["north", "south"]] = {
ds.hemisphere for ds in self.inputs
}
if len(hemisphere) != 1:
msg = f"Found {len(hemisphere)} different hemisphere indicators across {len(self.inputs)} datasets."
raise ValueError(msg)
return hemisphere.pop()

def __len__(self) -> int:
"""Return the total length of the dataset."""
return len(self.dates)
Expand Down
22 changes: 18 additions & 4 deletions icenet_mp/data_loaders/common_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from icenet_mp.types import ArrayTCHW, DataloaderArgs, DataSpace
from icenet_mp.types import ArrayTCHW, DataloaderArgs, DataSpace, Hemisphere

from .combined_dataset import CombinedDataset
from .single_dataset import SingleDataset
Expand Down Expand Up @@ -35,9 +35,11 @@ def __init__(self, config: DictConfig) -> None:
self.base_path / "data" / "anemoi" / f"{dataset['name']}.zarr"
).resolve()
)
logger.info("Found %d dataset_groups.", len(self.dataset_groups))
for dataset_group in self.dataset_groups:
logger.debug("... %s.", dataset_group)
logger.info("Found %d dataset groups.", len(self.dataset_groups))
for idx, (name, paths) in enumerate(self.dataset_groups.items(), start=1):
logger.info("%d) %s:", idx, name)
for path in paths:
logger.info("%s - %s", " " * (len(str(idx)) + 1), path)

# Check prediction target
self.target_group_name = config["predict"]["target"]["group_name"]
Expand Down Expand Up @@ -82,6 +84,18 @@ def __init__(self, config: DictConfig) -> None:
worker_init_fn=None,
)

@property
def hemisphere(self) -> Hemisphere:
"""Return the hemisphere of the dataset."""
hemisphere: set[Hemisphere] = {
SingleDataset(name, paths).hemisphere
for name, paths in self.dataset_groups.items()
}
if len(hemisphere) != 1:
msg = f"Found {len(hemisphere)} different hemisphere indicators across {len(self.dataset_groups)} dataset groups."
raise ValueError(msg)
return hemisphere.pop()

@cached_property
def input_spaces(self) -> list[DataSpace]:
"""Return the data space for each input."""
Expand Down
5 changes: 2 additions & 3 deletions icenet_mp/data_loaders/single_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path
from typing import Literal

import numpy as np
from anemoi.datasets.data import open_dataset
from anemoi.datasets.data.dataset import Dataset as AnemoiDataset
from torch.utils.data import Dataset

from icenet_mp.types import ArrayCHW, ArrayTCHW, DataSpace
from icenet_mp.types import ArrayCHW, ArrayTCHW, DataSpace, Hemisphere
from icenet_mp.utils import normalise_date


Expand All @@ -31,7 +30,7 @@ def __init__(
self._date_ranges = sorted(
date_ranges, key=lambda dr: "" if dr["start"] is None else dr["start"]
)
self.hemisphere: Literal["north", "south"] = (
self.hemisphere: Hemisphere = (
"north"
if any("north" in str(input_file).lower() for input_file in input_files)
else "south"
Expand Down
1 change: 1 addition & 0 deletions icenet_mp/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def from_config(cls, config: DictConfig) -> "ModelService":
builder.model_ = hydra.utils.instantiate(
dict(
{
"hemisphere": builder.data_module.hemisphere,
"input_spaces": [
s.to_dict() for s in builder.data_module.input_spaces
],
Expand Down
Loading
Loading