Skip to content

Commit 2c31a32

Browse files
committed
Merge remote-tracking branch 'origin/main' into 45-encoder-decoder
2 parents ce109d4 + f5dfad9 commit 2c31a32

17 files changed

Lines changed: 1421 additions & 170 deletions

notebooks/00_exploration.ipynb

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"## AutoCast encoder-processor-decoder model API Exploration\n",
9+
"\n",
10+
"This notebook aims to explore the end-to-end API.\n"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "1",
16+
"metadata": {},
17+
"source": [
18+
"### Example dataaset\n",
19+
"\n",
20+
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D."
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "2",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"\n",
31+
"from autoemulate.simulations.reaction_diffusion import ReactionDiffusion as Sim\n",
32+
"\n",
33+
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
34+
"\n",
35+
"def generate_split(simulator: Sim, n_train: int = 1, n_valid: int = 1, n_test: int = 1):\n",
36+
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
37+
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
38+
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
39+
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
40+
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
41+
"\n",
42+
"\n",
43+
"combined_data = generate_split(sim)"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"id": "3",
49+
"metadata": {},
50+
"source": [
51+
"### Read combined data into datamodule\n"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "4",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
62+
"\n",
63+
"n_steps_input = 4\n",
64+
"n_steps_output = 1\n",
65+
"datamodule = SpatioTemporalDataModule(\n",
66+
" data=combined_data, data_path=None, n_steps_input=n_steps_input, n_steps_output=n_steps_output, batch_size=16\n",
67+
")"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"id": "5",
73+
"metadata": {},
74+
"source": [
75+
"### Example batch\n"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "6",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"batch = next(iter(datamodule.train_dataloader()))\n",
86+
"\n",
87+
"# batch"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"id": "7",
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"from auto_cast.decoders.channels_last import ChannelsLast\n",
98+
"from auto_cast.encoders.permute_concat import PermuteConcat\n",
99+
"from auto_cast.models.encoder_decoder import EncoderDecoder\n",
100+
"from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder\n",
101+
"from auto_cast.nn.fno import FNOProcessor\n",
102+
"\n",
103+
"batch = next(iter(datamodule.train_dataloader()))\n",
104+
"n_channels = batch.input_fields.shape[-1]\n",
105+
"processor = FNOProcessor(\n",
106+
" in_channels=n_channels * n_steps_input,\n",
107+
" out_channels=n_channels * n_steps_output,\n",
108+
" n_modes=(16, 16),\n",
109+
" hidden_channels=64,\n",
110+
")\n",
111+
"encoder = PermuteConcat(with_constants=False)\n",
112+
"decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)\n",
113+
"\n",
114+
"model = EncoderProcessorDecoder.from_encoder_processor_decoder(\n",
115+
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
116+
" processor=processor,\n",
117+
")"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"id": "8",
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"model(batch).shape"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "9",
133+
"metadata": {},
134+
"source": [
135+
"### Run trainer\n"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"id": "10",
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"import lightning as L\n",
146+
"\n",
147+
"# device = \"mps\" # \"cpu\"\n",
148+
"device = \"cpu\"\n",
149+
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)\n",
150+
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
151+
]
152+
},
153+
{
154+
"cell_type": "markdown",
155+
"id": "11",
156+
"metadata": {},
157+
"source": [
158+
"### Run the evaluation"
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": null,
164+
"id": "12",
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"trainer.test(model, datamodule.test_dataloader())"
169+
]
170+
},
171+
{
172+
"cell_type": "markdown",
173+
"id": "13",
174+
"metadata": {},
175+
"source": [
176+
"### Example rollout"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": null,
182+
"id": "14",
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"# A single element is the full trajectory\n",
187+
"batch = next(iter(datamodule.rollout_test_dataloader()))"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"id": "15",
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"# First n_steps_input are inputs\n",
198+
"print(batch.input_fields.shape)\n",
199+
"# Remaining n_steps_output are outputs\n",
200+
"print(batch.output_fields.shape)"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": null,
206+
"id": "16",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"# Run rollout on one trajectory\n",
211+
"preds, trues = model.rollout(batch)"
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": null,
217+
"id": "17",
218+
"metadata": {},
219+
"outputs": [],
220+
"source": [
221+
"print(preds.shape)"
222+
]
223+
},
224+
{
225+
"cell_type": "code",
226+
"execution_count": null,
227+
"id": "18",
228+
"metadata": {},
229+
"outputs": [],
230+
"source": [
231+
"print(trues.shape)\n"
232+
]
233+
}
234+
],
235+
"metadata": {
236+
"kernelspec": {
237+
"display_name": ".venv",
238+
"language": "python",
239+
"name": "python3"
240+
},
241+
"language_info": {
242+
"codemirror_mode": {
243+
"name": "ipython",
244+
"version": 3
245+
},
246+
"file_extension": ".py",
247+
"mimetype": "text/x-python",
248+
"name": "python",
249+
"nbconvert_exporter": "python",
250+
"pygments_lexer": "ipython3",
251+
"version": "3.12.12"
252+
}
253+
},
254+
"nbformat": 4,
255+
"nbformat_minor": 5
256+
}

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"h5py>=3.15.1",
1515
"jaxtyping>=0.3.3",
1616
"lightning>=2.5.6",
17+
"neuraloperator>=2.0.0",
1718
"the-well>=1.1.0",
1819
"torch>=2.9.1",
1920
]
@@ -95,3 +96,14 @@ convention = "numpy"
9596

9697
[tool.ruff.lint.per-file-ignores]
9798
"tests/*.py" = ["D"]
99+
100+
[tool.uv.sources]
101+
autoemulate = { git = "https://github.com/alan-turing-institute/autoemulate.git" }
102+
103+
[tool.pytest.ini_options]
104+
filterwarnings = [
105+
# Ignore Lightning warnings that are expected/benign in test environment
106+
"ignore:You are trying to `self.log\\(\\)` but the `self.trainer` reference is not registered:UserWarning",
107+
"ignore:GPU available but not used:UserWarning",
108+
"ignore:The '.*_dataloader' does not have many workers:UserWarning",
109+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from einops import rearrange
2+
3+
from auto_cast.decoders.base import Decoder
4+
from auto_cast.types import Tensor
5+
6+
7+
class ChannelsLast(Decoder):
8+
"""Decoder that splits merged (channel*time) back to (time, channel) and reorders to channels-last format.""" # noqa: E501
9+
10+
def __init__(self, output_channels: int, time_steps: int = 1) -> None:
11+
"""Initialize the ChannelsLast decoder.
12+
13+
Parameters
14+
----------
15+
output_channels: int
16+
Number of output channels (C).
17+
time_steps: int
18+
Number of time steps (T) that were merged with channels in encoding.
19+
"""
20+
super().__init__()
21+
self.output_channels = output_channels
22+
self.time_steps = time_steps
23+
24+
def forward(self, x: Tensor) -> Tensor:
25+
"""Forward pass through the ChannelsLast decoder.
26+
27+
Expects input shape (B, C*T, W, H) and outputs (B, T, W, H, C).
28+
"""
29+
# Split merged (C*T) dimension back into separate C and T
30+
# x: (B, C*T, W, H) -> (B, C, T, W, H)
31+
x = rearrange(
32+
x, "b (c t) w h -> b c t w h", c=self.output_channels, t=self.time_steps
33+
)
34+
# Rearrange to channels-last: (B, C, T, W, H) -> (B, T, W, H, C)
35+
return rearrange(x, "b c t w h -> b t w h c")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
from einops import rearrange
3+
4+
from auto_cast.encoders.base import Encoder
5+
from auto_cast.types import Batch, Tensor
6+
7+
8+
class PermuteConcat(Encoder):
9+
"""Permute and concatenate Encoder."""
10+
11+
def __init__(self, with_constants: bool = False) -> None:
12+
super().__init__()
13+
self.with_constants = with_constants
14+
15+
def forward(self, batch: Batch) -> Tensor:
16+
# Destructure batch, time, space, channels
17+
b, t, w, h, _ = batch.input_fields.shape # TODO: generalize beyond 2D spatial
18+
x = batch.input_fields
19+
x = rearrange(x, "b t w h c -> b c t w h")
20+
if self.with_constants and batch.constant_fields is not None:
21+
constants = batch.constant_fields
22+
constants = rearrange(constants, "b w h c -> b c 1 w h")
23+
x = torch.cat([x, constants], dim=1)
24+
if self.with_constants and batch.constant_scalars is not None:
25+
scalars = batch.constant_scalars
26+
scalars = rearrange(scalars, "b c -> b c 1 1 1")
27+
scalars = scalars.expand(b, -1, t, w, h)
28+
x = torch.cat([x, scalars], dim=1)
29+
return rearrange(x, "b c t w h -> b (c t) w h")

src/auto_cast/models/encoder_decoder.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def forward_with_latent(self, batch: Batch) -> tuple[TensorBTSPlusC, TensorBMSta
2727
return decoded, encoded
2828

2929
def training_step(self, batch: Batch, batch_idx: int) -> Tensor: # noqa: ARG002
30-
output = self(batch)
30+
if self.loss_func is None:
31+
msg = "Loss function not defined for EncoderDecoder model."
32+
raise ValueError(msg)
33+
x = self(batch)
34+
output = self.decoder(x)
3135
loss = self.loss_func(output, batch.output_fields)
3236
self.log(
3337
"train_loss", loss, prog_bar=True, batch_size=batch.input_fields.shape[0]
@@ -54,3 +58,18 @@ def decode(self, z: TensorBMStarL) -> TensorBTSPlusC:
5458
def configure_optimizers(self):
5559
"""Configure optimizers for training."""
5660
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
61+
62+
63+
class VAE(EncoderDecoder):
64+
"""Variational Autoencoder Model."""
65+
66+
def forward(self, batch: Batch) -> Tensor:
67+
mu, log_var = self.encoder(batch)
68+
z = self.reparametrize(mu, log_var)
69+
x = self.decoder(z)
70+
return x # noqa: RET504
71+
72+
def reparametrize(self, mu: Tensor, log_var: Tensor) -> Tensor:
73+
std = torch.exp(0.5 * log_var)
74+
eps = torch.randn_like(std)
75+
return mu + eps * std

0 commit comments

Comments
 (0)