diff --git a/.gitignore b/.gitignore index 8b26e14b..fc3c8009 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ __pycache__ -.coverage +.coverage* .DS_Store .ipynb_checkpoints/ .python-version diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 008a63ec..0c8f6c3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,15 +16,15 @@ repos: - id: mixed-line-ending - id: name-tests-test args: ["--pytest-test-first"] - exclude: ^tests/legacy_metrics.py - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.12.5" hooks: # Run the linter - - id: ruff - types_or: [python, pyi] + - id: ruff-check args: ["--fix", "--show-fixes"] + pass_filenames: false # Run the formatter - id: ruff-format + pass_filenames: false diff --git a/ice_station_zebra/cli/main.py b/ice_station_zebra/cli/main.py index 8f3b718c..6ed99b86 100644 --- a/ice_station_zebra/cli/main.py +++ b/ice_station_zebra/cli/main.py @@ -1,7 +1,7 @@ import typer from hydra.core.utils import simple_stdout_log_config -from ice_station_zebra.data.anemoi import datasets_cli +from ice_station_zebra.data_processors import datasets_cli from ice_station_zebra.evaluation import evaluation_cli from ice_station_zebra.training import training_cli diff --git a/ice_station_zebra/data/lightning/__init__.py b/ice_station_zebra/data_loaders/__init__.py similarity index 100% rename from ice_station_zebra/data/lightning/__init__.py rename to ice_station_zebra/data_loaders/__init__.py diff --git a/ice_station_zebra/data/lightning/combined_dataset.py b/ice_station_zebra/data_loaders/combined_dataset.py similarity index 99% rename from ice_station_zebra/data/lightning/combined_dataset.py rename to ice_station_zebra/data_loaders/combined_dataset.py index 894245ec..0c040d18 100644 --- a/ice_station_zebra/data/lightning/combined_dataset.py +++ b/ice_station_zebra/data_loaders/combined_dataset.py @@ -4,9 +4,10 @@ import numpy as np from torch.utils.data import Dataset -from .zebra_dataset import ZebraDataset from ice_station_zebra.types import ArrayTCHW +from .zebra_dataset import ZebraDataset + class CombinedDataset(Dataset): def __init__( diff --git a/ice_station_zebra/data/lightning/zebra_data_module.py b/ice_station_zebra/data_loaders/zebra_data_module.py similarity index 100% rename from ice_station_zebra/data/lightning/zebra_data_module.py rename to ice_station_zebra/data_loaders/zebra_data_module.py diff --git a/ice_station_zebra/data/lightning/zebra_dataset.py b/ice_station_zebra/data_loaders/zebra_dataset.py similarity index 100% rename from ice_station_zebra/data/lightning/zebra_dataset.py rename to ice_station_zebra/data_loaders/zebra_dataset.py index 7faf741e..231d7d03 100644 --- a/ice_station_zebra/data/lightning/zebra_dataset.py +++ b/ice_station_zebra/data_loaders/zebra_dataset.py @@ -1,5 +1,5 @@ -from pathlib import Path from collections.abc import Sequence +from pathlib import Path import numpy as np from anemoi.datasets.data import open_dataset @@ -74,14 +74,14 @@ def __getitem__(self, idx: int) -> ArrayCHW: """Return the data for a single timestep in [C, H, W] format""" return self.dataset[idx].reshape(self.space.chw) - @cachedmethod(lambda self: self._cache) - def index_from_date(self, date: np.datetime64) -> int: - """Return the index of a given date in the dataset.""" - idx, _, _ = self.dataset.to_index(date, 0) - return idx - def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW: """Return the data for a series of timesteps in [T, C, H, W] format""" return np.stack( [self[self.index_from_date(target_date)] for target_date in dates], axis=0 ) + + @cachedmethod(lambda self: self._cache) + def index_from_date(self, date: np.datetime64) -> int: + """Return the index of a given date in the dataset.""" + idx, _, _ = self.dataset.to_index(date, 0) + return idx diff --git a/ice_station_zebra/data/anemoi/__init__.py b/ice_station_zebra/data_processors/__init__.py similarity index 100% rename from ice_station_zebra/data/anemoi/__init__.py rename to ice_station_zebra/data_processors/__init__.py diff --git a/ice_station_zebra/data/anemoi/cli.py b/ice_station_zebra/data_processors/cli.py similarity index 81% rename from ice_station_zebra/data/anemoi/cli.py rename to ice_station_zebra/data_processors/cli.py index c4142290..808a17a0 100644 --- a/ice_station_zebra/data/anemoi/cli.py +++ b/ice_station_zebra/data_processors/cli.py @@ -5,7 +5,7 @@ from ice_station_zebra.cli import hydra_adaptor -from .zebra_dataset_factory import ZebraDatasetFactory +from .zebra_data_processor_factory import ZebraDataProcessorFactory # Create the typer app datasets_cli = typer.Typer(help="Manage datasets") @@ -17,7 +17,7 @@ @hydra_adaptor def create(config: DictConfig) -> None: """Create all datasets""" - factory = ZebraDatasetFactory(config) + factory = ZebraDataProcessorFactory(config) for dataset in factory.datasets: log.info(f"Working on {dataset.name}") dataset.create() @@ -27,7 +27,7 @@ def create(config: DictConfig) -> None: @hydra_adaptor def inspect(config: DictConfig) -> None: """Inspect all datasets""" - factory = ZebraDatasetFactory(config) + factory = ZebraDataProcessorFactory(config) for dataset in factory.datasets: log.info(f"Working on {dataset.name}") dataset.inspect() diff --git a/ice_station_zebra/data/anemoi/preprocessors/__init__.py b/ice_station_zebra/data_processors/preprocessors/__init__.py similarity index 100% rename from ice_station_zebra/data/anemoi/preprocessors/__init__.py rename to ice_station_zebra/data_processors/preprocessors/__init__.py diff --git a/ice_station_zebra/data/anemoi/preprocessors/base.py b/ice_station_zebra/data_processors/preprocessors/base.py similarity index 100% rename from ice_station_zebra/data/anemoi/preprocessors/base.py rename to ice_station_zebra/data_processors/preprocessors/base.py diff --git a/ice_station_zebra/data/anemoi/preprocessors/icenet_sic.py b/ice_station_zebra/data_processors/preprocessors/icenet_sic.py similarity index 100% rename from ice_station_zebra/data/anemoi/preprocessors/icenet_sic.py rename to ice_station_zebra/data_processors/preprocessors/icenet_sic.py diff --git a/ice_station_zebra/data/anemoi/zebra_dataset.py b/ice_station_zebra/data_processors/zebra_data_processor.py similarity index 96% rename from ice_station_zebra/data/anemoi/zebra_dataset.py rename to ice_station_zebra/data_processors/zebra_data_processor.py index ba060e10..cdbf99a5 100644 --- a/ice_station_zebra/data/anemoi/zebra_dataset.py +++ b/ice_station_zebra/data_processors/zebra_data_processor.py @@ -8,13 +8,14 @@ from omegaconf import DictConfig, OmegaConf from zarr.errors import PathNotFoundError -from ice_station_zebra.data.anemoi.preprocessors import IPreprocessor from ice_station_zebra.types import AnemoiCreateArgs, AnemoiInspectArgs +from .preprocessors import IPreprocessor + log = logging.getLogger(__name__) -class ZebraDataset: +class ZebraDataProcessor: def __init__( self, name: str, config: DictConfig, cls_preprocessor: Type[IPreprocessor] ) -> None: diff --git a/ice_station_zebra/data/anemoi/zebra_dataset_factory.py b/ice_station_zebra/data_processors/zebra_data_processor_factory.py similarity index 65% rename from ice_station_zebra/data/anemoi/zebra_dataset_factory.py rename to ice_station_zebra/data_processors/zebra_data_processor_factory.py index 642740b1..ee213194 100644 --- a/ice_station_zebra/data/anemoi/zebra_dataset_factory.py +++ b/ice_station_zebra/data_processors/zebra_data_processor_factory.py @@ -1,21 +1,23 @@ from omegaconf import DictConfig from .preprocessors import IceNetSICPreprocessor, NullPreprocessor -from .zebra_dataset import ZebraDataset +from .zebra_data_processor import ZebraDataProcessor -class ZebraDatasetFactory: +class ZebraDataProcessorFactory: preprocessors = { "None": NullPreprocessor, "IceNetSIC": IceNetSICPreprocessor, } def __init__(self, config: DictConfig) -> None: - self.datasets: list[ZebraDataset] = [] + self.datasets: list[ZebraDataProcessor] = [] for dataset_name in config["datasets"]: cls_preprocessor = self.preprocessors[ config["datasets"][dataset_name] .get("preprocessor", {}) .get("type", "None") ] - self.datasets.append(ZebraDataset(dataset_name, config, cls_preprocessor)) + self.datasets.append( + ZebraDataProcessor(dataset_name, config, cls_preprocessor) + ) diff --git a/ice_station_zebra/evaluation/callbacks/plotting_callback.py b/ice_station_zebra/evaluation/callbacks/plotting_callback.py index 527408c0..c0ec8b8a 100644 --- a/ice_station_zebra/evaluation/callbacks/plotting_callback.py +++ b/ice_station_zebra/evaluation/callbacks/plotting_callback.py @@ -6,8 +6,8 @@ from torch import Tensor from torch.utils.data import DataLoader +from ice_station_zebra.data_loaders import CombinedDataset from ice_station_zebra.visualisations import plot_sic_comparison -from ice_station_zebra.data.lightning import CombinedDataset logger = logging.getLogger(__name__) diff --git a/ice_station_zebra/evaluation/evaluator.py b/ice_station_zebra/evaluation/evaluator.py index 88a26f51..1b40f11e 100644 --- a/ice_station_zebra/evaluation/evaluator.py +++ b/ice_station_zebra/evaluation/evaluator.py @@ -5,7 +5,7 @@ from lightning import Callback, Trainer from omegaconf import DictConfig, OmegaConf -from ice_station_zebra.data.lightning import ZebraDataModule +from ice_station_zebra.data_loaders import ZebraDataModule from ice_station_zebra.models import ZebraModel from ice_station_zebra.utils import get_timestamp diff --git a/ice_station_zebra/models/decoders/naive_latent_space_decoder.py b/ice_station_zebra/models/decoders/naive_latent_space_decoder.py index ba734935..f18393e0 100644 --- a/ice_station_zebra/models/decoders/naive_latent_space_decoder.py +++ b/ice_station_zebra/models/decoders/naive_latent_space_decoder.py @@ -4,6 +4,7 @@ import torch.nn as nn from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW + from .base_decoder import BaseDecoder diff --git a/ice_station_zebra/models/encode_process_decode.py b/ice_station_zebra/models/encode_process_decode.py index a3441db0..f9dde6b4 100644 --- a/ice_station_zebra/models/encode_process_decode.py +++ b/ice_station_zebra/models/encode_process_decode.py @@ -4,8 +4,8 @@ import torch from omegaconf import DictConfig -from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW from ice_station_zebra.models.encoders import BaseEncoder +from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW from .zebra_model import ZebraModel diff --git a/ice_station_zebra/models/encoders/base_encoder.py b/ice_station_zebra/models/encoders/base_encoder.py index e0f82702..b3c0e5f9 100644 --- a/ice_station_zebra/models/encoders/base_encoder.py +++ b/ice_station_zebra/models/encoders/base_encoder.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import torch.nn as nn + from ice_station_zebra.types import TensorNCHW, TensorNTCHW diff --git a/ice_station_zebra/models/encoders/naive_latent_space_encoder.py b/ice_station_zebra/models/encoders/naive_latent_space_encoder.py index 1a2d6abc..3458d3e9 100644 --- a/ice_station_zebra/models/encoders/naive_latent_space_encoder.py +++ b/ice_station_zebra/models/encoders/naive_latent_space_encoder.py @@ -4,6 +4,7 @@ import torch.nn as nn from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW + from .base_encoder import BaseEncoder diff --git a/ice_station_zebra/models/processors/null.py b/ice_station_zebra/models/processors/null.py index 39be27d4..410daf80 100644 --- a/ice_station_zebra/models/processors/null.py +++ b/ice_station_zebra/models/processors/null.py @@ -1,4 +1,5 @@ import torch.nn as nn + from ice_station_zebra.types import TensorNCHW diff --git a/ice_station_zebra/models/processors/unet.py b/ice_station_zebra/models/processors/unet.py index 69845c62..5d14be9d 100644 --- a/ice_station_zebra/models/processors/unet.py +++ b/ice_station_zebra/models/processors/unet.py @@ -1,9 +1,9 @@ +import torch import torch.nn as nn from torch import Tensor -import torch -from ice_station_zebra.models.common.convblock import ConvBlock from ice_station_zebra.models.common.bottleneckblock import BottleneckBlock +from ice_station_zebra.models.common.convblock import ConvBlock from ice_station_zebra.models.common.upconvblock import UpconvBlock diff --git a/ice_station_zebra/training/trainer.py b/ice_station_zebra/training/trainer.py index ba415763..2e65e5e0 100644 --- a/ice_station_zebra/training/trainer.py +++ b/ice_station_zebra/training/trainer.py @@ -5,7 +5,7 @@ from lightning import Callback, Trainer from omegaconf import DictConfig, OmegaConf -from ice_station_zebra.data.lightning import ZebraDataModule +from ice_station_zebra.data_loaders import ZebraDataModule from ice_station_zebra.models import ZebraModel from ice_station_zebra.utils import generate_run_name, get_wandb_logger diff --git a/ice_station_zebra/types.py b/ice_station_zebra/types.py index f42aeede..cfaa81e4 100644 --- a/ice_station_zebra/types.py +++ b/ice_station_zebra/types.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from dataclasses import dataclass -from jaxtyping import Float -from numpy.typing import NDArray -from numpy import float32 from typing import Any, Self, TypedDict -from torch import Tensor +from jaxtyping import Float +from numpy import float32 +from numpy.typing import NDArray from omegaconf import DictConfig +from torch import Tensor ArrayCHW = Float[NDArray[float32], "channels height width"] ArrayTCHW = Float[NDArray[float32], "time channels height width"] diff --git a/pyproject.toml b/pyproject.toml index fdf5d4fb..015ab7fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,10 @@ addopts = [ [tool.ruff] line-length = 88 +exclude = ["notebooks/*ipynb"] + +[tool.ruff.lint] +extend-select = ["I"] [tool.uv.build-backend] module-root = "" diff --git a/tests/conftest.py b/tests/conftest.py index e6f011a2..81db733a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ import datetime from pathlib import Path +from typing import Any + import pytest import xarray as xr -from typing import Any -from omegaconf import DictConfig from anemoi.datasets.commands.create import Create +from omegaconf import DictConfig + from ice_station_zebra.types import AnemoiCreateArgs diff --git a/tests/data/lightning/test_zebra_dataset.py b/tests/data_loaders/test_zebra_dataset.py similarity index 98% rename from tests/data/lightning/test_zebra_dataset.py rename to tests/data_loaders/test_zebra_dataset.py index 039f1b60..54fa8f2d 100644 --- a/tests/data/lightning/test_zebra_dataset.py +++ b/tests/data_loaders/test_zebra_dataset.py @@ -1,8 +1,9 @@ from pathlib import Path + import numpy as np import pytest -from ice_station_zebra.data.lightning.zebra_dataset import ZebraDataset +from ice_station_zebra.data_loaders.zebra_dataset import ZebraDataset from ice_station_zebra.types import DataSpace