Skip to content

Commit 1765886

Browse files
authored
Merge pull request #37 from alan-turing-institute/34-fno
Update `EncoderProcessorDecoder` and add FNO (#34)
2 parents cd9f8c2 + 16d6333 commit 1765886

17 files changed

Lines changed: 1795 additions & 45 deletions

File tree

notebooks/00_exploration.ipynb

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"## AutoCast encoder-processor-decoder model API Exploration\n",
9+
"\n",
10+
"This notebook aims to explore the end-to-end API.\n"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "1",
16+
"metadata": {},
17+
"source": [
18+
"### Example dataaset\n",
19+
"\n",
20+
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D."
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "2",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"\n",
31+
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
32+
"\n",
33+
"sim = AdvectionDiffusion(return_timeseries=True, log_level=\"error\")\n",
34+
"\n",
35+
"def generate_split(\n",
36+
" simulator: AdvectionDiffusion, n_train: int = 4, n_valid: int = 2, n_test: int = 2\n",
37+
"):\n",
38+
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
39+
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
40+
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
41+
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
42+
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
43+
"\n",
44+
"\n",
45+
"combined_data = generate_split(sim)"
46+
]
47+
},
48+
{
49+
"cell_type": "markdown",
50+
"id": "3",
51+
"metadata": {},
52+
"source": [
53+
"### Read combined data into datamodule\n"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"id": "4",
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
64+
"\n",
65+
"datamodule = SpatioTemporalDataModule(\n",
66+
" data=combined_data, data_path=None, n_steps_input=4, n_steps_output=1, batch_size=16\n",
67+
")"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"id": "5",
73+
"metadata": {},
74+
"source": [
75+
"### Example batch\n"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "6",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"batch = next(iter(datamodule.train_dataloader()))\n",
86+
"\n",
87+
"# batch"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "7",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"from auto_cast.decoders.channels_last import ChannelsLast\n",
98+
"from auto_cast.encoders.permute_concat import PermuteConcat\n",
99+
"from auto_cast.models.encoder_decoder import EncoderDecoder\n",
100+
"from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder\n",
101+
"from auto_cast.nn.fno import FNOProcessor\n",
102+
"\n",
103+
"processor = FNOProcessor(\n",
104+
" in_channels=1, out_channels=1, n_modes=(16, 16, 1), hidden_channels=64\n",
105+
")\n",
106+
"encoder = PermuteConcat(with_constants=False)\n",
107+
"decoder = ChannelsLast()\n",
108+
"\n",
109+
"model = EncoderProcessorDecoder.from_encoder_processor_decoder(\n",
110+
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
111+
" processor=processor,\n",
112+
")"
113+
]
114+
},
115+
{
116+
"cell_type": "markdown",
117+
"id": "8",
118+
"metadata": {},
119+
"source": [
120+
"### Run trainer\n"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": null,
126+
"id": "9",
127+
"metadata": {},
128+
"outputs": [],
129+
"source": [
130+
"import lightning as L\n",
131+
"\n",
132+
"device = \"mps\" # \"cpu\"\n",
133+
"trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)\n",
134+
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"id": "10",
140+
"metadata": {},
141+
"source": [
142+
"### Run the evaluation"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"id": "11",
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"trainer.test(model, datamodule.test_dataloader())"
153+
]
154+
}
155+
],
156+
"metadata": {
157+
"kernelspec": {
158+
"display_name": ".venv",
159+
"language": "python",
160+
"name": "python3"
161+
},
162+
"language_info": {
163+
"codemirror_mode": {
164+
"name": "ipython",
165+
"version": 3
166+
},
167+
"file_extension": ".py",
168+
"mimetype": "text/x-python",
169+
"name": "python",
170+
"nbconvert_exporter": "python",
171+
"pygments_lexer": "ipython3",
172+
"version": "3.12.12"
173+
}
174+
},
175+
"nbformat": 4,
176+
"nbformat_minor": 5
177+
}

pyproject.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ authors = [
88
]
99
requires-python = ">=3.11,<3.13"
1010
dependencies = [
11+
"autoemulate>=1.2.0",
1112
"einops>=0.8.1",
1213
"h5py>=3.15.1",
1314
"lightning>=2.5.6",
15+
"neuraloperator>=2.0.0",
1416
"the-well>=1.1.0",
1517
"torch>=2.9.1",
1618
]
@@ -91,3 +93,14 @@ convention = "numpy"
9193

9294
[tool.ruff.lint.per-file-ignores]
9395
"tests/*.py" = ["D"]
96+
97+
[tool.uv.sources]
98+
autoemulate = { git = "https://github.com/alan-turing-institute/autoemulate.git" }
99+
100+
[tool.pytest.ini_options]
101+
filterwarnings = [
102+
# Ignore Lightning warnings that are expected/benign in test environment
103+
"ignore:You are trying to `self.log\\(\\)` but the `self.trainer` reference is not registered:UserWarning",
104+
"ignore:GPU available but not used:UserWarning",
105+
"ignore:The '.*_dataloader' does not have many workers:UserWarning",
106+
]

src/auto_cast/decoders/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
class Decoder(nn.Module, ABC):
1010
"""Base Decoder."""
1111

12-
def __init__(self, latent_dim: int, output_channels: int) -> None:
13-
super().__init__()
14-
self.latent_dim = latent_dim
15-
self.output_channels = output_channels
16-
1712
def decode(self, z: Tensor) -> Tensor:
1813
"""Decode the latent tensor back to the original space.
1914
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from einops import rearrange
2+
3+
from auto_cast.decoders.base import Decoder
4+
from auto_cast.types import Tensor
5+
6+
7+
class ChannelsLast(Decoder):
8+
"""Base Decoder."""
9+
10+
def forward(self, x: Tensor) -> Tensor:
11+
"""Forward pass through the ChannelsLast decoder."""
12+
return rearrange(x, "b c t w h -> b t w h c")

src/auto_cast/encoders/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
class Encoder(nn.Module, ABC):
1010
"""Base encoder."""
1111

12-
def __init__(self, latent_dim: int, input_channels: int) -> None:
13-
super().__init__()
14-
self.latent_dim = latent_dim
15-
self.input_channels = input_channels
16-
1712
def encode(self, x: Tensor) -> Tensor:
1813
"""Encode the input tensor into the latent space.
1914
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
from einops import rearrange
3+
4+
from auto_cast.encoders.base import Encoder
5+
from auto_cast.types import Batch, Tensor
6+
7+
8+
class PermuteConcat(Encoder):
9+
"""Permute and concatenate Encoder."""
10+
11+
def __init__(self, with_constants: bool = False) -> None:
12+
super().__init__()
13+
self.with_constants = with_constants
14+
15+
def forward(self, batch: Batch) -> Tensor:
16+
# Destructure batch, time, space, channels
17+
b, t, w, h, _ = batch.input_fields.shape # TODO: generalize beyond 2D spatial
18+
x = batch.input_fields
19+
x = rearrange(x, "b t w h c -> b c t w h")
20+
if self.with_constants and batch.constant_fields is not None:
21+
constants = batch.constant_fields
22+
constants = rearrange(constants, "b w h c -> b c 1 w h")
23+
x = torch.cat([x, constants], dim=1)
24+
if self.with_constants and batch.constant_scalars is not None:
25+
scalars = batch.constant_scalars
26+
scalars = rearrange(scalars, "b c -> b c 1 1 1")
27+
scalars = scalars.expand(b, -1, t, w, h)
28+
x = torch.cat([x, scalars], dim=1)
29+
return x

src/auto_cast/models/encoder_decoder.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any
2-
31
import lightning as L
42
import torch
53
from torch import nn
@@ -14,16 +12,25 @@ class EncoderDecoder(L.LightningModule):
1412

1513
encoder: Encoder
1614
decoder: Decoder
17-
loss_func: nn.Module
15+
loss_func: nn.Module | None
1816

19-
def __init__(self):
20-
pass
17+
def __init__(
18+
self, encoder: Encoder, decoder: Decoder, loss_func: nn.Module | None = None
19+
) -> None:
20+
super().__init__()
21+
self.encoder = encoder
22+
self.decoder = decoder
23+
self.loss_func = loss_func
2124

22-
def forward(self, *args: Any, **kwargs: Any) -> Any:
23-
return self.decoder(self.encoder(*args, **kwargs))
25+
def forward(self, batch: Batch) -> Tensor:
26+
return self.decoder(self.encoder(batch))
2427

2528
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
26-
output = self.encode(batch)
29+
if self.loss_func is None:
30+
msg = "Loss function not defined for EncoderDecoder model."
31+
raise ValueError(msg)
32+
x = self.encode(batch)
33+
output = self.decoder(x)
2734
loss = self.loss_func(output, batch.output_fields)
2835
return loss # noqa: RET504
2936

@@ -43,8 +50,8 @@ def configure_optmizers(self):
4350
class VAE(EncoderDecoder):
4451
"""Variational Autoencoder Model."""
4552

46-
def forward(self, x: Tensor) -> Tensor:
47-
mu, log_var = self.encoder(x)
53+
def forward(self, batch: Batch) -> Tensor:
54+
mu, log_var = self.encoder(batch)
4855
z = self.reparametrize(mu, log_var)
4956
x = self.decoder(z)
5057
return x # noqa: RET504

0 commit comments

Comments
 (0)