Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
135e918
:tada: Initial testing framework
jemrobinson Jul 25, 2025
dccce04
:wrench: Add pre-commmit linting and formatting
jemrobinson Jul 25, 2025
d544568
:wrench: Add a code testing CI job
jemrobinson Jul 25, 2025
2d291ce
:wrench: Run mypy in strict mode
jemrobinson Jul 28, 2025
24f745a
:truck: Move mypy verification from pre-commit to uv
jemrobinson Jul 28, 2025
852eed5
:wrench: Ignore external libraries when type-checking
jemrobinson Jul 28, 2025
93f2103
:arrow_up: Add mypy and pandas-stubs to test group
jemrobinson Jul 28, 2025
9d0dd3a
:rotating_light: Fix formatting issues
jemrobinson Jul 28, 2025
0323d3f
:bug: Fix loop variable shadowing global variable
jemrobinson Jul 28, 2025
0da5f5a
:alien: Ignore type issues for explicitly overridden hydra decorator
jemrobinson Jul 28, 2025
a3e0b89
:alien: Ignore assigment type errors for DictConfigs with known contents
jemrobinson Jul 28, 2025
d6f86a3
:alien: Add explicit method for constructing a DataSpace from a DictC…
jemrobinson Jul 28, 2025
dbc7c0c
:truck: Move LightningBatch to types
jemrobinson Jul 28, 2025
4e521f6
:bug: Fix type-checking paths when loading data from trainer test_dat…
jemrobinson Jul 28, 2025
dbab2b6
:label: Fix missing types where the default was ambiguous
jemrobinson Jul 28, 2025
0d0ec38
:art: Convert multi-valued metrics into single value before logging
jemrobinson Jul 28, 2025
be8f8d6
:art: Consistently overwrite config values with calculated values whe…
jemrobinson Jul 28, 2025
2d4d157
:bug: CLI checks need to account for coloured output text
jemrobinson Jul 29, 2025
26708a8
:art: Copy function docstring into wrapper docstring when applying hy…
jemrobinson Jul 29, 2025
36f357b
:white_check_mark: Add CLI --help tests
jemrobinson Jul 29, 2025
4d10da1
:truck: Refactored tests into classes and ensured that executable nam…
jemrobinson Jul 29, 2025
2811558
:wrench: Use uv standard naming for environments and build system
jemrobinson Jul 29, 2025
8b46966
:wrench: Add test coverage comment bot
jemrobinson Jul 29, 2025
2209806
:arrow_down: Note that Python 3.13 is currently incompatible with the…
jemrobinson Jul 29, 2025
3e23f67
:memo: Add notes about running tests and style checks
jemrobinson Jul 29, 2025
39db2bd
:art: Renamed function as per CoPilot suggestion
jemrobinson Jul 29, 2025
db78acc
:wrench: Update permissions for code-coverage commenting
jemrobinson Jul 29, 2025
8e89cfd
:wrench: Update coverage settings
jemrobinson Jul 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/regular-issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ body:
- type: markdown
attributes:
value: |
Before you submit:
Before you submit:
- Set priority label ('P1', 'P2' or 'P3')
- Link a milestone (if applicable)
- Link a milestone (if applicable)
- Assign yourself if you're picking this up
- type: textarea
id: problem
Expand Down Expand Up @@ -39,4 +39,4 @@ body:
- Slack thread: https://...
- Docs/Notebooks: [Notebook X](https://...), [Spec Y](https://...)
validations:
required: false
required: false
26 changes: 26 additions & 0 deletions .github/workflows/code_style.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
name: Fix code style

# Run workflow on pushes to matching branches
on: # yamllint disable-line rule:truthy
push:
branches: [main]
pull_request:

jobs:
fix_code_style:
runs-on: ubuntu-latest
name: Format and lint code
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Run pre-commit
uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual --all-files
58 changes: 58 additions & 0 deletions .github/workflows/test_code.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
---
name: Run tests

# Run workflow on pushes to matching branches
on: # yamllint disable-line rule:truthy
push:
branches: [main]
pull_request:

jobs:
run_tests:
runs-on: ubuntu-latest
name: Run Python tests
permissions:
# Gives the action the necessary permissions for publishing new
# comments in pull requests.
pull-requests: write
# Gives the action the necessary permissions for pushing data to the
# python-coverage-comment-action branch, and for editing existing
# comments (to avoid publishing multiple comments in the same PR)
contents: write
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.12"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.8.3"
- name: Run pytest
run: uv run --group dev pytest
- name: Run mypy
run: uv run --group dev mypy .
# For security reasons, PRs created from forks cannot generate PR comments directly
# (see https://securitylab.github.com/research/github-actions-preventing-pwn-requests/).
# Instead we need to trigger another workflow after this one completes.
- name: Generate test coverage comment
id: coverage_comment
uses: py-cov-action/python-coverage-comment-action@v3
with:
GITHUB_TOKEN: ${{ github.token }}
# Save the coverage comment for later use
# See https://github.com/py-cov-action/python-coverage-comment-action/blob/main/README.md
- name: Save coverage comment as an artifact
uses: actions/upload-artifact@v4
if: steps.coverage_comment.outputs.COMMENT_FILE_WRITTEN == 'true'
with:
name: python-coverage-comment-action
path: python-coverage-comment-action.txt
32 changes: 32 additions & 0 deletions .github/workflows/test_coverage.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
name: Post test coverage GitHub comment

# Run workflow after test_code has completed
on: # yamllint disable-line rule:truthy
workflow_run:
workflows: ["Run tests"]
types:
- completed

jobs:
coverage:
runs-on: ubuntu-latest
if: github.event.workflow_run.event == 'pull_request' && github.event.workflow_run.conclusion == 'success'
permissions:
# Gives the action the necessary permissions for publishing new
# comments in pull requests.
pull-requests: write
# Gives the action the necessary permissions for editing existing
# comments (to avoid publishing multiple comments in the same PR)
contents: write
# Gives the action the necessary permissions for looking up the
# workflow that launched this workflow, and download the related
# artifact that contains the comment to be published
actions: read
steps:
# Post the pre-generated coverage comment
- name: Post coverage comment
uses: py-cov-action/python-coverage-comment-action@v3
with:
GITHUB_TOKEN: ${{ github.token }}
GITHUB_PR_RUN_ID: ${{ github.event.workflow_run.id }}
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
__pycache__
.coverage
.DS_Store
.ipynb_checkpoints/
.python-version
.venv
*local.yaml
dist
outputs
wandb/
wandb
30 changes: 30 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
ci:
autoupdate_commit_msg: ":rotating_light: Fix pre-commit linting errors"
autofix_commit_msg: ":rotating_light: Fix pre-commit linting errors"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v5.0.0"
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: ^tests/legacy_metrics.py
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.12.5"
hooks:
# Run the linter
- id: ruff
types_or: [python, pyi]
args: ["--fix", "--show-fixes"]
# Run the formatter
- id: ruff-format
11 changes: 10 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,13 @@ Welcome! Thanks for your interest in contributing to this project.

## Coding conventions

TBD
- We use `pre-commit` to enforce code-style conventions.
- You can install `pre-commit` following the instructions [here](https://pre-commit.com/#install).
- Run `pre-commit install` inside your locally-checked-out repository to activate it.
- You can also run the style checks without installing `pre-commit` by running `uv run --group dev pre-commit run --all-files`
Comment thread
jemrobinson marked this conversation as resolved.

## Tests

We encourage the use of tests across the whole codebase.
Run the `pytest` tests with `uv run --group dev pytest`.
Run the `mypy` checks with `uv run --group dev mypy .`.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ A pipeline for predicting sea ice.

## Setting up your environment

### Tools

You will need to install the following tools if you want to develop this project:

- [`uv`](https://docs.astral.sh/uv/getting-started/installation/)

### Creating your own configuration file

Create a file in `config` that is called `<your chosen name here>.local.yaml`.
Expand Down
5 changes: 3 additions & 2 deletions ice_station_zebra/cli/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def wrapper(

# Since the additional parameters are keyword arguments we can simply append them
combined_parameters = list(itertools.chain(function_params, additional_params))
wrapper.__signature__ = fn_signature.replace(parameters=combined_parameters)
wrapper.__signature__ = fn_signature.replace(parameters=combined_parameters) # type: ignore[attr-defined]
wrapper.__name__ = function.__name__
return wrapper
wrapper.__doc__ = function.__doc__
return wrapper # type: ignore[return-value]
2 changes: 1 addition & 1 deletion ice_station_zebra/config/evaluate/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ defaults:
- callbacks:
- metric_summary
- plotting
- _self_
- _self_
2 changes: 1 addition & 1 deletion ice_station_zebra/data/anemoi/zebra_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.path_dataset = _data_path / "anemoi" / f"{name}.zarr"
self.path_preprocessor = _data_path / "preprocessing"
# Note that Anemoi 'forcings' need to be escaped with `\${}` to avoid being resolved here
self.config = OmegaConf.to_container(config, resolve=True)["datasets"][name]
self.config: DictConfig = OmegaConf.to_object(config["datasets"][name]) # type: ignore[assignment]
self.preprocessor = cls_preprocessor(self.config)

def create(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions ice_station_zebra/data/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .combined_dataset import CombinedDataset
from .zebra_data_module import ZebraDataModule

__all__ = [
"CombinedDataset",
"ZebraDataModule",
]
4 changes: 2 additions & 2 deletions ice_station_zebra/data/lightning/combined_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __len__(self) -> int:
"""Return the total length of the dataset"""
return min([len(ds) for ds in self.datasets] + [len(self.target)])

def __getitem__(self, idx: int) -> NDArray[np.float32]:
def __getitem__(self, idx: int) -> tuple[NDArray[np.float32]]:
"""Return a single timestep"""
return tuple([ds[idx] for ds in self.datasets] + [self.target[idx]])
return tuple([ds[idx] for ds in self.datasets] + [self.target[idx]]) # type: ignore[return-value]

def date_from_index(self, idx: int) -> datetime:
"""Return the date of the timestep"""
Expand Down
5 changes: 1 addition & 4 deletions ice_station_zebra/data/lightning/zebra_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from lightning import LightningDataModule
from numpy.typing import NDArray
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from ice_station_zebra.types import DataloaderArgs, DataSpace
Expand All @@ -21,9 +21,6 @@ class ZebraDataModule(LightningDataModule):
def __init__(self, config: DictConfig) -> None:
super().__init__()

# Load the resolved config into Python format
config = OmegaConf.to_container(config, resolve=True)

# Load paths
self.base_path = Path(config["base_path"])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import statistics
from typing import Any

from lightning import LightningModule, Trainer
Expand All @@ -14,15 +15,15 @@ def __init__(self, average_loss: bool = True) -> None:
Args:
average_loss: Whether to log average loss
"""
self.metrics = {}
self.metrics: dict[str, list[float]] = {}
if average_loss:
self.metrics["average_loss"] = []

def on_test_batch_end(
self,
trainer: Trainer,
module: LightningModule,
outputs: dict[str, Tensor],
outputs: dict[str, Tensor], # type: ignore[override]
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
Expand All @@ -33,11 +34,12 @@ def on_test_batch_end(

def on_test_epoch_end(self, trainer: Trainer, module: LightningModule) -> None:
"""Called at the end of the test epoch."""
# Post-process metrics if needed
if "average_loss" in self.metrics:
losses = self.metrics["average_loss"]
self.metrics["average_loss"] = sum(losses) / len(losses)
# Post-process accumulated metrics into a single value
metrics_: dict[str, float] = {}
for name, values in self.metrics.items():
if name.startswith("average_"):
metrics_[name] = statistics.mean(values)

# Log metrics to each logger
for logger in trainer.loggers:
logger.log_metrics(self.metrics)
logger.log_metrics(metrics_)
37 changes: 23 additions & 14 deletions ice_station_zebra/evaluation/callbacks/plotting_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from lightning import LightningModule, Trainer
from lightning.pytorch import Callback
from torch import Tensor
from torch.utils.data import DataLoader

from ice_station_zebra.visualisations import plot_sic_comparison
from ice_station_zebra.data.lightning import CombinedDataset

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +34,7 @@ def on_test_batch_end(
self,
trainer: Trainer,
module: LightningModule,
outputs: dict[str, Tensor],
outputs: dict[str, Tensor], # type: ignore[override]
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
Expand All @@ -42,11 +44,18 @@ def on_test_batch_end(
if batch_idx % self.frequency == 0:
# Get date for this batch
batch_size = outputs["target"].shape[0]
try:
dataloader = trainer.test_dataloaders[dataloader_idx]
except TypeError:
dataloader = trainer.test_dataloaders
date_ = dataloader.dataset.date_from_index(batch_size * batch_idx)
test_dataloaders: DataLoader | list[DataLoader] | None = (
trainer.test_dataloaders
)
if test_dataloaders is None:
logger.debug("No test dataloaders found, skipping plotting.")
return
dataset: CombinedDataset = (
test_dataloaders[dataloader_idx]
if isinstance(test_dataloaders, list)
else test_dataloaders
).dataset # type: ignore[assignment]
date_ = dataset.date_from_index(batch_size * batch_idx)

# Load the ground truth and prediction
np_ground_truth = outputs["target"].cpu().numpy()[0, 0, :, :]
Expand All @@ -59,11 +68,11 @@ def on_test_batch_end(
}

# Log images to each logger
for logger in trainer.loggers:
try:
for key, image_list in images.items():
logger.log_image(key=key, images=image_list)
except AttributeError:
logger.debug(
f"Logger {logger.name} does not support logging images."
)
for lightning_logger in trainer.loggers:
for key, image_list in images.items():
if hasattr(lightning_logger, "log_image"):
lightning_logger.log_image(key=key, images=image_list)
else:
logger.debug(
f"Logger {lightning_logger.name} does not support logging images."
)
2 changes: 1 addition & 1 deletion ice_station_zebra/evaluation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
log = logging.getLogger(__name__)


@evaluation_cli.command(help="Evaluate a model")
@evaluation_cli.command()
@hydra_adaptor
def evaluate(
config: DictConfig,
Expand Down
2 changes: 1 addition & 1 deletion ice_station_zebra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, config: DictConfig, checkpoint_path: Path) -> None:
try:
ckpt_config = OmegaConf.load(config_path)
logger.debug(f"Loaded checkpoint config from {ckpt_config}.")
config["model"]["_target_"] = ckpt_config["model"]["_target_"]
config["model"]["_target_"] = ckpt_config["model"]["_target_"] # type: ignore[index]
except (NotADirectoryError, FileNotFoundError):
msg = f"Could not find model configuration file at {config_path}."
logger.debug(msg)
Expand Down
Loading
Loading