Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7766acc
:wrench: Moved trainer and optimizer configs into dedicated files
jemrobinson Aug 1, 2025
4b4f1f1
:wrench: Move prediction target settings to a dedicated config file
jemrobinson Aug 1, 2025
b657ae6
:sparkles: Pass n_forecast_steps and n_history_steps to CombinedDataset
jemrobinson Aug 1, 2025
2d183f3
:safety_vest: Require that all datasets in a CombinedDataset have the…
jemrobinson Aug 1, 2025
be2916f
:sparkles: Add get_forecast_steps and get_history_steps to return lis…
jemrobinson Aug 1, 2025
beeec9d
:art: Return a dictionary of tensors instead of a list
jemrobinson Aug 4, 2025
b1d834f
:label: Separate numpy and torch batches
jemrobinson Aug 4, 2025
0166f67
:sparkles: Get all available dates when constructing a CombinedDataset
jemrobinson Aug 4, 2025
664d350
:heavy_plus_sign: Add cached index_from_date method
jemrobinson Aug 4, 2025
3084f86
squash :sparkles: Get all available dates when constructing a Combine…
jemrobinson Aug 4, 2025
57d242e
:recycle: Refactor Encoder and Decoder to take and return a time dime…
jemrobinson Aug 4, 2025
90fcbbd
:recycle: Pass NTCHW batches through training
jemrobinson Aug 4, 2025
2bec94a
:heavy_plus_sign: Add types-cachetools for mypy
jemrobinson Aug 7, 2025
2287aed
:rotating_light: Fix typing errors
jemrobinson Aug 7, 2025
6344289
:heavy_plus_sign: Use jaxtyping to annotate types of tensor and array…
jemrobinson Aug 7, 2025
319cbe9
:label: Add type-hint to ZebraDataset cache
jemrobinson Aug 7, 2025
d2b80b8
:truck: Move 'predict' config to top-level
jemrobinson Aug 8, 2025
3f7e811
:recycle: Lazy-load AnemoiDataset in ZebraDataset
jemrobinson Aug 8, 2025
269b253
:recycle: Lazy-load CHW in ZebraDataset
jemrobinson Aug 11, 2025
73841e0
:truck: Move CHW to DataSpace
jemrobinson Aug 11, 2025
a7868f6
:white_check_mark: Add basic ZebraDataset test
jemrobinson Aug 11, 2025
a68a69b
:white_check_mark: Add tests for a dummy AnemoiDataset
jemrobinson Aug 11, 2025
136452b
:truck: Move Anemoi helper classes to types
jemrobinson Aug 12, 2025
eb4aad0
:white_check_mark: Add mock dataset for testing
jemrobinson Aug 12, 2025
7b471fd
:white_check_mark: Add tests for DataSpace and length
jemrobinson Aug 12, 2025
a2961d4
:white_check_mark: Add tests for getitem
jemrobinson Aug 12, 2025
79d2b95
:white_check_mark: Add tests for index_from_date
jemrobinson Aug 12, 2025
03a79ca
:white_check_mark: Add tests for get_tchw
jemrobinson Aug 12, 2025
22690d0
:bulb: Fix docstring for ZebraModel::forward
jemrobinson Aug 12, 2025
81b730b
:rotating_light: Fix mypy type checks
jemrobinson Aug 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ice_station_zebra/config/predict/osisaf-south.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Name of the dataset group containing our prediction target
dataset_group: osisaf-south

# Number of future steps to predict
n_forecast_steps: 2

# Number of history steps to use when predicting
n_history_steps: 3
16 changes: 2 additions & 14 deletions ice_station_zebra/config/train/default.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
defaults:
- callbacks:
- device_stats
- optimizer: default
- trainer: default
- _self_

# Name of the dataset group containing our prediction target
predict_target: osisaf-south

# PyTorch lightning settings
trainer:
_target_: lightning.pytorch.trainer.trainer.Trainer
accelerator: auto
devices: auto
log_every_n_steps: 10
max_epochs: 50
optimizer:
_target_: torch.optim.AdamW
lr: 5e-4
3 changes: 3 additions & 0 deletions ice_station_zebra/config/train/optimizer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PyTorch Lightning optimizer settings
_target_: torch.optim.AdamW
lr: 5e-4
6 changes: 6 additions & 0 deletions ice_station_zebra/config/train/trainer/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# PyTorch Lightning trainer settings
_target_: lightning.pytorch.trainer.trainer.Trainer
accelerator: auto
devices: auto
log_every_n_steps: 10
max_epochs: 50
1 change: 1 addition & 0 deletions ice_station_zebra/config/zebra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defaults:
- loggers:
- wandb
- model: encode_null_decode
- predict: osisaf-south
- split: default
- train: default
- _self_
Expand Down
20 changes: 1 addition & 19 deletions ice_station_zebra/data/anemoi/zebra_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Type

Expand All @@ -10,28 +9,11 @@
from zarr.errors import PathNotFoundError

from ice_station_zebra.data.anemoi.preprocessors import IPreprocessor
from ice_station_zebra.types import AnemoiCreateArgs, AnemoiInspectArgs

log = logging.getLogger(__name__)


@dataclass
class AnemoiCreateArgs:
path: str
config: DictConfig
command: str = "unused"
threads: int = 0
processes: int = 0


@dataclass
class AnemoiInspectArgs:
path: str
detailed: bool
progress: bool
statistics: bool
size: bool


class ZebraDataset:
def __init__(
self, name: str, config: DictConfig, cls_preprocessor: Type[IPreprocessor]
Expand Down
87 changes: 77 additions & 10 deletions ice_station_zebra/data/lightning/combined_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,103 @@
from datetime import datetime

import numpy as np
from numpy.typing import NDArray
from torch.utils.data import Dataset

from .zebra_dataset import ZebraDataset
from ice_station_zebra.types import ArrayTCHW


class CombinedDataset(Dataset):
def __init__(self, datasets: Sequence[ZebraDataset], *, target: str) -> None:
def __init__(
self,
datasets: Sequence[ZebraDataset],
target: str,
*,
n_forecast_steps: int = 1,
n_history_steps: int = 1,
) -> None:
"""Constructor"""
super().__init__()

# Store the number of forecast and history steps
self.n_forecast_steps = n_forecast_steps
self.n_history_steps = n_history_steps

# Define target and input datasets
self.target = next(ds for ds in datasets if ds.name == target)
self.datasets = [ds for ds in datasets if ds != self.target]
self.inputs = [ds for ds in datasets]

# Require that all datasets have the same frequency
frequencies = sorted(set(ds.dataset.frequency for ds in datasets))
if len(frequencies) != 1:
msg = f"Cannot combine datasets with different frequencies: {frequencies}."
raise ValueError(msg)
self.frequency = np.timedelta64(frequencies[0])

# Get list of dates that are available in all datasets
self.available_dates = [
start_date
# Iterate over all dates in any dataset
for start_date in sorted(
{date for ds in datasets for date in ds.dataset.dates}
)
# Check that all inputs have n_history_steps starting on start_date
if all(
date in ds.dataset.dates
for date in self.get_history_steps(start_date)
for ds in self.inputs
)
# Check that the target has n_forecast_steps starting after the history dates
and all(
date in self.target.dataset.dates
for date in self.get_forecast_steps(start_date)
)
]

def __len__(self) -> int:
"""Return the total length of the dataset"""
return min([len(ds) for ds in self.datasets] + [len(self.target)])
return len(self.available_dates)

def __getitem__(self, idx: int) -> dict[str, ArrayTCHW]:
"""Return the data for a single timestep as a dictionary

def __getitem__(self, idx: int) -> tuple[NDArray[np.float32]]:
"""Return a single timestep"""
return tuple([ds[idx] for ds in self.datasets] + [self.target[idx]]) # type: ignore[return-value]
Returns:
A dictionary with dataset names as keys and a numpy array as the value.
The shape of each array is:
- input datasets: [n_history_steps, C_input_k, H_input_k, W_input_k]
- target dataset: [n_forecast_steps, C_target, H_target, W_target]
"""
return {
ds.name: ds.get_tchw(self.get_history_steps(self.available_dates[idx]))
for ds in self.inputs
} | {
"target": self.target.get_tchw(
self.get_forecast_steps(self.available_dates[idx])
)
}

def date_from_index(self, idx: int) -> datetime:
"""Return the date of the timestep"""
np_datetime = self.target.dataset.dates[idx]
np_datetime = self.available_dates[idx]
return datetime.strptime(str(np_datetime), r"%Y-%m-%dT%H:%M:%S")

def get_forecast_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
"""Return list of consecutive forecast dates for a given start date."""
return [
start_date + (idx + self.n_history_steps) * self.frequency
for idx in range(self.n_forecast_steps)
]

def get_history_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
"""Return list of consecutive history dates for a given start date."""
return [
start_date + idx * self.frequency for idx in range(self.n_history_steps)
]

@property
def end_date(self) -> np.datetime64:
"""Return the end date of the dataset."""
end_date = set(dataset.end_date for dataset in self.datasets)
end_date = set(dataset.end_date for dataset in self.inputs)
if len(end_date) != 1:
msg = f"Datasets have {len(end_date)} different end dates"
raise ValueError(msg)
Expand All @@ -40,7 +107,7 @@ def end_date(self) -> np.datetime64:
@property
def start_date(self) -> np.datetime64:
"""Return the start date of the dataset."""
start_date = set(dataset.start_date for dataset in self.datasets)
start_date = set(dataset.start_date for dataset in self.inputs)
if len(start_date) != 1:
msg = f"Datasets have {len(start_date)} different start dates"
raise ValueError(msg)
Expand Down
23 changes: 16 additions & 7 deletions ice_station_zebra/data/lightning/zebra_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from functools import cached_property
from pathlib import Path

import numpy as np
from lightning import LightningDataModule
from numpy.typing import NDArray
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from ice_station_zebra.types import DataloaderArgs, DataSpace
from ice_station_zebra.types import ArrayTCHW, DataloaderArgs, DataSpace

from .combined_dataset import CombinedDataset
from .zebra_dataset import ZebraDataset
Expand Down Expand Up @@ -37,10 +35,11 @@ def __init__(self, config: DictConfig) -> None:
logger.debug(f"... {dataset_group}")

# Check prediction target
self.predict_target = config["train"]["predict_target"]
self.predict_target = config["predict"]["dataset_group"]
if self.predict_target not in self.dataset_groups:
raise ValueError(f"Could not find prediction target {self.predict_target}")

# Set periods for train, validation, and test
self.batch_size = int(config["split"]["batch_size"])
self.train_period = {
k: None if v == "None" else v for k, v in config["split"]["train"].items()
Expand All @@ -53,6 +52,10 @@ def __init__(self, config: DictConfig) -> None:
k: None if v == "None" else v for k, v in config["split"]["test"].items()
}

# Set history and forecast steps
self.n_forecast_steps = int(config["predict"].get("n_forecast_steps", 1))
self.n_history_steps = int(config["predict"].get("n_history_steps", 1))

# Set common arguments for the dataloader
self._common_dataloader_kwargs = DataloaderArgs(
batch_size=self.batch_size,
Expand Down Expand Up @@ -82,7 +85,7 @@ def output_space(self) -> DataSpace:

def train_dataloader(
self,
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
) -> DataLoader[dict[str, ArrayTCHW]]:
"""Construct train dataloader"""
dataset = CombinedDataset(
[
Expand All @@ -94,6 +97,8 @@ def train_dataloader(
)
for name, paths in self.dataset_groups.items()
],
n_forecast_steps=self.n_forecast_steps,
n_history_steps=self.n_history_steps,
target=self.predict_target,
)
logger.info(
Expand All @@ -106,7 +111,7 @@ def train_dataloader(

def val_dataloader(
self,
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
) -> DataLoader[dict[str, ArrayTCHW]]:
"""Construct validation dataloader"""
dataset = CombinedDataset(
[
Expand All @@ -118,6 +123,8 @@ def val_dataloader(
)
for name, paths in self.dataset_groups.items()
],
n_forecast_steps=self.n_forecast_steps,
n_history_steps=self.n_history_steps,
target=self.predict_target,
)
logger.info(
Expand All @@ -130,7 +137,7 @@ def val_dataloader(

def test_dataloader(
self,
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
) -> DataLoader[dict[str, ArrayTCHW]]:
"""Construct test dataloader"""
dataset = CombinedDataset(
[
Expand All @@ -142,6 +149,8 @@ def test_dataloader(
)
for name, paths in self.dataset_groups.items()
],
n_forecast_steps=self.n_forecast_steps,
n_history_steps=self.n_history_steps,
target=self.predict_target,
)
logger.info(
Expand Down
53 changes: 42 additions & 11 deletions ice_station_zebra/data/lightning/zebra_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path
from collections.abc import Sequence

import numpy as np
from anemoi.datasets.data import open_dataset
from numpy.typing import NDArray
from anemoi.datasets.data.dataset import Dataset as AnemoiDataset
from cachetools import LRUCache, cachedmethod
from torch.utils.data import Dataset

from ice_station_zebra.types import DataSpace
from ice_station_zebra.types import ArrayCHW, ArrayTCHW, DataSpace


class ZebraDataset(Dataset):
Expand All @@ -23,23 +25,41 @@ def __init__(
We reshape each time point to: variables; pos_x; pos_y
"""
super().__init__()
self.dataset = open_dataset(input_files, start=start, end=end)
self.dataset._name = name
self._cache: LRUCache = LRUCache(maxsize=128)
self._dataset: AnemoiDataset | None = None
self._end = end
self._input_files = input_files
self._name = name
self._start = start

@property
def dataset(self) -> AnemoiDataset:
"""Load the underlying Anemoi dataset."""
if not self._dataset:
self._dataset = open_dataset(
self._input_files, start=self._start, end=self._end
)
self._dataset._name = self._name
return self._dataset

@property
def end_date(self) -> np.datetime64:
"""Return the end date of the dataset."""
return self.dataset.end_date

@property
def name(self) -> str | None:
def name(self) -> str:
"""Return the name of the dataset."""
return self.dataset.name
return self._name

@property
def space(self) -> DataSpace:
"""Return the data space for this dataset."""
return DataSpace(channels=self.dataset.shape[1], shape=self.dataset.field_shape)
return DataSpace(
channels=self.dataset.shape[1],
name=self.name,
shape=self.dataset.field_shape,
)

@property
def start_date(self) -> np.datetime64:
Expand All @@ -50,7 +70,18 @@ def __len__(self) -> int:
"""Return the total length of the dataset"""
return len(self.dataset)

def __getitem__(self, idx: int) -> NDArray[np.float32]:
"""Return a single timestep after reshaping to [C, H, W]"""
chw = (self.space.channels, *self.space.shape)
return self.dataset[idx].reshape(chw)
def __getitem__(self, idx: int) -> ArrayCHW:
"""Return the data for a single timestep in [C, H, W] format"""
return self.dataset[idx].reshape(self.space.chw)

@cachedmethod(lambda self: self._cache)
def index_from_date(self, date: np.datetime64) -> int:
"""Return the index of a given date in the dataset."""
idx, _, _ = self.dataset.to_index(date, 0)
return idx

def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW:
"""Return the data for a series of timesteps in [T, C, H, W] format"""
return np.stack(
[self[self.index_from_date(target_date)] for target_date in dates], axis=0
Comment thread
jemrobinson marked this conversation as resolved.
)
4 changes: 2 additions & 2 deletions ice_station_zebra/evaluation/callbacks/plotting_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def on_test_batch_end(
date_ = dataset.date_from_index(batch_size * batch_idx)

# Load the ground truth and prediction
np_ground_truth = outputs["target"].cpu().numpy()[0, 0, :, :]
np_prediction = outputs["output"].cpu().numpy()[0, 0, :, :]
np_ground_truth = outputs["target"].cpu().numpy()[0, 0, 0, :, :]
np_prediction = outputs["output"].cpu().numpy()[0, 0, 0, :, :]

# Create each requested plot
images = {
Expand Down
Loading
Loading