Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__pycache__
.coverage
.coverage*
.DS_Store
.ipynb_checkpoints/
.python-version
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ repos:
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: ^tests/legacy_metrics.py
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.12.5"
hooks:
# Run the linter
- id: ruff
types_or: [python, pyi]
- id: ruff-check
args: ["--fix", "--show-fixes"]
pass_filenames: false
# Run the formatter
- id: ruff-format
pass_filenames: false
2 changes: 1 addition & 1 deletion ice_station_zebra/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typer
from hydra.core.utils import simple_stdout_log_config

from ice_station_zebra.data.anemoi import datasets_cli
from ice_station_zebra.data_processors import datasets_cli
from ice_station_zebra.evaluation import evaluation_cli
from ice_station_zebra.training import training_cli

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
from torch.utils.data import Dataset

from .zebra_dataset import ZebraDataset
from ice_station_zebra.types import ArrayTCHW

from .zebra_dataset import ZebraDataset


class CombinedDataset(Dataset):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from collections.abc import Sequence
from pathlib import Path

import numpy as np
from anemoi.datasets.data import open_dataset
Expand Down Expand Up @@ -74,14 +74,14 @@ def __getitem__(self, idx: int) -> ArrayCHW:
"""Return the data for a single timestep in [C, H, W] format"""
return self.dataset[idx].reshape(self.space.chw)

@cachedmethod(lambda self: self._cache)
def index_from_date(self, date: np.datetime64) -> int:
"""Return the index of a given date in the dataset."""
idx, _, _ = self.dataset.to_index(date, 0)
return idx

def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW:
"""Return the data for a series of timesteps in [T, C, H, W] format"""
return np.stack(
[self[self.index_from_date(target_date)] for target_date in dates], axis=0
)

@cachedmethod(lambda self: self._cache)
def index_from_date(self, date: np.datetime64) -> int:
"""Return the index of a given date in the dataset."""
idx, _, _ = self.dataset.to_index(date, 0)
return idx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ice_station_zebra.cli import hydra_adaptor

from .zebra_dataset_factory import ZebraDatasetFactory
from .zebra_data_processor_factory import ZebraDataProcessorFactory

# Create the typer app
datasets_cli = typer.Typer(help="Manage datasets")
Expand All @@ -17,7 +17,7 @@
@hydra_adaptor
def create(config: DictConfig) -> None:
"""Create all datasets"""
factory = ZebraDatasetFactory(config)
factory = ZebraDataProcessorFactory(config)
for dataset in factory.datasets:
log.info(f"Working on {dataset.name}")
dataset.create()
Expand All @@ -27,7 +27,7 @@ def create(config: DictConfig) -> None:
@hydra_adaptor
def inspect(config: DictConfig) -> None:
"""Inspect all datasets"""
factory = ZebraDatasetFactory(config)
factory = ZebraDataProcessorFactory(config)
for dataset in factory.datasets:
log.info(f"Working on {dataset.name}")
dataset.inspect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from omegaconf import DictConfig, OmegaConf
from zarr.errors import PathNotFoundError

from ice_station_zebra.data.anemoi.preprocessors import IPreprocessor
from ice_station_zebra.types import AnemoiCreateArgs, AnemoiInspectArgs

from .preprocessors import IPreprocessor

log = logging.getLogger(__name__)


class ZebraDataset:
class ZebraDataProcessor:
def __init__(
self, name: str, config: DictConfig, cls_preprocessor: Type[IPreprocessor]
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from omegaconf import DictConfig

from .preprocessors import IceNetSICPreprocessor, NullPreprocessor
from .zebra_dataset import ZebraDataset
from .zebra_data_processor import ZebraDataProcessor


class ZebraDatasetFactory:
class ZebraDataProcessorFactory:
preprocessors = {
"None": NullPreprocessor,
"IceNetSIC": IceNetSICPreprocessor,
}

def __init__(self, config: DictConfig) -> None:
self.datasets: list[ZebraDataset] = []
self.datasets: list[ZebraDataProcessor] = []
for dataset_name in config["datasets"]:
cls_preprocessor = self.preprocessors[
config["datasets"][dataset_name]
.get("preprocessor", {})
.get("type", "None")
]
self.datasets.append(ZebraDataset(dataset_name, config, cls_preprocessor))
self.datasets.append(
ZebraDataProcessor(dataset_name, config, cls_preprocessor)
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch import Tensor
from torch.utils.data import DataLoader

from ice_station_zebra.data_loaders import CombinedDataset
from ice_station_zebra.visualisations import plot_sic_comparison
from ice_station_zebra.data.lightning import CombinedDataset

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion ice_station_zebra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightning import Callback, Trainer
from omegaconf import DictConfig, OmegaConf

from ice_station_zebra.data.lightning import ZebraDataModule
from ice_station_zebra.data_loaders import ZebraDataModule
from ice_station_zebra.models import ZebraModel
from ice_station_zebra.utils import get_timestamp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn

from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW

from .base_decoder import BaseDecoder


Expand Down
2 changes: 1 addition & 1 deletion ice_station_zebra/models/encode_process_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from omegaconf import DictConfig

from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
from ice_station_zebra.models.encoders import BaseEncoder
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW

from .zebra_model import ZebraModel

Expand Down
1 change: 1 addition & 0 deletions ice_station_zebra/models/encoders/base_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

import torch.nn as nn

from ice_station_zebra.types import TensorNCHW, TensorNTCHW


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn

from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW

from .base_encoder import BaseEncoder


Expand Down
1 change: 1 addition & 0 deletions ice_station_zebra/models/processors/null.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn

from ice_station_zebra.types import TensorNCHW


Expand Down
4 changes: 2 additions & 2 deletions ice_station_zebra/models/processors/unet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import torch.nn as nn
from torch import Tensor
import torch

from ice_station_zebra.models.common.convblock import ConvBlock
from ice_station_zebra.models.common.bottleneckblock import BottleneckBlock
from ice_station_zebra.models.common.convblock import ConvBlock
from ice_station_zebra.models.common.upconvblock import UpconvBlock


Expand Down
2 changes: 1 addition & 1 deletion ice_station_zebra/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lightning import Callback, Trainer
from omegaconf import DictConfig, OmegaConf

from ice_station_zebra.data.lightning import ZebraDataModule
from ice_station_zebra.data_loaders import ZebraDataModule
from ice_station_zebra.models import ZebraModel
from ice_station_zebra.utils import generate_run_name, get_wandb_logger

Expand Down
8 changes: 4 additions & 4 deletions ice_station_zebra/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from jaxtyping import Float
from numpy.typing import NDArray
from numpy import float32
from typing import Any, Self, TypedDict
from torch import Tensor

from jaxtyping import Float
from numpy import float32
from numpy.typing import NDArray
from omegaconf import DictConfig
from torch import Tensor

ArrayCHW = Float[NDArray[float32], "channels height width"]
ArrayTCHW = Float[NDArray[float32], "time channels height width"]
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ addopts = [

[tool.ruff]
line-length = 88
exclude = ["notebooks/*ipynb"]

[tool.ruff.lint]
extend-select = ["I"]

[tool.uv.build-backend]
module-root = ""
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime
from pathlib import Path
from typing import Any

import pytest
import xarray as xr
from typing import Any
from omegaconf import DictConfig
from anemoi.datasets.commands.create import Create
from omegaconf import DictConfig

from ice_station_zebra.types import AnemoiCreateArgs


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path

import numpy as np
import pytest

from ice_station_zebra.data.lightning.zebra_dataset import ZebraDataset
from ice_station_zebra.data_loaders.zebra_dataset import ZebraDataset
from ice_station_zebra.types import DataSpace


Expand Down
Loading