Skip to content

Commit 38b2202

Browse files
committed
➕ Use jaxtyping to annotate types of tensor and array being used in different functions
1 parent 3717f84 commit 38b2202

13 files changed

Lines changed: 97 additions & 59 deletions

ice_station_zebra/data/lightning/combined_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.utils.data import Dataset
66

77
from .zebra_dataset import ZebraDataset
8-
from ice_station_zebra.types import CombinedNumpyBatch
8+
from ice_station_zebra.types import ArrayTCHW
99

1010

1111
class CombinedDataset(Dataset):
@@ -59,7 +59,7 @@ def __len__(self) -> int:
5959
"""Return the total length of the dataset"""
6060
return len(self.available_dates)
6161

62-
def __getitem__(self, idx: int) -> CombinedNumpyBatch:
62+
def __getitem__(self, idx: int) -> dict[str, ArrayTCHW]:
6363
"""Return the data for a single timestep as a dictionary
6464
6565
Returns:

ice_station_zebra/data/lightning/zebra_data_module.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
from functools import cached_property
44
from pathlib import Path
55

6-
import numpy as np
76
from lightning import LightningDataModule
8-
from numpy.typing import NDArray
97
from omegaconf import DictConfig
108
from torch.utils.data import DataLoader
119

12-
from ice_station_zebra.types import DataloaderArgs, DataSpace
10+
from ice_station_zebra.types import ArrayTCHW, DataloaderArgs, DataSpace
1311

1412
from .combined_dataset import CombinedDataset
1513
from .zebra_dataset import ZebraDataset
@@ -89,7 +87,7 @@ def output_space(self) -> DataSpace:
8987

9088
def train_dataloader(
9189
self,
92-
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
90+
) -> DataLoader[dict[str, ArrayTCHW]]:
9391
"""Construct train dataloader"""
9492
dataset = CombinedDataset(
9593
[
@@ -115,7 +113,7 @@ def train_dataloader(
115113

116114
def val_dataloader(
117115
self,
118-
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
116+
) -> DataLoader[dict[str, ArrayTCHW]]:
119117
"""Construct validation dataloader"""
120118
dataset = CombinedDataset(
121119
[
@@ -141,7 +139,7 @@ def val_dataloader(
141139

142140
def test_dataloader(
143141
self,
144-
) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
142+
) -> DataLoader[dict[str, ArrayTCHW]]:
145143
"""Construct test dataloader"""
146144
dataset = CombinedDataset(
147145
[

ice_station_zebra/data/lightning/zebra_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from pathlib import Path
2+
from collections.abc import Sequence
23

34
import numpy as np
45
from anemoi.datasets.data import open_dataset
56
from cachetools import LRUCache, cachedmethod
6-
from numpy.typing import NDArray
77
from torch.utils.data import Dataset
88

9-
from ice_station_zebra.types import DataSpace
9+
from ice_station_zebra.types import ArrayCHW, ArrayTCHW, DataSpace
1010

1111

1212
class ZebraDataset(Dataset):
@@ -57,8 +57,8 @@ def __len__(self) -> int:
5757
"""Return the total length of the dataset"""
5858
return len(self.dataset)
5959

60-
def __getitem__(self, idx: int) -> NDArray[np.float32]:
61-
"""Return a single timestep after reshaping to [C, H, W]"""
60+
def __getitem__(self, idx: int) -> ArrayCHW:
61+
"""Return the data for a single timestep in [C, H, W] format"""
6262
return self.dataset[idx].reshape(self.chw)
6363

6464
@cachedmethod(lambda self: self._cache)
@@ -67,8 +67,8 @@ def index_from_date(self, date: np.datetime64) -> int:
6767
idx, _, _ = self.dataset.to_index(date, 0)
6868
return idx
6969

70-
def get_tchw(self, dates: list[np.datetime64]) -> NDArray[np.float32]:
71-
"""Return the data for a given set of dates in [T, C, H, W] format"""
70+
def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW:
71+
"""Return the data for a series of timesteps in [T, C, H, W] format"""
7272
return np.stack(
7373
[self[self.index_from_date(target_date)] for target_date in dates], axis=0
7474
)
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
11
from abc import ABC, abstractmethod
22

33
import torch.nn as nn
4-
from torch import Tensor
4+
5+
from ice_station_zebra.types import TensorNCHW, TensorNTCHW
56

67

78
class BaseDecoder(nn.Module, ABC):
89
"""
910
Decoder that takes data in a latent space and translates it to a larger output space
1011
1112
Latent space:
12-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
13+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
1314
1415
Output space:
15-
Tensor[NTCHW] with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
16+
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
1617
"""
1718

1819
def __init__(self, *, n_forecast_steps: int) -> None:
1920
super().__init__()
2021
self.n_forecast_steps = n_forecast_steps
2122

2223
@abstractmethod
23-
def forward(self, x: Tensor) -> Tensor:
24+
def forward(self, x: TensorNCHW) -> TensorNTCHW:
2425
"""
2526
Transformation summary
2627
2728
Args:
28-
x: Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
29+
x: TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
2930
3031
Returns:
31-
Tensor[NTCHW] with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
32+
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
3233
"""

ice_station_zebra/models/decoders/naive_latent_space_decoder.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from typing import Any
33

44
import torch.nn as nn
5-
from torch import Tensor
65

7-
from ice_station_zebra.types import DataSpace
6+
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
87
from .base_decoder import BaseDecoder
98

109

@@ -13,10 +12,10 @@ class NaiveLatentSpaceDecoder(BaseDecoder):
1312
Naive, linear decoder that takes data in a latent space and translates it to a larger output space
1413
1514
Latent space:
16-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
15+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
1716
1817
Output space:
19-
Tensor[NTCHW] with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
18+
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
2019
"""
2120

2221
def __init__(
@@ -54,14 +53,14 @@ def __init__(
5453
# Combine the layers sequentially
5554
self.model = nn.Sequential(*layers)
5655

57-
def forward(self, x: Tensor) -> Tensor:
56+
def forward(self, x: TensorNCHW) -> TensorNTCHW:
5857
"""
5958
Transformation summary
6059
6160
Args:
62-
x: Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
61+
x: TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
6362
6463
Returns:
65-
Tensor[NTCHW] with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
64+
TensorNTCHW with (batch_size, n_forecast_steps, output_channels, output_height, output_width)
6665
"""
6766
return self.model(x)

ice_station_zebra/models/encode_process_decode.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import hydra
44
import torch
55
from omegaconf import DictConfig
6-
from torch import Tensor
76

8-
from ice_station_zebra.types import CombinedTensorBatch, DataSpace
7+
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
98
from ice_station_zebra.models.encoders import BaseEncoder
109

1110
from .zebra_model import ZebraModel
@@ -62,7 +61,7 @@ def __init__(
6261
self.model_list.append(self.processor)
6362
self.model_list.append(self.decoder)
6463

65-
def forward(self, inputs: CombinedTensorBatch) -> torch.Tensor:
64+
def forward(self, inputs: dict[str, TensorNTCHW]) -> TensorNTCHW:
6665
"""Forward step of the model
6766
6867
- start with multiple [NTCHW] inputs each with shape [batch, n_history_steps, C_input_k, H_input_k, W_input_k]
@@ -72,18 +71,18 @@ def forward(self, inputs: CombinedTensorBatch) -> torch.Tensor:
7271
- decode back to [NTCHW] output space [batch, n_forecast_steps, C_output, H_output, W_output]
7372
"""
7473
# Encode inputs into latent space: list of tensors with (batch_size, variables, latent_height, latent_width)
75-
latent_inputs: list[Tensor] = [
74+
latent_inputs: list[TensorNCHW] = [
7675
encoder(inputs[encoder.name]) for encoder in self.encoders
7776
]
7877

7978
# Combine in the variable dimension: tensor with (batch_size, all_variables, latent_height, latent_width)
80-
latent_input_combined = torch.cat(latent_inputs, dim=1)
79+
latent_input_combined: TensorNCHW = torch.cat(latent_inputs, dim=1)
8180

8281
# Process in latent space: tensor with (batch_size, all_variables, latent_height, latent_width)
83-
latent_output: Tensor = self.processor(latent_input_combined)
82+
latent_output: TensorNCHW = self.processor(latent_input_combined)
8483

8584
# Decode to output space: tensor with (batch_size, output_variables, output_height, output_width)
86-
output: Tensor = self.decoder(latent_output)
85+
output: TensorNTCHW = self.decoder(latent_output)
8786

8887
# Return
8988
return output
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from abc import ABC, abstractmethod
22

33
import torch.nn as nn
4-
from torch import Tensor
4+
from ice_station_zebra.types import TensorNCHW, TensorNTCHW
55

66

77
class BaseEncoder(nn.Module, ABC):
88
"""
99
Encoder that takes data in an input space and translates it to a smaller latent space
1010
1111
Input space:
12-
Tensor[NTCHW] with (batch_size, n_history_steps, input_channels, input_height, input_width)
12+
TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
1313
1414
Latent space:
15-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
15+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
1616
"""
1717

1818
def __init__(self, *, name: str, n_history_steps: int) -> None:
@@ -21,13 +21,13 @@ def __init__(self, *, name: str, n_history_steps: int) -> None:
2121
self.n_history_steps = n_history_steps
2222

2323
@abstractmethod
24-
def forward(self, x: Tensor) -> Tensor:
24+
def forward(self, x: TensorNTCHW) -> TensorNCHW:
2525
"""
2626
Transformation summary
2727
2828
Args:
29-
x: Tensor[NTCHW] with (batch_size, n_history_steps, input_channels, input_height, input_width)
29+
x: TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
3030
3131
Returns:
32-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
32+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
3333
"""

ice_station_zebra/models/encoders/naive_latent_space_encoder.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from typing import Any
33

44
import torch.nn as nn
5-
from torch import Tensor
65

7-
from ice_station_zebra.types import DataSpace
6+
from ice_station_zebra.types import DataSpace, TensorNCHW, TensorNTCHW
87
from .base_encoder import BaseEncoder
98

109

@@ -13,10 +12,10 @@ class NaiveLatentSpaceEncoder(BaseEncoder):
1312
Naive, linear encoder that takes data in an input space and translates it to a smaller latent space
1413
1514
Input space:
16-
Tensor[NTCHW] with (batch_size, n_history_steps, input_channels, input_height, input_width)
15+
TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
1716
1817
Latent space:
19-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
18+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
2019
"""
2120

2221
def __init__(
@@ -52,14 +51,14 @@ def __init__(
5251
# Combine the layers sequentially
5352
self.model = nn.Sequential(*layers)
5453

55-
def forward(self, x: Tensor) -> Tensor:
54+
def forward(self, x: TensorNTCHW) -> TensorNCHW:
5655
"""
5756
Transformation summary
5857
5958
Args:
60-
x: Tensor[NTCHW] with (batch_size, n_history_steps, input_channels, input_height, input_width)
59+
x: TensorNTCHW with (batch_size, n_history_steps, input_channels, input_height, input_width)
6160
6261
Returns:
63-
Tensor[NCHW] with (batch_size, latent_channels, latent_height, latent_width)
62+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
6463
"""
6564
return self.model(x)
Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
11
import torch.nn as nn
2-
from torch import Tensor
2+
from ice_station_zebra.types import TensorNCHW
33

44

55
class NullProcessor(nn.Module):
6-
"""Null model that simply returns input"""
6+
"""Null model that simply returns input
7+
8+
Operations all occur in latent space:
9+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
10+
"""
711

812
def __init__(self, n_latent_channels: int) -> None:
913
super().__init__()
1014
self.n_latent_channels = n_latent_channels
1115
self.model = nn.Identity()
1216

13-
def forward(self, x: Tensor) -> Tensor:
17+
def forward(self, x: TensorNCHW) -> TensorNCHW:
18+
"""
19+
Transformation summary
20+
21+
Args:
22+
x: TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
23+
24+
Returns:
25+
TensorNCHW with (batch_size, latent_channels, latent_height, latent_width)
26+
"""
1427
return self.model(x)

ice_station_zebra/models/zebra_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from omegaconf import DictConfig
88
from torch.optim import Optimizer
99

10-
from ice_station_zebra.types import CombinedTensorBatch, DataSpace
10+
from ice_station_zebra.types import DataSpace, TensorNTCHW
1111

1212

1313
class ZebraModel(LightningModule, ABC):
@@ -46,7 +46,7 @@ def __init__(
4646
self.save_hyperparameters()
4747

4848
@abstractmethod
49-
def forward(self, inputs: CombinedTensorBatch) -> torch.Tensor:
49+
def forward(self, inputs: dict[str, TensorNTCHW]) -> TensorNTCHW:
5050
"""Forward step of the model
5151
5252
- start with multiple [NTCHW] inputs each with shape [batch, n_history_steps, C_input_k, H_input_k, W_input_k]
@@ -62,11 +62,11 @@ def configure_optimizers(self) -> Optimizer:
6262
dict(**self.optimizer_cfg) | {"params": self.model_list.parameters()}
6363
)
6464

65-
def loss(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
65+
def loss(self, output: TensorNTCHW, target: TensorNTCHW) -> torch.Tensor:
6666
return torch.nn.functional.l1_loss(output, target)
6767

6868
def test_step(
69-
self, batch: CombinedTensorBatch, batch_idx: int
69+
self, batch: dict[str, TensorNTCHW], batch_idx: int
7070
) -> dict[str, torch.Tensor]:
7171
"""Run the test step, in PyTorch eval model (i.e. no gradients)
7272
@@ -82,7 +82,9 @@ def test_step(
8282
loss = self.loss(output, target)
8383
return {"output": output, "target": target, "loss": loss}
8484

85-
def training_step(self, batch: CombinedTensorBatch, batch_idx: int) -> torch.Tensor:
85+
def training_step(
86+
self, batch: dict[str, TensorNTCHW], batch_idx: int
87+
) -> torch.Tensor:
8688
"""Run the training step
8789
8890
A batch contains one tensor for each input dataset and one for the target
@@ -97,7 +99,7 @@ def training_step(self, batch: CombinedTensorBatch, batch_idx: int) -> torch.Ten
9799
return self.loss(output, target)
98100

99101
def validation_step(
100-
self, batch: CombinedTensorBatch, batch_idx: int
102+
self, batch: dict[str, TensorNTCHW], batch_idx: int
101103
) -> torch.Tensor:
102104
"""Run the validation step
103105

0 commit comments

Comments
 (0)