Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7c8ef0a
:bug: Check for appropriate patch size during VitProcessor initialisa…
jemrobinson Jun 18, 2026
4a3ff5e
:truck: Rename baseline configs
jemrobinson Jun 18, 2026
df16775
:coffin: Remove non-baseline configs
jemrobinson Jun 18, 2026
0a6d457
:memo: Better mask creation output
jemrobinson Jun 22, 2026
a095677
:alien: Add missing cleanup step in dataset downloading
jemrobinson Jun 22, 2026
dda6dbf
:alien: Catch recipe failures during artifact cleanup
jemrobinson Jun 22, 2026
511a23d
:memo: Reduce unnecessary logging
jemrobinson Jun 22, 2026
c547ba3
:alembic: Check for artifacts as part of finalisation
jemrobinson Jun 22, 2026
5275b0d
:bug: Allow EMAWeightAveragingCallback to support parameter-less mode…
jemrobinson Jun 23, 2026
7746af0
:sparkles: Make scale factor configurable in CNNEncoder/CNNDecoder
jemrobinson Jun 23, 2026
f69edc4
:wrench: Fix CNN-UNet-CNN config file
jemrobinson Jun 23, 2026
95c58cd
:bug: Do not log metrics before they have been updated
jemrobinson Jun 23, 2026
cb22696
:bulb: Drop configuration message to debug as this appears on each wo…
jemrobinson Jun 23, 2026
1921ddb
:bug: Use most recent timestep for persistence
jemrobinson Jun 23, 2026
bb39cf5
:recycle: Rename _loss_fn to loss_fn
jemrobinson Jun 23, 2026
492318d
:wrench: Increase naive-unet-naive start_out_channels and bound the o…
jemrobinson Jun 25, 2026
230c6cd
:wrench: Increase kernel_size, start_out_channels and added bounding …
jemrobinson Jun 25, 2026
a5e96bb
:wrench: Increase ViT size and add decoder bounding to cnn_vit_cnn
jemrobinson Jun 25, 2026
9a313c9
:wrench: Increase DDPM size and reduce dropout in ddpm
jemrobinson Jun 25, 2026
b545ae1
:art: Tidy up use of fully_deterministic
jemrobinson Jun 26, 2026
5f85626
:white_check_mark: Fix residual use of _loss_fn in testing
jemrobinson Jun 26, 2026
467a674
:wrench: Increase complexity of ViT processor
jemrobinson Jun 27, 2026
4dbf8d7
:wrench: Minor tweaks to cnn-unet-cnn, ddpm and piecewise-unet-piecew…
jemrobinson Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions icenet_mp/callbacks/ema_weight_averaging_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import WeightAveraging
from torch.optim.swa_utils import get_ema_multi_avg_fn

Expand Down Expand Up @@ -26,6 +29,18 @@ def __init__(
self.every_n_epochs = every_n_epochs
self.every_n_steps = every_n_steps

def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, *args: Any, **kwargs: Any
) -> None:
"""Ignore the update if the module has no parameters."""
if next(pl_module.parameters(), None) is not None:
super().on_train_batch_end(trainer, pl_module, *args, **kwargs)

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Ignore the update if the module has no parameters."""
if next(pl_module.parameters(), None) is not None:
super().on_train_epoch_end(trainer, pl_module)

def should_update(
self, step_idx: int | None = None, epoch_idx: int | None = None
) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions icenet_mp/callbacks/metric_summary_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def log_per_epoch_metrics(
"""Log per-epoch metrics to W&B."""
for name, metric in metrics.items():
# Compute the metric value (e.g., SIEError) across all batches
if not metric._update_called:
continue
values: Tensor = metric.compute()

# Log the mean value of the metric across all days
Expand Down Expand Up @@ -56,6 +58,8 @@ def log_per_run_metrics(
values_per_forecast_day: dict[str, dict[str, Tensor]] = defaultdict(dict)
for stage, metric_collection in metrics.items():
for metric_name, metric in metric_collection.items():
if not metric._update_called:
continue
metric_tensor: Tensor = metric.compute()
if metric_tensor.reshape(-1).shape[0] > 1:
values_per_forecast_day[metric_name][stage] = metric_tensor
Expand Down
2 changes: 1 addition & 1 deletion icenet_mp/compatibility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def configure_external_libraries() -> None:
"""Configure any external libraries used by the application."""
log.info("Configuring external libraries...")
log.debug("Configuring external libraries...")
patch_parameter_deepcopy()
register_accelerators()
register_animation_backends()
Expand Down
2 changes: 1 addition & 1 deletion icenet_mp/config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:
- evaluate: default
- loggers:
- wandb
- model: naive_unet_naive
- model: 00_naive_unet_naive
- predict: sic-icenet-2d
- random: default
- train: default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ encoders:

processor:
_target_: icenet_mp.models.processors.UNetProcessor
start_out_channels: 100 # Reduce number of channels to support 21 day forecasts
start_out_channels: 128

decoder:
_target_: icenet_mp.models.decoders.NaiveLinearDecoder
bounded: false # Whether to bound the output between 0 and 1
bounded: true
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,31 @@ encoders:
latent_space: [144, 144]
era5:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 5
n_layers: 1
scale_factor: 3
float-argo:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 5
n_layers: 1
scale_factor: 3
sic-icenet:
_target_: icenet_mp.models.encoders.CNNEncoder
kernel_size: 5
n_layers: 1
scale_factor: 3
sic-ssmis:
_target_: icenet_mp.models.encoders.CNNEncoder
n_layers: 1
scale_factor: 3

processor:
_target_: icenet_mp.models.processors.UNetProcessor
start_out_channels: 128

decoder:
_target_: icenet_mp.models.decoders.CNNDecoder
bounded: false # Whether to bound the output between 0 and 1
bounded: true
kernel_size: 5
n_layers: 1
scale_factor: 3
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,30 @@ _target_: icenet_mp.models.EncodeProcessDecode
name: cnn-vit-cnn

encoders:
latent_space: [144, 144]
latent_space: [216, 216]
era5:
_target_: icenet_mp.models.encoders.CNNEncoder
n_layers: 1
float-argo:
_target_: icenet_mp.models.encoders.CNNEncoder
n_layers: 1
sic-icenet:
_target_: icenet_mp.models.encoders.CNNEncoder
n_layers: 1
sic-ssmis:
_target_: icenet_mp.models.encoders.CNNEncoder
n_layers: 1

processor:
_target_: icenet_mp.models.processors.VitProcessor
depth: 6 # number of transformer blocks
dropout: 0.1 # dropout rate
emb_dim: 256 # embedding dimension
heads: 8 # number of attention heads
mlp_dim: 1024 # dimension of the MLP in the transformer block
patch_size: 24 # 216 / 24 = 9 patches per dimension (81 in total)

decoder:
_target_: icenet_mp.models.decoders.CNNDecoder
bounded: false # Whether to bound the output between 0 and 1
bounded: true
n_layers: 1
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ _target_: icenet_mp.models.ddpm.DDPM

# Run DDPM model with default settings
name: ddpm
dropout_rate: 0.05
start_out_channels: 128
timesteps: 500
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ processor:

decoder:
_target_: icenet_mp.models.decoders.PiecewiseDecoder
restrict_range: sigmoid
21 changes: 0 additions & 21 deletions icenet_mp/config/model/cnn_null_cnn.yaml

This file was deleted.

21 changes: 0 additions & 21 deletions icenet_mp/config/model/naive_null_naive.yaml

This file was deleted.

21 changes: 0 additions & 21 deletions icenet_mp/config/model/naive_vit_naive.yaml

This file was deleted.

20 changes: 0 additions & 20 deletions icenet_mp/config/model/piecewise_null_piecewise.yaml

This file was deleted.

20 changes: 0 additions & 20 deletions icenet_mp/config/model/piecewise_vit_piecewise.yaml

This file was deleted.

24 changes: 0 additions & 24 deletions icenet_mp/config/model/reproject_unet_naive.yaml

This file was deleted.

1 change: 0 additions & 1 deletion icenet_mp/config/train/scheduler/ddpm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ defaults:

scheduler_parameters:
eta_min: 5e-6
T_max: 100
38 changes: 31 additions & 7 deletions icenet_mp/data_processors/data_downloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import shutil
from contextlib import suppress
from pathlib import Path
from typing import Any

import numpy as np
from anemoi.datasets import open_dataset
from anemoi.datasets.commands.cleanup import Cleanup
from anemoi.datasets.commands.finalise import Finalise
from anemoi.datasets.commands.init import Init
from anemoi.datasets.commands.inspect import InspectZarr
Expand All @@ -14,6 +16,7 @@
from zarr.errors import PathNotFoundError

from icenet_mp.types import (
AnemoiCleanupArgs,
AnemoiDatasetStatus,
AnemoiFinaliseArgs,
AnemoiInitArgs,
Expand Down Expand Up @@ -46,6 +49,14 @@ def __init__(
self.recipe = Recipe(**anemoi_config)
self.preprocessor = cls_preprocessor(anemoi_config)

def artifacts(self) -> list[Path]:
"""Return a list of temporary artifacts created during the download and finalise process."""
return [
path
for path in self.path_dataset.parent.glob(f"{self.path_dataset.stem}.*")
if path != self.path_dataset
]

def check_status(self) -> AnemoiDatasetStatus:
"""Return the status of the dataset."""
try:
Expand Down Expand Up @@ -85,7 +96,9 @@ def check_status(self) -> AnemoiDatasetStatus:
return AnemoiDatasetStatus(
copy_in_progress=copy_in_progress,
download_complete=download_complete,
is_finalised=download_complete and statistics_ready,
is_finalised=download_complete
and statistics_ready
and not self.artifacts(),
)

def create(self, *, overwrite: bool = False) -> None:
Expand Down Expand Up @@ -169,22 +182,33 @@ def finalise(self, *, overwrite: bool, status: AnemoiDatasetStatus) -> None:
)
logger.info("Finalised dataset %s at %s.", self.name, self.path_dataset)

# create active grid cell and land masks for the SSMIS dataset
# Create active grid cell and land masks if appropriate
self.generate_masks(overwrite=overwrite)

# Cleanup any temporary artifacts created during the download and finalise process
if self.artifacts():
with suppress(ValueError):
Cleanup().run(AnemoiCleanupArgs(path=str(self.path_dataset)))
if remaining := self.artifacts():
logger.warning("Residual artifacts for dataset %s:", self.name)
for artifact in remaining:
logger.warning("... %s", artifact)
else:
logger.info("Cleaned up temporary artifacts for dataset %s.", self.name)

def generate_masks(self, *, overwrite: bool) -> None:
"""Generate land and active grid cell masks for the SSMIS dataset."""
# if there is an SSMIS dataset, create the masks
# Create the masks if this is an SSMIS dataset
if "ssmis" not in self.name:
logger.info("Not SSMIS dataset, skipping mask creation.")
return
logger.debug("Generating land and active grid cell masks for SSMIS dataset.")

self.path_masks.mkdir(parents=True, exist_ok=True)
land_mask_path = self.path_masks / "land_mask.npy"
active_mask_path = self.path_masks / "active_mask.npy"

if land_mask_path.exists() and active_mask_path.exists() and not overwrite:
logger.info("Both masks already exist, skipping creation.")
logger.debug("Both masks already exist, skipping creation.")
return

# Unpack status flags into a binary array, skipping any missing dates
Expand All @@ -205,7 +229,7 @@ def generate_masks(self, *, overwrite: bool) -> None:

# land mask: land = 0, sea = 1
if land_mask_path.exists() and not overwrite:
logger.info("Land mask already exists, skipping creation.")
logger.debug("Land mask already exists, skipping creation.")
land_mask = np.load(land_mask_path)
else:
land_mask = np.squeeze(binary[..., [7]]).sum(axis=0)
Expand All @@ -219,7 +243,7 @@ def generate_masks(self, *, overwrite: bool) -> None:

# active mask: active grid cells = 1, inactive = 0
if active_mask_path.exists() and not overwrite:
logger.info("Active mask already exists, skipping creation.")
logger.debug("Active mask already exists, skipping creation.")
else:
# Identify grid cells that are inactive for all time steps
inactive_count = np.squeeze(binary[..., [0]]).sum(axis=0)
Expand Down
Loading
Loading